     15 """Tests for core."""
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     21 import threading
     23 import numpy as np
     25 from tensorflow.core.protobuf import config_pb2
     26 from tensorflow.python import pywrap_tensorflow
     27 from tensorflow.python.eager import context
     28 from tensorflow.python.eager import core
     29 from tensorflow.python.eager import execute as execute_lib
     30 from tensorflow.python.eager import test
     31 from tensorflow.python.framework import constant_op
     32 from tensorflow.python.framework import dtypes
     33 from tensorflow.python.framework import errors
     34 from tensorflow.python.framework import ops
     35 from tensorflow.python.framework import test_util
     36 from tensorflow.python.ops import nn_ops
     39 def execute(op_name, num_outputs, inputs, attrs=None):
     40   return execute_lib.execute(
     41       op_name, num_outputs, inputs, attrs, context.context())
     44 def truncated_normal(shape):
     45   return execute(
     46       b'TruncatedNormal',
     47       1,
     48       inputs=[shape],
     49       attrs=('dtype', dtypes.float32.as_datatype_enum, 'T',
     50              shape.dtype.as_datatype_enum, 'seed', 0, 'seed2', 0))[0]
     53 class TFETest(test_util.TensorFlowTestCase):
     55   def testContext(self):
     56     ctx = context.Context()
     57     self.assertFalse(ctx.in_graph_mode())
     58     self.assertTrue(ctx.in_eager_mode())
     60     self.assertEqual('', ctx.scope_name)
     61     ctx.scope_name = 'foo'
     62     self.assertEqual('foo', ctx.scope_name)
     64     self.assertIsNone(ctx.summary_writer_resource)
     65     ctx.summary_writer_resource = 'mock'
     66     self.assertEqual('mock', ctx.summary_writer_resource)
     68     self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0',
     69                      ctx.device_name)
     70     self.assertEqual(ctx.device_name, ctx.device_spec.to_string())
     71     with ctx.device('GPU:0'):
     72       self.assertEqual('/job:localhost/replica:0/task:0/device:GPU:0',
     73                        ctx.device_name)
     74       self.assertEqual(ctx.device_name, ctx.device_spec.to_string())
     75       with ctx.device(None):
     76         self.assertEqual('', ctx.device_name)
     77         self.assertEqual(ctx.device_name, ctx.device_spec.to_string())
     78         with ctx.device('CPU:0'):
     79           self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0',
     80                            ctx.device_name)
     81           self.assertEqual(ctx.device_name, ctx.device_spec.to_string())
     83     has_cpu_device = False
     84     for x in ctx.devices():
     85       has_cpu_device = has_cpu_device or 'CPU' in x
     86     self.assertTrue(has_cpu_device)
     87     del ctx
     89   def testRunMetadata(self):
     90     context.enable_run_metadata()
     91     t = constant_op.constant(1.0)
     92     _ = t + t  # Runs an operation which will be in the RunMetadata
     93     run_metadata = context.export_run_metadata()
     94     context.disable_run_metadata()
     95     step_stats = run_metadata.step_stats
     96     self.assertGreater(len(step_stats.dev_stats), 0)
     97     cpu_stats = step_stats.dev_stats[0]
     98     self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0',
     99                      cpu_stats.device)
    100     self.assertEqual(len(cpu_stats.node_stats), 1)
    101     self.assertEqual(cpu_stats.node_stats[0].node_name, 'Add')
    103   def testContextStackContainsEagerMode(self):
    104     # Eager execution has been enabled, and no other context
    105     # switch has occurred, so `context_stack` should contain
    106     # exactly one entry.
    107     self.assertEqual(len(context.context_stack.stack), 1)
    108     stack_entry = context.context_stack.stack[0]
    110     # The entry should log that eager mode was entered.
    111     self.assertIs(stack_entry.enter_context_fn, context.eager_mode)
    113     # It is not possible to build a graph function when eager execution
    114     # is enabled; the stack entry should reflect this fact.
    115     self.assertFalse(stack_entry.is_building_function)
    117   def testInt32GPU(self):
    118     if not context.context().num_gpus():
    119       self.skipTest('No GPUs found')
    120     with ops.device('gpu:0'):
    121       xent = nn_ops.sparse_softmax_cross_entropy_with_logits(
    122           logits=[[0.0, 0.0]], labels=[0])
    123     self.assertAllClose(xent, [0.69314718])
    125   def _runInThread(self, target, args):
    126     t = threading.Thread(target=target, args=args)
    127     try:
    128       t.start()
    129       t.join()
    130     except Exception as e:
    131       raise e
    133   # Test that different thread local values are initialized to the same values
    134   # in different threads.
    135   def testContextThreadLocalMembers(self):
    137     def get_context_values(ctx):
    138       return [
    139           ctx.in_graph_mode(),
    140           ctx.in_eager_mode(), ctx.scope_name, ctx.summary_writer_resource,
    141           ctx.device_name, ctx.num_gpus()
    142       ]
    144     def get_values(ctx, values):
    145       values.extend(get_context_values(ctx))
    147     context_values = []
    148     ctx = context.Context()
    149     self._runInThread(get_values, (ctx, context_values))
    150     self.assertAllEqual(context_values, get_context_values(ctx))
    152   def testContextConfig(self):
    153     if not context.context().num_gpus():
    154       self.skipTest('No GPUs found')
    155     ctx = context.Context(config=config_pb2.ConfigProto(
    156         device_count={'GPU': 0}))
    157     self.assertEquals(0, ctx.num_gpus())
    159   def testTensorPlacement(self):
    160     if not context.context().num_gpus():
    161       self.skipTest('No GPUs found')
    163     x = constant_op.constant(1.).gpu()
    164     with context.device('gpu:0'):
    165       y = constant_op.constant(2.)
    166     # Add would fail if t2 were not on GPU
    167     result = execute(
    168         b'Add', 1, inputs=[x, y],
    169         attrs=('T', x.dtype.as_datatype_enum))[0].cpu().numpy()
    170     self.assertEqual(3, result)
    172   def testCopyBetweenDevices(self):
    173     if not context.context().num_gpus():
    174       self.skipTest('No GPUs found')
    176     x = constant_op.constant([[1., 2.], [3., 4.]])
    177     x = x.cpu()
    178     x = x.gpu()
    179     x = x.gpu()
    180     x = x.cpu()
    182     # Invalid device
    183     with self.assertRaises(RuntimeError):
    184       x.gpu(context.context().num_gpus() + 1)
    186   def testCopyScope(self):
    187     if not context.context().num_gpus():
    188       self.skipTest('No GPUs found')
    189     constant = constant_op.constant(1.0)
    190     with ops.device('gpu:0'):
    191       with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
    192         c = constant + 1.0
    193     self.assertAllEqual(c, 2.0)
    195   def testNumpyForceCPU(self):
    196     if not context.context().num_gpus():
    197       self.skipTest('No GPUs found')
    199     cpu = constant_op.constant([[1., 2.], [3., 4.]])
    200     c2g = cpu.gpu()
    201     self.assertAllEqual(c2g, cpu.numpy())
    203   def testCopyFromCPUToCPU(self):
    204     ta = constant_op.constant([[1, 2], [3, 4]])
    205     tb = ta.cpu()
    207     self.assertNotEqual(id(ta), id(tb))
    208     self.assertAllEqual(ta, tb.numpy())
    210   def testRegisterExceptionClass(self):
    211     with self.assertRaises(TypeError):
    212       pywrap_tensorflow.TFE_Py_RegisterExceptionClass(str)
    213     pywrap_tensorflow.TFE_Py_RegisterExceptionClass(core._NotOkStatusException)  # pylint: disable=protected-access
    215   # TODO(agarwal): add tests passing incorrect typed values to attrs.
    216   def testExecuteBasic(self):
    217     three = constant_op.constant(3)
    218     five = constant_op.constant(5)
    219     product = execute(
    220         b'Mul',
    221         num_outputs=1,
    222         inputs=[three, five],
    223         attrs=('T', three.dtype.as_datatype_enum))[0]
    224     self.assertAllEqual(15, product)
    226   def testExecuteTooManyNumOutputs(self):
    227     # num_outputs provided is 50, but only one output is produced.
    228     # That should be okay.
    229     product = execute(
    230         b'Mul',
    231         num_outputs=50,
    232         inputs=[constant_op.constant(3), constant_op.constant(5)],
    233         attrs=('T', dtypes.int32.as_datatype_enum))[0]
    234     self.assertAllEqual(15, product)
    236   def testMatMulGPU(self):
    237     if not context.context().num_gpus():
    238       self.skipTest('No GPUs found')
    239     three = constant_op.constant([[3.]]).gpu()
    240     five = constant_op.constant([[5.]]).gpu()
    241     product = execute(
    242         b'MatMul',
    243         num_outputs=1,
    244         inputs=[three, five],
    245         attrs=('transpose_a', False, 'transpose_b', False, 'T',
    246                three.dtype.as_datatype_enum))[0]
    247     self.assertAllEqual([[15.0]], product)
    249   def testExecuteStringAttr(self):
    250     checked_three = execute(
    251         b'CheckNumerics',
    252         num_outputs=1,
    253         inputs=[constant_op.constant(3.)],
    254         attrs=('message', 'just checking', 'T',
    255                dtypes.float32.as_datatype_enum))[0]
    256     self.assertEqual([[3]], checked_three.numpy())
    258   def testExecuteStringAttrBadValue(self):
    259     with self.assertRaises(errors.InvalidArgumentError):
    260       _ = execute(
    261           b'CheckNumerics',
    262           num_outputs=1,
    263           inputs=[constant_op.constant(3.)],
    264           attrs=('message', 1, 'T', dtypes.float32.as_datatype_enum))
    266   def testExecuteFloatAttr(self):
    267     almost_equal = execute(
    268         b'ApproximateEqual',
    269         num_outputs=1,
    270         inputs=[constant_op.constant(3.0), constant_op.constant(2.9)],
    271         attrs=('tolerance', 0.3, 'T', dtypes.float32.as_datatype_enum))[0]
    272     self.assertTrue(almost_equal)
    274   def testExecuteFloatAttrBadValue(self):
    275     with self.assertRaises(errors.InvalidArgumentError):
    276       _ = execute(
    277           b'ApproximateEqual',
    278           num_outputs=1,
    279           inputs=[constant_op.constant(3.0), constant_op.constant(2.9)],
    280           attrs=('tolerance', '0.3', 'T', dtypes.float32.as_datatype_enum))
    282   def testExecuteIntAttr(self):
    283     total = execute(
    284         b'AddN',
    285         num_outputs=1,
    286         inputs=[constant_op.constant(3), constant_op.constant(4)],
    287         attrs=('T', dtypes.int32.as_datatype_enum, 'N', 2))[0]
    288     self.assertAllEqual(7, total)
    290   def testExecuteIntAttrBadValue(self):
    291     with self.assertRaises(errors.InvalidArgumentError):
    292       _ = execute(
    293           b'AddN',
    294           num_outputs=1,
    295           inputs=[constant_op.constant(3), constant_op.constant(4)],
    296           attrs=('T', dtypes.int32.as_datatype_enum, 'N', '2'))
    298   # Looks like we don't have an existing op with list(bool) attrs.
    299   def testExecuteBoolAttr(self):
    300     product = execute(
    301         b'MatMul',
    302         num_outputs=1,
    303         inputs=[constant_op.constant([[3]]),
    304                 constant_op.constant([[5]])],
    305         attrs=('transpose_a', True, 'transpose_b', False, 'T',
    306                dtypes.int32.as_datatype_enum))[0]
    307     self.assertAllEqual([[15]], product)
    309   def testExecuteShapeAttr(self):
    310     execute(
    311         b'VarHandleOp',
    312         num_outputs=1,
    313         inputs=[],
    314         attrs=('shape', [1, 2], 'dtype', dtypes.int32.as_datatype_enum,
    315                'container', '', 'shared_name', ''))
    317   def testExecuteShapeAttrBadValue(self):
    318     with self.assertRaises(errors.InvalidArgumentError):
    319       execute(
    320           b'VarHandleOp',
    321           num_outputs=1,
    322           inputs=[],
    323           attrs=('shape', 1, 'dtype', dtypes.int32.as_datatype_enum,
    324                  'container', '', 'shared_name', ''))
    326   def testExecuteListStringAttr(self):
    327     execute(
    328         b'TensorSummary',
    329         num_outputs=1,
    330         inputs=[constant_op.constant(3.0)],
    331         attrs=('T', dtypes.float32.as_datatype_enum, 'description',
    332                'tensor_summary', 'labels', ['3',
    333                                             'summary'], 'display_name', 'test'))
    335   def testExecuteListStringAttrBadValue(self):
    336     with self.assertRaises(errors.InvalidArgumentError):
    337       execute(
    338           b'TensorSummary',
    339           num_outputs=1,
    340           inputs=[constant_op.constant(3.0)],
    341           attrs=('T', dtypes.float32.as_datatype_enum, 'description', '',
    342                  'labels', 3, 'display_name', 'test'))
    344   def testExecuteListStringAttrBadListValue(self):
    345     with self.assertRaises(errors.InvalidArgumentError):
    346       execute(
    347           b'TensorSummary',
    348           num_outputs=1,
    349           inputs=[constant_op.constant(3.0)],
    350           attrs=('T', dtypes.float32.as_datatype_enum, 'description', '',
    351                  'labels', [3], 'display_name', 'test'))
    353   def testExecuteListFloatAttr(self):
    354     b = execute(
    355         b'Bucketize',
    356         num_outputs=1,
    357         inputs=[constant_op.constant([3.0, 5.0, 7.0])],
    358         attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', [4.0,
    359                                                                     6.0]))[0]
    360     self.assertAllEqual([0, 1, 2], b)
    362   def testExecuteListFloatAttrBadValue(self):
    363     with self.assertRaises(errors.InvalidArgumentError):
    364       execute(
    365           b'Bucketize',
    366           num_outputs=1,
    367           inputs=[constant_op.constant([3.0, 5.0, 7.0])],
    368           attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', 4.0))
    370   def testExecuteListFloatAttrBadListValue(self):
    371     with self.assertRaises(errors.InvalidArgumentError):
    372       execute(
    373           b'Bucketize',
    374           num_outputs=1,
    375           inputs=[constant_op.constant([3.0, 5.0, 7.0])],
    376           attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries',
    377                  ['4.0', '6.0']))
    379   def testExecuteListIntAttr(self):
    380     b = execute(
    381         b'Squeeze',
    382         num_outputs=1,
    383         inputs=[constant_op.constant([[[3.0]]])],
    384         attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', [0, 2]))[0]
    385     self.assertAllEqual([3], b)
    387   def testExecuteListIntAttrBadValue(self):
    388     with self.assertRaises(errors.InvalidArgumentError):
    389       execute(
    390           b'Squeeze',
    391           num_outputs=1,
    392           inputs=[constant_op.constant([[[3.0]]])],
    393           attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', 0))
    395   def testExecuteListIntAttrBadListValue(self):
    396     with self.assertRaises(errors.InvalidArgumentError):
    397       execute(
    398           b'Squeeze',
    399           num_outputs=1,
    400           inputs=[constant_op.constant([[[3.0]]])],
    401           attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims',
    402                  ['0', '2']))
    404   def testExecuteListTypeListShapeAttr(self):
    405     execute(
    406         b'Barrier',
    407         num_outputs=1,
    408         inputs=[],
    409         attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
    410                [[1, 2]], 'capacity', -1, 'container', '', 'shared_name', ''))
    412   def testExecuteListTypeAttrBadValue(self):
    413     with self.assertRaises(errors.InvalidArgumentError):
    414       execute(
    415           b'Barrier',
    416           num_outputs=1,
    417           inputs=[],
    418           attrs=('component_types', dtypes.float64.as_datatype_enum, 'shapes',
    419                  [[1, 2]], 'capacity', -1, 'container', '', 'shared_name', ''))
    421   def testExecuteListTypeAttrBadListValue(self):
    422     with self.assertRaises(errors.InvalidArgumentError):
    423       execute(
    424           b'Barrier',
    425           num_outputs=1,
    426           inputs=[],
    427           attrs=('component_types', '1', 'shapes', [[1, 2]], 'capacity', -1,
    428                  'container', '', 'shared_name', ''))
    430   def testExecuteListShapeAttrBadValue(self):
    431     with self.assertRaises(errors.InvalidArgumentError):
    432       execute(
    433           b'Barrier',
    434           num_outputs=1,
    435           inputs=[],
    436           attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
    437                  [1, 2], 'capacity', -1, 'container', '', 'shared_name', ''))
    439   def testExecuteListShapeAttrBadListValue(self):
    440     with self.assertRaises(errors.InvalidArgumentError):
    441       execute(
    442           b'Barrier',
    443           num_outputs=1,
    444           inputs=[],
    445           attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
    446                  [1], 'capacity', -1, 'container', '', 'shared_name', ''))
    448   def testExecuteMultipleOutputs(self):
    449     split_dim = 1
    450     value = [[0, 1, 2], [3, 4, 5]]
    451     x1, x2, x3 = execute(
    452         b'Split',
    453         num_outputs=3,
    454         inputs=[constant_op.constant(split_dim),
    455                 constant_op.constant(value)],
    456         attrs=('num_split', 3, 'T', dtypes.int32.as_datatype_enum))
    457     self.assertAllEqual([[0], [3]], x1)
    458     self.assertAllEqual([[1], [4]], x2)
    459     self.assertAllEqual([[2], [5]], x3)
    461   def testExecuteBadNumOutputsArgument(self):
    462     with self.assertRaises(TypeError):
    463       execute(
    464           b'Relu', [],
    465           inputs=[constant_op.constant(3.0)],
    466           attrs=('T', dtypes.float32.as_datatype_enum))
    468   def testExecuteUnknownOp(self):
    469     with self.assertRaises(errors.NotFoundError):
    470       execute(b'BlahBlahBlah', num_outputs=1, inputs=[], attrs=None)
    472   def testExecuteUnknownAttr(self):
    473     with self.assertRaises(errors.InvalidArgumentError):
    474       execute(
    475           b'Identity',
    476           num_outputs=1,
    477           inputs=[constant_op.constant(3)],
    478           attrs=('T', dtypes.int32.as_datatype_enum, 'unknown_attr', 'blah'))
    480   def testComposition(self):
    482     def add(x, y):
    483       return execute(
    484           b'Add',
    485           num_outputs=1,
    486           inputs=[x, y],
    487           attrs=('T', dtypes.int32.as_datatype_enum))[0]
    489     x = constant_op.constant(1)
    490     three_x = add(add(x, x), x)
    491     self.assertEquals(dtypes.int32, three_x.dtype)
    492     self.assertAllEqual(3, three_x)
    494   def testOperationWithNoInputsRunsOnDevice(self):
    495     if not context.context().num_gpus():
    496       self.skipTest('No GPUs found')
    497     shape = constant_op.constant([], dtype=dtypes.int32)
    499     # x: Run the "TruncatedNormal" op CPU and copy result to GPU.
    500     x = truncated_normal(shape).gpu()
    501     # y: Explicitly run the "TruncatedNormal" op on GPU.
    502     with context.device('gpu:0'):
    503       y = truncated_normal(shape)
    504     # Add would fail if x and y were not on the same device.
    505     execute(
    506         b'Add', 1, inputs=[x, y], attrs=('T', x.dtype.as_datatype_enum))
    508   def testInvalidDevice(self):
    509     with self.assertRaises(ValueError):
    510       with context.device('pu:0'):
    511         _ = constant_op.constant(1)
    513   def testConvertMixedEagerTensors(self):
    514     array = np.zeros((), dtype=np.float32)
    515     tensor = constant_op.constant(0., dtype=dtypes.float32)
    516     types, tensors = execute_lib.convert_to_mixed_eager_tensors(
    517         [array, tensor], context.context())
    518     for typ, t in zip(types, tensors):
    519       self.assertEquals(typ, dtypes.float32)
    520       self.assertIsInstance(t, ops.EagerTensor)
    523 if __name__ == '__main__':
    524   test.main()