Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for ConstantOp."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.eager import context
     24 from tensorflow.python.eager import test
     25 from tensorflow.python.framework import constant_op
     26 from tensorflow.python.framework import dtypes as dtypes_lib
     27 from tensorflow.python.framework import errors_impl
     28 from tensorflow.python.framework import ops
     29 from tensorflow.python.framework import test_util
     30 from tensorflow.python.ops import array_ops
     31 from tensorflow.python.util import compat
     32 
     33 
     34 # TODO(josh11b): add tests with lists/tuples, Shape.
     35 class ConstantTest(test.TestCase):
     36 
     37   def _testCpu(self, x):
     38     np_ans = np.array(x)
     39     with context.device("/device:CPU:0"):
     40       tf_ans = ops.convert_to_tensor(x).numpy()
     41     if np_ans.dtype in [np.float32, np.float64, np.complex64, np.complex128]:
     42       self.assertAllClose(np_ans, tf_ans)
     43     else:
     44       self.assertAllEqual(np_ans, tf_ans)
     45 
     46   def _testGpu(self, x):
     47     device = test_util.gpu_device_name()
     48     if device:
     49       np_ans = np.array(x)
     50       with context.device(device):
     51         tf_ans = ops.convert_to_tensor(x).numpy()
     52       if np_ans.dtype in [np.float32, np.float64, np.complex64, np.complex128]:
     53         self.assertAllClose(np_ans, tf_ans)
     54       else:
     55         self.assertAllEqual(np_ans, tf_ans)
     56 
     57   def _testAll(self, x):
     58     self._testCpu(x)
     59     self._testGpu(x)
     60 
     61   def testFloat(self):
     62     self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float32))
     63     self._testAll(
     64         np.random.normal(size=30).reshape([2, 3, 5]).astype(np.float32))
     65     self._testAll(np.empty((2, 0, 5)).astype(np.float32))
     66 
     67     orig = [-1.0, 2.0, 0.0]
     68     tf_ans = constant_op.constant(orig)
     69     self.assertEqual(dtypes_lib.float32, tf_ans.dtype)
     70     self.assertAllClose(np.array(orig), tf_ans.numpy())
     71 
     72     # Mix floats and ints
     73     orig = [-1.5, 2, 0]
     74     tf_ans = constant_op.constant(orig)
     75     self.assertEqual(dtypes_lib.float32, tf_ans.dtype)
     76     self.assertAllClose(np.array(orig), tf_ans.numpy())
     77 
     78     orig = [-5, 2.5, 0]
     79     tf_ans = constant_op.constant(orig)
     80     self.assertEqual(dtypes_lib.float32, tf_ans.dtype)
     81     self.assertAllClose(np.array(orig), tf_ans.numpy())
     82 
     83     # Mix floats and ints that don't fit in int32
     84     orig = [1, 2**42, 0.5]
     85     tf_ans = constant_op.constant(orig)
     86     self.assertEqual(dtypes_lib.float32, tf_ans.dtype)
     87     self.assertAllClose(np.array(orig), tf_ans.numpy())
     88 
     89   def testDouble(self):
     90     self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float64))
     91     self._testAll(
     92         np.random.normal(size=30).reshape([2, 3, 5]).astype(np.float64))
     93     self._testAll(np.empty((2, 0, 5)).astype(np.float64))
     94 
     95     orig = [-5, 2.5, 0]
     96     tf_ans = constant_op.constant(orig, dtypes_lib.float64)
     97     self.assertEqual(dtypes_lib.float64, tf_ans.dtype)
     98     self.assertAllClose(np.array(orig), tf_ans.numpy())
     99 
    100     # This integer is not exactly representable as a double, gets rounded.
    101     tf_ans = constant_op.constant(2**54 + 1, dtypes_lib.float64)
    102     self.assertEqual(2**54, tf_ans.numpy())
    103 
    104     # This integer is larger than all non-infinite numbers representable
    105     # by a double, raises an exception.
    106     with self.assertRaisesRegexp(ValueError, "out-of-range integer"):
    107       constant_op.constant(10**310, dtypes_lib.float64)
    108 
    109   def testInt32(self):
    110     self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(np.int32))
    111     self._testAll(
    112         (100 * np.random.normal(size=30)).reshape([2, 3, 5]).astype(np.int32))
    113     self._testAll(np.empty((2, 0, 5)).astype(np.int32))
    114     self._testAll([-1, 2])
    115 
    116   def testInt64(self):
    117     self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(np.int64))
    118     self._testAll(
    119         (100 * np.random.normal(size=30)).reshape([2, 3, 5]).astype(np.int64))
    120     self._testAll(np.empty((2, 0, 5)).astype(np.int64))
    121     # Should detect out of range for int32 and use int64 instead.
    122     orig = [2, 2**48, -2**48]
    123     tf_ans = constant_op.constant(orig)
    124     self.assertEqual(dtypes_lib.int64, tf_ans.dtype)
    125     self.assertAllClose(np.array(orig), tf_ans.numpy())
    126 
    127     # Out of range for an int64
    128     with self.assertRaisesRegexp(ValueError, "out-of-range integer"):
    129       constant_op.constant([2**72])
    130 
    131   def testComplex64(self):
    132     self._testAll(
    133         np.complex(1, 2) *
    134         np.arange(-15, 15).reshape([2, 3, 5]).astype(np.complex64))
    135     self._testAll(
    136         np.complex(1, 2) *
    137         np.random.normal(size=30).reshape([2, 3, 5]).astype(np.complex64))
    138     self._testAll(np.empty((2, 0, 5)).astype(np.complex64))
    139 
    140   def testComplex128(self):
    141     self._testAll(
    142         np.complex(1, 2) * np.arange(-15, 15).reshape([2, 3, 5
    143                                                       ]).astype(np.complex128))
    144     self._testAll(
    145         np.complex(1, 2) * np.random.normal(size=30).reshape(
    146             [2, 3, 5]).astype(np.complex128))
    147     self._testAll(np.empty((2, 0, 5)).astype(np.complex128))
    148 
    149   def testString(self):
    150     val = [compat.as_bytes(str(x)) for x in np.arange(-15, 15)]
    151     self._testCpu(np.array(val).reshape([2, 3, 5]))
    152     self._testCpu(np.empty((2, 0, 5)).astype(np.str_))
    153 
    154   def testStringWithNulls(self):
    155     val = ops.convert_to_tensor(b"\0\0\0\0").numpy()
    156     self.assertEqual(len(val), 4)
    157     self.assertEqual(val, b"\0\0\0\0")
    158 
    159     val = ops.convert_to_tensor(b"xx\0xx").numpy()
    160     self.assertEqual(len(val), 5)
    161     self.assertAllEqual(val, b"xx\0xx")
    162 
    163     nested = [[b"\0\0\0\0", b"xx\0xx"], [b"\0_\0_\0_\0", b"\0"]]
    164     val = ops.convert_to_tensor(nested).numpy()
    165     # NOTE(mrry): Do not use assertAllEqual, because it converts nested to a
    166     #   numpy array, which loses the null terminators.
    167     self.assertEqual(val.tolist(), nested)
    168 
    169   def testExplicitShapeNumPy(self):
    170     c = constant_op.constant(
    171         np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float32),
    172         shape=[2, 3, 5])
    173     self.assertEqual(c.get_shape(), [2, 3, 5])
    174 
    175   def testImplicitShapeNumPy(self):
    176     c = constant_op.constant(
    177         np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float32))
    178     self.assertEqual(c.get_shape(), [2, 3, 5])
    179 
    180   def testExplicitShapeList(self):
    181     c = constant_op.constant([1, 2, 3, 4, 5, 6, 7], shape=[7])
    182     self.assertEqual(c.get_shape(), [7])
    183 
    184   def testExplicitShapeFill(self):
    185     c = constant_op.constant(12, shape=[7])
    186     self.assertEqual(c.get_shape(), [7])
    187     self.assertAllEqual([12, 12, 12, 12, 12, 12, 12], c.numpy())
    188 
    189   def testExplicitShapeReshape(self):
    190     c = constant_op.constant(
    191         np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float32),
    192         shape=[5, 2, 3])
    193     self.assertEqual(c.get_shape(), [5, 2, 3])
    194 
    195   def testImplicitShapeList(self):
    196     c = constant_op.constant([1, 2, 3, 4, 5, 6, 7])
    197     self.assertEqual(c.get_shape(), [7])
    198 
    199   def testExplicitShapeNumber(self):
    200     c = constant_op.constant(1, shape=[1])
    201     self.assertEqual(c.get_shape(), [1])
    202 
    203   def testImplicitShapeNumber(self):
    204     c = constant_op.constant(1)
    205     self.assertEqual(c.get_shape(), [])
    206 
    207   def testShapeTooBig(self):
    208     with self.assertRaises(TypeError):
    209       constant_op.constant([1, 2, 3, 4, 5, 6, 7], shape=[10])
    210 
    211   def testShapeTooSmall(self):
    212     with self.assertRaises(TypeError):
    213       constant_op.constant([1, 2, 3, 4, 5, 6, 7], shape=[5])
    214 
    215   def testShapeWrong(self):
    216     with self.assertRaisesRegexp(TypeError, None):
    217       constant_op.constant([1, 2, 3, 4, 5, 6, 7], shape=[5])
    218 
    219   def testShape(self):
    220     self._testAll(constant_op.constant([1]).get_shape())
    221 
    222   def testDimension(self):
    223     x = constant_op.constant([1]).shape[0]
    224     self._testAll(x)
    225 
    226   def testDimensionList(self):
    227     x = [constant_op.constant([1]).shape[0]]
    228     self._testAll(x)
    229 
    230     # Mixing with regular integers is fine too
    231     self._testAll([1] + x)
    232     self._testAll(x + [1])
    233 
    234   def testDimensionTuple(self):
    235     x = constant_op.constant([1]).shape[0]
    236     self._testAll((x,))
    237     self._testAll((1, x))
    238     self._testAll((x, 1))
    239 
    240   def testInvalidLength(self):
    241 
    242     class BadList(list):
    243 
    244       def __init__(self):
    245         super(BadList, self).__init__([1, 2, 3])  # pylint: disable=invalid-length-returned
    246 
    247       def __len__(self):
    248         return -1
    249 
    250     with self.assertRaisesRegexp(ValueError, "should return >= 0"):
    251       constant_op.constant([BadList()])
    252     with self.assertRaisesRegexp(ValueError, "mixed types"):
    253       constant_op.constant([1, 2, BadList()])
    254     with self.assertRaisesRegexp(ValueError, "should return >= 0"):
    255       constant_op.constant(BadList())
    256     with self.assertRaisesRegexp(ValueError, "should return >= 0"):
    257       constant_op.constant([[BadList(), 2], 3])
    258     with self.assertRaisesRegexp(ValueError, "should return >= 0"):
    259       constant_op.constant([BadList(), [1, 2, 3]])
    260     with self.assertRaisesRegexp(ValueError, "should return >= 0"):
    261       constant_op.constant([BadList(), []])
    262 
    263     # TODO(allenl, josh11b): These cases should return exceptions rather than
    264     # working (currently shape checking only checks the first element of each
    265     # sequence recursively). Maybe the first one is fine, but the second one
    266     # silently truncating is rather bad.
    267 
    268     # with self.assertRaisesRegexp(ValueError, "should return >= 0"):
    269     #   constant_op.constant([[3, 2, 1], BadList()])
    270     # with self.assertRaisesRegexp(ValueError, "should return >= 0"):
    271     #   constant_op.constant([[], BadList()])
    272 
    273   def testSparseValuesRaiseErrors(self):
    274     with self.assertRaisesRegexp(ValueError, "non-rectangular Python sequence"):
    275       constant_op.constant([[1, 2], [3]], dtype=dtypes_lib.int32)
    276 
    277     with self.assertRaisesRegexp(ValueError, None):
    278       constant_op.constant([[1, 2], [3]])
    279 
    280     with self.assertRaisesRegexp(ValueError, None):
    281       constant_op.constant([[1, 2], [3], [4, 5]])
    282 
    283 
    284 class AsTensorTest(test.TestCase):
    285 
    286   def testAsTensorForTensorInput(self):
    287     t = constant_op.constant(10.0)
    288     x = ops.convert_to_tensor(t)
    289     self.assertIs(t, x)
    290 
    291   def testAsTensorForNonTensorInput(self):
    292     x = ops.convert_to_tensor(10.0)
    293     self.assertTrue(isinstance(x, ops.EagerTensor))
    294 
    295 
    296 class ZerosTest(test.TestCase):
    297 
    298   def _Zeros(self, shape):
    299     ret = array_ops.zeros(shape)
    300     self.assertEqual(shape, ret.get_shape())
    301     return ret.numpy()
    302 
    303   def testConst(self):
    304     self.assertTrue(
    305         np.array_equal(self._Zeros([2, 3]), np.array([[0] * 3] * 2)))
    306 
    307   def testScalar(self):
    308     self.assertEqual(0, self._Zeros([]))
    309     self.assertEqual(0, self._Zeros(()))
    310     scalar = array_ops.zeros(constant_op.constant([], dtype=dtypes_lib.int32))
    311     self.assertEqual(0, scalar.numpy())
    312 
    313   def testDynamicSizes(self):
    314     np_ans = np.array([[0] * 3] * 2)
    315     # Creates a tensor of 2 x 3.
    316     d = array_ops.fill([2, 3], 12., name="fill")
    317     # Constructs a tensor of zeros of the same dimensions as "d".
    318     z = array_ops.zeros(array_ops.shape(d))
    319     out = z.numpy()
    320     self.assertAllEqual(np_ans, out)
    321     self.assertShapeEqual(np_ans, d)
    322     self.assertShapeEqual(np_ans, z)
    323 
    324   def testDtype(self):
    325     d = array_ops.fill([2, 3], 12., name="fill")
    326     self.assertEqual(d.get_shape(), [2, 3])
    327     # Test default type for both constant size and dynamic size
    328     z = array_ops.zeros([2, 3])
    329     self.assertEqual(z.dtype, dtypes_lib.float32)
    330     self.assertEqual([2, 3], z.get_shape())
    331     self.assertAllEqual(z.numpy(), np.zeros([2, 3]))
    332     z = array_ops.zeros(array_ops.shape(d))
    333     self.assertEqual(z.dtype, dtypes_lib.float32)
    334     self.assertEqual([2, 3], z.get_shape())
    335     self.assertAllEqual(z.numpy(), np.zeros([2, 3]))
    336     # Test explicit type control
    337     for dtype in [
    338         dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
    339         dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8,
    340         dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.int64,
    341         dtypes_lib.bool,
    342         # TODO(josh11b): Support string type here.
    343         # dtypes_lib.string
    344     ]:
    345       z = array_ops.zeros([2, 3], dtype=dtype)
    346       self.assertEqual(z.dtype, dtype)
    347       self.assertEqual([2, 3], z.get_shape())
    348       z_value = z.numpy()
    349       self.assertFalse(np.any(z_value))
    350       self.assertEqual((2, 3), z_value.shape)
    351       z = array_ops.zeros(array_ops.shape(d), dtype=dtype)
    352       self.assertEqual(z.dtype, dtype)
    353       self.assertEqual([2, 3], z.get_shape())
    354       z_value = z.numpy()
    355       self.assertFalse(np.any(z_value))
    356       self.assertEqual((2, 3), z_value.shape)
    357 
    358 
    359 class ZerosLikeTest(test.TestCase):
    360 
    361   def _compareZeros(self, dtype, use_gpu):
    362     # Creates a tensor of non-zero values with shape 2 x 3.
    363     # NOTE(kearnes): The default numpy dtype associated with tf.string is
    364     # np.object (and can't be changed without breaking a lot things), which
    365     # causes a TypeError in constant_op.constant below. Here we catch the
    366     # special case of tf.string and set the numpy dtype appropriately.
    367     if dtype == dtypes_lib.string:
    368       numpy_dtype = np.string_
    369     else:
    370       numpy_dtype = dtype.as_numpy_dtype
    371     d = constant_op.constant(np.ones((2, 3), dtype=numpy_dtype), dtype=dtype)
    372     # Constructs a tensor of zeros of the same dimensions and type as "d".
    373     z_var = array_ops.zeros_like(d)
    374     # Test that the type is correct
    375     self.assertEqual(z_var.dtype, dtype)
    376     # Test that the shape is correct
    377     self.assertEqual([2, 3], z_var.get_shape())
    378 
    379     # Test that the value is correct
    380     z_value = z_var.numpy()
    381     self.assertFalse(np.any(z_value))
    382     self.assertEqual((2, 3), z_value.shape)
    383 
    384   def testZerosLikeCPU(self):
    385     for dtype in [
    386         dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
    387         dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8,
    388         dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.int64,
    389         # TODO(josh11b): Support string type here.
    390         # dtypes_lib.string
    391     ]:
    392       self._compareZeros(dtype, use_gpu=False)
    393 
    394   def testZerosLikeGPU(self):
    395     for dtype in [
    396         dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
    397         dtypes_lib.bool, dtypes_lib.int64,
    398         # TODO(josh11b): Support string type here.
    399         # dtypes_lib.string
    400     ]:
    401       self._compareZeros(dtype, use_gpu=True)
    402 
    403   def testZerosLikeDtype(self):
    404     # Make sure zeros_like works even for dtypes that cannot be cast between
    405     shape = (3, 5)
    406     dtypes = np.float32, np.complex64
    407     for in_type in dtypes:
    408       x = np.arange(15).astype(in_type).reshape(*shape)
    409       for out_type in dtypes:
    410         y = array_ops.zeros_like(x, dtype=out_type).numpy()
    411         self.assertEqual(y.dtype, out_type)
    412         self.assertEqual(y.shape, shape)
    413         self.assertAllEqual(y, np.zeros(shape, dtype=out_type))
    414 
    415 
    416 class OnesTest(test.TestCase):
    417 
    418   def _Ones(self, shape):
    419     ret = array_ops.ones(shape)
    420     self.assertEqual(shape, ret.get_shape())
    421     return ret.numpy()
    422 
    423   def testConst(self):
    424     self.assertTrue(np.array_equal(self._Ones([2, 3]), np.array([[1] * 3] * 2)))
    425 
    426   def testScalar(self):
    427     self.assertEqual(1, self._Ones([]))
    428     self.assertEqual(1, self._Ones(()))
    429     scalar = array_ops.ones(constant_op.constant([], dtype=dtypes_lib.int32))
    430     self.assertEqual(1, scalar.numpy())
    431 
    432   def testDynamicSizes(self):
    433     np_ans = np.array([[1] * 3] * 2)
    434     # Creates a tensor of 2 x 3.
    435     d = array_ops.fill([2, 3], 12., name="fill")
    436     # Constructs a tensor of ones of the same dimensions as "d".
    437     z = array_ops.ones(array_ops.shape(d))
    438     out = z.numpy()
    439     self.assertAllEqual(np_ans, out)
    440     self.assertShapeEqual(np_ans, d)
    441     self.assertShapeEqual(np_ans, z)
    442 
    443   def testDtype(self):
    444     d = array_ops.fill([2, 3], 12., name="fill")
    445     self.assertEqual(d.get_shape(), [2, 3])
    446     # Test default type for both constant size and dynamic size
    447     z = array_ops.ones([2, 3])
    448     self.assertEqual(z.dtype, dtypes_lib.float32)
    449     self.assertEqual([2, 3], z.get_shape())
    450     self.assertAllEqual(z.numpy(), np.ones([2, 3]))
    451     z = array_ops.ones(array_ops.shape(d))
    452     self.assertEqual(z.dtype, dtypes_lib.float32)
    453     self.assertEqual([2, 3], z.get_shape())
    454     self.assertAllEqual(z.numpy(), np.ones([2, 3]))
    455     # Test explicit type control
    456     for dtype in (dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
    457                   dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8,
    458                   dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.int64,
    459                   dtypes_lib.bool):
    460       z = array_ops.ones([2, 3], dtype=dtype)
    461       self.assertEqual(z.dtype, dtype)
    462       self.assertEqual([2, 3], z.get_shape())
    463       self.assertAllEqual(z.numpy(), np.ones([2, 3]))
    464       z = array_ops.ones(array_ops.shape(d), dtype=dtype)
    465       self.assertEqual(z.dtype, dtype)
    466       self.assertEqual([2, 3], z.get_shape())
    467       self.assertAllEqual(z.numpy(), np.ones([2, 3]))
    468 
    469 
    470 class OnesLikeTest(test.TestCase):
    471 
    472   def testOnesLike(self):
    473     for dtype in [
    474         dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
    475         dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8,
    476         dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.int64
    477     ]:
    478       numpy_dtype = dtype.as_numpy_dtype
    479       # Creates a tensor of non-zero values with shape 2 x 3.
    480       d = constant_op.constant(np.ones((2, 3), dtype=numpy_dtype), dtype=dtype)
    481       # Constructs a tensor of zeros of the same dimensions and type as "d".
    482       z_var = array_ops.ones_like(d)
    483       # Test that the type is correct
    484       self.assertEqual(z_var.dtype, dtype)
    485       z_value = z_var.numpy()
    486 
    487       # Test that the value is correct
    488       self.assertTrue(np.array_equal(z_value, np.array([[1] * 3] * 2)))
    489       self.assertEqual([2, 3], z_var.get_shape())
    490 
    491 
    492 class FillTest(test.TestCase):
    493 
    494   def _compare(self, dims, val, np_ans, use_gpu):
    495     ctx = context.get_default_context()
    496     device = "GPU:0" if (use_gpu and ctx.num_gpus()) else "CPU:0"
    497     with ops.device(device):
    498       tf_ans = array_ops.fill(dims, val, name="fill")
    499       out = tf_ans.numpy()
    500     self.assertAllClose(np_ans, out)
    501 
    502   def _compareAll(self, dims, val, np_ans):
    503     self._compare(dims, val, np_ans, False)
    504     self._compare(dims, val, np_ans, True)
    505 
    506   def testFillFloat(self):
    507     np_ans = np.array([[3.1415] * 3] * 2).astype(np.float32)
    508     self._compareAll([2, 3], np_ans[0][0], np_ans)
    509 
    510   def testFillDouble(self):
    511     np_ans = np.array([[3.1415] * 3] * 2).astype(np.float64)
    512     self._compareAll([2, 3], np_ans[0][0], np_ans)
    513 
    514   def testFillInt32(self):
    515     np_ans = np.array([[42] * 3] * 2).astype(np.int32)
    516     self._compareAll([2, 3], np_ans[0][0], np_ans)
    517 
    518   def testFillInt64(self):
    519     np_ans = np.array([[-42] * 3] * 2).astype(np.int64)
    520     self._compareAll([2, 3], np_ans[0][0], np_ans)
    521 
    522   def testFillComplex64(self):
    523     np_ans = np.array([[0.15] * 3] * 2).astype(np.complex64)
    524     self._compare([2, 3], np_ans[0][0], np_ans, use_gpu=False)
    525 
    526   def testFillComplex128(self):
    527     np_ans = np.array([[0.15] * 3] * 2).astype(np.complex128)
    528     self._compare([2, 3], np_ans[0][0], np_ans, use_gpu=False)
    529 
    530   def testFillString(self):
    531     np_ans = np.array([[b"yolo"] * 3] * 2)
    532     tf_ans = array_ops.fill([2, 3], np_ans[0][0], name="fill").numpy()
    533     self.assertAllEqual(np_ans, tf_ans)
    534 
    535   def testFillNegative(self):
    536     for shape in (-1,), (2, -1), (-1, 2), (-2), (-3):
    537       with self.assertRaises(errors_impl.InvalidArgumentError):
    538         array_ops.fill(shape, 7)
    539 
    540   def testShapeFunctionEdgeCases(self):
    541     # Non-vector dimensions.
    542     with self.assertRaises(errors_impl.InvalidArgumentError):
    543       array_ops.fill([[0, 1], [2, 3]], 1.0)
    544 
    545     # Non-scalar value.
    546     with self.assertRaises(errors_impl.InvalidArgumentError):
    547       array_ops.fill([3, 2], [1.0, 2.0])
    548 
    549 
    550 if __name__ == "__main__":
    551   test.main()
    552