Home | History | Annotate | Download | only in framework
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.python.framework.ops."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import gc
     22 import weakref
     23 
     24 from tensorflow.core.framework import attr_value_pb2
     25 from tensorflow.core.framework import types_pb2
     26 from tensorflow.core.protobuf import config_pb2
     27 from tensorflow.python.client import session
     28 from tensorflow.python.eager import context
     29 from tensorflow.python.eager import function as eager_function
     30 from tensorflow.python.framework import common_shapes
     31 from tensorflow.python.framework import constant_op
     32 from tensorflow.python.framework import device as pydev
     33 from tensorflow.python.framework import dtypes
     34 from tensorflow.python.framework import errors
     35 from tensorflow.python.framework import function
     36 from tensorflow.python.framework import ops
     37 from tensorflow.python.framework import sparse_tensor
     38 from tensorflow.python.framework import tensor_shape
     39 from tensorflow.python.framework import tensor_util
     40 from tensorflow.python.framework import test_ops
     41 from tensorflow.python.framework import test_util
     42 from tensorflow.python.framework import versions
     43 from tensorflow.python.ops import array_ops
     44 from tensorflow.python.ops import control_flow_ops
     45 from tensorflow.python.ops import gen_array_ops
     46 from tensorflow.python.ops import math_ops
     47 from tensorflow.python.ops import resource_variable_ops
     48 from tensorflow.python.ops import resources
     49 from tensorflow.python.ops import variable_scope
     50 from tensorflow.python.ops import variables
     51 import tensorflow.python.ops.gradients  # pylint: disable=unused-import
     52 from tensorflow.python.platform import googletest
     53 from tensorflow.python.util import compat
     54 
     55 ops._set_call_cpp_shape_fn(common_shapes.call_cpp_shape_fn)
     56 
     57 
     58 @test_util.with_c_api
     59 class ResourceTest(test_util.TensorFlowTestCase):
     60 
     61   def testBuildGraph(self):
     62     with self.test_session():
     63       pt = test_ops.stub_resource_handle_op(container="a", shared_name="b")
     64       test_ops.resource_create_op(pt).run()
     65 
     66   def testInitialize(self):
     67     with self.test_session():
     68       handle = test_ops.stub_resource_handle_op(container="a", shared_name="b")
     69       resources.register_resource(
     70           handle=handle,
     71           create_op=test_ops.resource_create_op(handle),
     72           is_initialized_op=test_ops.resource_initialized_op(handle))
     73       self.assertEquals(
     74           len(
     75               resources.report_uninitialized_resources(
     76                   resources.shared_resources()).eval()), 1)
     77       resources.initialize_resources(resources.shared_resources()).run()
     78       self.assertEquals(
     79           len(
     80               resources.report_uninitialized_resources(
     81                   resources.shared_resources()).eval()), 0)
     82 
     83 
     84 @test_util.with_c_api
     85 class TensorAndShapeTest(test_util.TensorFlowTestCase):
     86 
     87   def testShape(self):
     88     op = ops.Operation(
     89         ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
     90     t = op.outputs[0]
     91     self.assertEqual(tensor_shape.unknown_shape(), t.get_shape())
     92     t.set_shape([1, 2, 3])
     93     self.assertEqual([1, 2, 3], t.get_shape())
     94 
     95   def testIterable(self):
     96     op = ops.Operation(
     97         ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
     98     t = op.outputs[0]
     99     self.assertTrue(isinstance(t, ops.Tensor))
    100     with self.assertRaisesRegexp(TypeError, "iter"):
    101       for _ in t:
    102         pass
    103 
    104   def testAddShape(self):
    105     with self.test_session():
    106       a = array_ops.zeros([2, 3])
    107       b = array_ops.ones([1, 3])
    108       c = a + b
    109       self.assertEqual([2, 3], c.shape)
    110 
    111   def testUnknownDim(self):
    112     with self.test_session():
    113       a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
    114       b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
    115       c = a + b
    116       self.assertEqual([2, None, 3], c.shape.as_list())
    117 
    118   def testUnknownShape(self):
    119     with self.test_session():
    120       a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
    121       b = array_ops.ones([1, 3])
    122       c = a + b
    123       self.assertEqual(tensor_shape.unknown_shape(), c.shape)
    124 
    125   def testScalarShape(self):
    126     with self.test_session():
    127       a = array_ops.placeholder(dtype=dtypes.float32, shape=[])
    128       b = array_ops.ones([])
    129       c = a + b
    130       self.assertEqual(tensor_shape.scalar(), c.shape)
    131 
    132   def testShapeFunctionError(self):
    133     with self.test_session():
    134       a = array_ops.ones([1, 2, 3])
    135       b = array_ops.ones([4, 5, 6])
    136       with self.assertRaisesRegexp(
    137           ValueError,
    138           r"Dimensions must be equal, but are 2 and 5 for 'add' \(op: 'Add'\) "
    139           r"with input shapes: \[1,2,3\], \[4,5,6\]."):
    140         _ = a + b
    141 
    142 
    143 @test_util.with_c_api
    144 class IndexedSlicesTest(test_util.TensorFlowTestCase):
    145 
    146   def testToTensor(self):
    147     with self.test_session():
    148       values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
    149       indices = constant_op.constant([0, 2])
    150       dense_shape = constant_op.constant([3, 2])
    151       x = ops.IndexedSlices(values, indices, dense_shape)
    152       tensor = ops.convert_to_tensor(x, name="tensor")
    153       self.assertAllEqual(tensor.eval(), [[2, 3], [0, 0], [5, 7]])
    154 
    155   def testNegation(self):
    156     with self.test_session():
    157       values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
    158       indices = constant_op.constant([0, 2])
    159       x = -ops.IndexedSlices(values, indices)
    160       self.assertAllEqual(x.values.eval(), [[-2, -3], [-5, -7]])
    161       self.assertAllEqual(x.indices.eval(), [0, 2])
    162 
    163   def testScalarMul(self):
    164     with self.test_session():
    165       values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
    166       indices = constant_op.constant([0, 2])
    167       x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices))
    168       self.assertAllEqual(x.values.eval(), [[-4, -6], [-10, -14]])
    169       self.assertAllEqual(x.indices.eval(), [0, 2])
    170 
    171 
    172 @test_util.with_c_api
    173 class NodeDefConstructorTest(test_util.TensorFlowTestCase):
    174 
    175   def testNoArgs(self):
    176     nodedef = ops._NodeDef("None", "bar")
    177     self.assertProtoEquals("op: 'None' name: 'bar'", nodedef)
    178 
    179   def testArgs(self):
    180     nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*")
    181     self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'",
    182                            nodedef)
    183     nodedef = ops._NodeDef("foo", "bar", device=pydev.DeviceSpec(job="j"))
    184     self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef)
    185 
    186 
    187 def _apply_op(g, *args, **kwargs):
    188   op = g.create_op(*args, **kwargs)
    189   if len(op.outputs) == 1:
    190     return op.outputs[0]
    191   else:
    192     return op.outputs
    193 
    194 
    195 @test_util.with_c_api
    196 class OperationTest(test_util.TensorFlowTestCase):
    197 
    198   def testNoInputs(self):
    199     op = test_ops.float_output_string_output(name="myop").a.op
    200     self.assertEqual(2, len(op.values()))
    201     self.assertEqual(0, len(op.inputs))
    202     self.assertEqual("myop", op.name)
    203 
    204     float_t, label_str_t = op.values()
    205     self.assertEqual(dtypes.float32, float_t.dtype)
    206     self.assertEqual(op, float_t.op)
    207     self.assertEqual(0, float_t._value_index)
    208     self.assertEqual(0, len(float_t.consumers()))
    209     self.assertEqual("myop", float_t._as_node_def_input())
    210 
    211     self.assertEqual(dtypes.string, label_str_t.dtype)
    212     self.assertEqual(op, label_str_t.op)
    213     self.assertEqual(1, label_str_t._value_index)
    214     self.assertEqual(0, len(label_str_t.consumers()))
    215     self.assertEqual("myop:1", label_str_t._as_node_def_input())
    216 
    217     self.assertProtoEquals("op:'FloatOutputStringOutput' name:'myop'",
    218                            op.node_def)
    219 
    220   def testNoOutputs(self):
    221     op1 = test_ops.float_output(name="myop1").op
    222     float_t, = op1.values()
    223     op2 = test_ops.float_input(float_t, name="myop2")
    224     self.assertEqual(0, len(op2.values()))
    225     self.assertEqual(1, len(op2.inputs))
    226     self.assertIs(float_t, op2.inputs[0])
    227 
    228     self.assertEqual(1, len(float_t.consumers()))
    229     self.assertEqual(op2, float_t.consumers()[0])
    230 
    231     self.assertProtoEquals("op:'FloatOutput' name:'myop1'", op1.node_def)
    232     self.assertProtoEquals("op:'FloatInput' name:'myop2' input:'myop1'",
    233                            op2.node_def)
    234 
    235   def testInputsAndOutputs(self):
    236     op1 = test_ops.float_output(name="myop1").op
    237     self.assertEqual(1, len(op1.values()))
    238     float1_t, = op1.values()
    239 
    240     op2 = test_ops.float_output_string_output(name="myop2").a.op
    241     self.assertEqual(2, len(op2.values()))
    242     float2_t, label2_str_t = op2.values()
    243 
    244     # Note that we consume label2_str_t twice here.
    245     op3 = test_ops.foo2(float1_t, label2_str_t, label2_str_t, name="myop3").d.op
    246     self.assertEqual(2, len(op3.values()))
    247 
    248     self.assertEqual(1, len(float1_t.consumers()))
    249     self.assertEqual(op3, float1_t.consumers()[0])
    250 
    251     self.assertEqual(0, len(float2_t.consumers()))
    252 
    253     self.assertEqual(2, len(label2_str_t.consumers()))
    254     self.assertEqual(op3, label2_str_t.consumers()[0])
    255     self.assertEqual(op3, label2_str_t.consumers()[1])
    256 
    257     self.assertProtoEquals("""
    258     op:'Foo2' name:'myop3'
    259     input:'myop1' input:'myop2:1' input:'myop2:1'
    260     """, op3.node_def)
    261 
    262   def testDeviceObject(self):
    263     op = ops.Operation(ops._NodeDef("None", "myop"), ops.Graph(), [], [])
    264     op._set_device("/job:goo/device:GPU:0")
    265     self.assertProtoEquals(
    266         "op:'None' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def)
    267     op = ops.Operation(ops._NodeDef("None", "op2"), ops.Graph(), [], [])
    268     op._set_device(
    269         pydev.DeviceSpec(
    270             job="muu", device_type="CPU", device_index=0))
    271     self.assertProtoEquals(
    272         "op:'None' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def)
    273 
    274   def testReferenceInput(self):
    275     g = ops.Graph()
    276     op1 = ops.Operation(
    277         ops._NodeDef("RefOutputFloatOutput", "op1"), g, [],
    278         [dtypes.float32_ref, dtypes.float32])
    279     g._add_op(op1)
    280     self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
    281     self.assertEquals([], list(op1.inputs))
    282     ref_t, nonref_t = op1.values()
    283     # NOTE(mrry): Must specify input_types to preserve ref-typed input.
    284     op2 = ops.Operation(
    285         ops._NodeDef("RefInputFloatInput", "op2"),
    286         g, [ref_t, nonref_t], [],
    287         input_types=[dtypes.float32_ref, dtypes.float32])
    288     g._add_op(op2)
    289     self.assertProtoEquals(
    290         "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
    291         op2.node_def)
    292     self.assertEquals([ref_t, nonref_t], list(op2.inputs))
    293     op3 = ops.Operation(
    294         ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], [])
    295     g._add_op(op3)
    296     self.assertProtoEquals(
    297         "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
    298         op3.node_def)
    299 
    300   def testInvalidNames(self):
    301     g = ops.Graph()
    302     with self.assertRaises(ValueError):
    303       ops.Operation(ops._NodeDef("op", ""), g)
    304     with self.assertRaises(ValueError):
    305       ops.Operation(ops._NodeDef("op", "_invalid"), g)
    306     with self.assertRaises(ValueError):
    307       ops.Operation(ops._NodeDef("op", "-invalid"), g)
    308     with self.assertRaises(ValueError):
    309       ops.Operation(ops._NodeDef("op", "/invalid"), g)
    310     with self.assertRaises(ValueError):
    311       ops.Operation(ops._NodeDef("op", "invalid:0"), g)
    312 
    313   def testNoShapeFunction(self):
    314     op = test_ops.a()
    315     self.assertEqual(tensor_shape.unknown_shape(), op.get_shape())
    316 
    317   def testConvertToTensorNestedArray(self):
    318     with self.test_session():
    319       values = [[2], [3], [5], [7]]
    320       tensor = ops.convert_to_tensor(values)
    321       self.assertAllEqual((4, 1), tensor.get_shape().as_list())
    322       self.assertAllEqual(values, tensor.eval())
    323 
    324   def testShapeTuple(self):
    325     with self.test_session():
    326       c = constant_op.constant(1)
    327       self.assertEqual(c._shape_tuple(), ())  # pylint: disable=protected-access
    328 
    329   def testConvertToTensorEager(self):
    330     with context.eager_mode():
    331       t = constant_op.constant(1)
    332       self.assertTrue(isinstance(t, ops.EagerTensor))
    333       converted = ops.convert_to_tensor(t)
    334       self.assertTrue(isinstance(converted, ops.EagerTensor))
    335       converted = ops.convert_to_tensor(1)
    336       self.assertTrue(isinstance(converted, ops.EagerTensor))
    337 
    338   def testConvertToTensorNestedTuple(self):
    339     with self.test_session():
    340       values = ((2,), (3,), (5,), (7,))
    341       tensor = ops.convert_to_tensor(values)
    342       self.assertAllEqual((4, 1), tensor.get_shape().as_list())
    343       self.assertAllEqual(values, ops.convert_to_tensor(values).eval())
    344 
    345   def testConvertToTensorNestedTensors(self):
    346     with self.test_session():
    347       values = ((2,), (3,), (5,), (7,))
    348       tensor = ops.convert_to_tensor(
    349           [constant_op.constant(row) for row in values])
    350       self.assertAllEqual((4, 1), tensor.get_shape().as_list())
    351       self.assertAllEqual(values, tensor.eval())
    352       tensor = ops.convert_to_tensor(
    353           [[constant_op.constant(v) for v in row] for row in values])
    354       self.assertAllEqual((4, 1), tensor.get_shape().as_list())
    355       self.assertAllEqual(values, tensor.eval())
    356 
    357   def testConvertToTensorNestedMix(self):
    358     with self.test_session():
    359       values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7]))
    360       tensor = ops.convert_to_tensor(values)
    361       self.assertAllEqual((4, 1), tensor.get_shape().as_list())
    362       self.assertAllEqual(((2,), (3,), (5,), (7,)), tensor.eval())
    363 
    364   def testConvertToTensorPreferred(self):
    365     with self.test_session():
    366       values = [2, 3, 5, 7]
    367       tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32)
    368       self.assertEqual(dtypes.float32, tensor.dtype)
    369 
    370     with self.test_session():
    371       # Convert empty tensor to anything.
    372       values = []
    373       tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
    374       self.assertEqual(dtypes.int64, tensor.dtype)
    375 
    376     with self.test_session():
    377       # The preferred dtype is a type error and will convert to
    378       # float32 instead.
    379       values = [1.23]
    380       tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
    381       self.assertEqual(dtypes.float32, tensor.dtype)
    382 
    383   def testConvertToInvalidTensorType(self):
    384     with self.assertRaises(TypeError):
    385       # Forcing an invalid dtype should fail with a type error.
    386       values = [1.23]
    387       _ = ops.convert_to_tensor(values, dtype=dtypes.int64)
    388 
    389   def testNoConvert(self):
    390     # Operation cannot be converted to Tensor.
    391     op = control_flow_ops.no_op()
    392     with self.assertRaisesRegexp(TypeError,
    393                                  r"Can't convert Operation '.*' to Tensor"):
    394       ops.convert_to_tensor(op)
    395 
    396   def testStr(self):
    397     node_def = ops._NodeDef("None", "op1")
    398     op = ops.Operation(node_def, ops.Graph(), [], [dtypes.float32])
    399     self.assertEqual(str(node_def), str(op))
    400 
    401   def testRepr(self):
    402     op = ops.Operation(
    403         ops._NodeDef("None", "op1"), ops.Graph(), [], [dtypes.float32])
    404     self.assertEqual("<tf.Operation 'op1' type=None>", repr(op))
    405 
    406   def testGetAttr(self):
    407     op = test_ops.default_attrs()
    408     self.assertEqual(op.get_attr("string_val"), b"abc")
    409     self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""])
    410     self.assertEqual(op.get_attr("int_val"), 123)
    411     self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3])
    412     self.assertEqual(op.get_attr("float_val"), 10.0)
    413     self.assertEqual(op.get_attr("float_list_val"), [10.0])
    414     self.assertEqual(op.get_attr("bool_val"), True)
    415     self.assertEqual(op.get_attr("bool_list_val"), [True, False])
    416     self.assertEqual(op.get_attr("shape_val"),
    417                      tensor_shape.as_shape([2, 1]).as_proto())
    418     self.assertEqual(op.get_attr("shape_list_val"),
    419                      [tensor_shape.as_shape([]).as_proto(),
    420                       tensor_shape.as_shape([1]).as_proto()])
    421     self.assertEqual(op.get_attr("tensor_val"),
    422                      tensor_util.make_tensor_proto(1, dtypes.int32))
    423     self.assertEqual(op.get_attr("tensor_list_val"),
    424                      [tensor_util.make_tensor_proto(1, dtypes.int32)])
    425 
    426     type_val = op.get_attr("type_val")
    427     # First check that type_val is a DType, because the assertEquals will work
    428     # no matter what since DType overrides __eq__
    429     self.assertIsInstance(type_val, dtypes.DType)
    430     self.assertEqual(type_val, dtypes.int32)
    431 
    432     type_list_val = op.get_attr("type_list_val")
    433     self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val))
    434     self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32])
    435 
    436     @function.Defun(dtypes.float32, func_name="MyFunc")
    437     def func(x):
    438       return x
    439 
    440     op = test_ops.func_attr(func)
    441     self.assertEqual(op.get_attr("f"),
    442                      attr_value_pb2.NameAttrList(name="MyFunc"))
    443 
    444     # Try fetching missing attr
    445     if ops._USE_C_API:
    446       error_msg = "Operation 'FuncAttr' has no attr named 'FakeAttr'."
    447     else:
    448       error_msg = "No attr named 'FakeAttr' in name: \"FuncAttr\""
    449 
    450     with self.assertRaisesRegexp(ValueError, error_msg):
    451       op.get_attr("FakeAttr")
    452 
    453   # TODO(b/65162920): remove this test when users who are directly mutating the
    454   # node_def have been updated to proper usage.
    455   def testSetAttr(self):
    456     op = test_ops.int_attr().op
    457     op._set_attr("foo", attr_value_pb2.AttrValue(i=2))
    458     # TODO(skyewm): add node_def check
    459     self.assertEqual(op.get_attr("foo"), 2)
    460 
    461   # TODO(nolivia): test all error cases
    462   def testAddControlInput(self):
    463     # The C API dedups redundant control edges, pure Python does not
    464     if ops._USE_C_API: return
    465     with ops.Graph().as_default():
    466       x = constant_op.constant(1).op
    467       y = constant_op.constant(2).op
    468       z = constant_op.constant(3).op
    469     z._add_control_input(x)  # pylint: disable=protected-access
    470     self.assertEqual(z.control_inputs, [x])
    471     z._add_control_input(x)  # pylint: disable=protected-access
    472     self.assertEqual(z.control_inputs, [x, x])
    473     z._add_control_inputs([x, y, y])  # pylint: disable=protected-access
    474     self.assertEqual(z.control_inputs, [x, x, x, y, y])
    475 
    476   def testAddControlInputC(self):
    477     # The C API dedups redundant control edges, pure Python does not
    478     if not ops._USE_C_API: return
    479     with ops.Graph().as_default():
    480       x = constant_op.constant(1).op
    481       y = constant_op.constant(2).op
    482       z = constant_op.constant(3).op
    483     z._add_control_input(x)  # pylint: disable=protected-access
    484     self.assertEqual(z.control_inputs, [x])
    485     z._add_control_input(x)  # pylint: disable=protected-access
    486     self.assertEqual(z.control_inputs, [x])
    487     z._add_control_inputs([x, y, y])  # pylint: disable=protected-access
    488     self.assertEqual(z.control_inputs, [x, y])
    489 
    490   def testRemoveAllControlInputs(self):
    491     a = constant_op.constant(1)
    492     with ops.control_dependencies([a]):
    493       b = constant_op.constant(2)
    494     c = constant_op.constant(3)
    495     d = constant_op.constant(4)
    496     e = constant_op.constant(5)
    497     with ops.control_dependencies([a, c]):
    498       f = d + e
    499 
    500     self.assertEqual(a.op.control_inputs, [])
    501     self.assertEqual(b.op.control_inputs, [a.op])
    502     self.assertEqual(f.op.control_inputs, [a.op, c.op])
    503 
    504     a.op._remove_all_control_inputs()  # pylint: disable=protected-access
    505     self.assertEqual(a.op.control_inputs, [])
    506 
    507     b.op._remove_all_control_inputs()  # pylint: disable=protected-access
    508     self.assertEqual(b.op.control_inputs, [])
    509 
    510     f.op._remove_all_control_inputs()  # pylint: disable=protected-access
    511     self.assertEqual(f.op.control_inputs, [])
    512     self.assertEqual(list(f.op.inputs), [d, e])
    513 
    514   def testControlInputCycle(self):
    515     # Non-C API path has a different error message
    516     if not ops._USE_C_API: return
    517     graph = ops.Graph()
    518     with graph.as_default():
    519       z = constant_op.constant(0)
    520       x = constant_op.constant(1)
    521       y = constant_op.constant(2)
    522       y.op._add_control_input(z.op)  # pylint: disable=protected-access
    523       y.op._add_control_input(x.op)  # pylint: disable=protected-access
    524       x.op._add_control_input(y.op)  # pylint: disable=protected-access
    525     with self.test_session(graph=graph) as sess:
    526       with self.assertRaisesRegexp(
    527           errors.InvalidArgumentError,
    528           "Graph is invalid, contains a cycle with 2 nodes"):
    529         sess.run(x)
    530 
    531   def testUpdateInput(self):
    532     g = ops.Graph()
    533     with g.as_default():
    534       x = constant_op.constant(1)
    535       y = constant_op.constant(2)
    536       z = x + y
    537 
    538     z.op._update_input(0, y)  # pylint: disable=protected-access
    539     self.assertEquals(list(z.op.inputs), [y, y])
    540     self.assertEquals(x.consumers(), [])
    541     self.assertEquals(y.consumers(), [z.op, z.op])
    542     with session.Session(graph=g) as sess:
    543       self.assertEquals(sess.run(z), 4)
    544 
    545     z.op._update_input(0, x)  # pylint: disable=protected-access
    546     self.assertEquals(list(z.op.inputs), [x, y])
    547     self.assertEquals(x.consumers(), [z.op])
    548     self.assertEquals(y.consumers(), [z.op])
    549     with session.Session(graph=g) as sess:
    550       self.assertEquals(sess.run(z), 3)
    551 
    552     z.op._update_input(1, y)  # pylint: disable=protected-access
    553     self.assertEquals(list(z.op.inputs), [x, y])
    554     self.assertEquals(x.consumers(), [z.op])
    555     self.assertEquals(y.consumers(), [z.op])
    556     with session.Session(graph=g) as sess:
    557       self.assertEquals(sess.run(z), 3)
    558 
    559   def testUpdateInputGraphError(self):
    560     g_0 = ops.Graph()
    561     g_1 = ops.Graph()
    562     with g_0.as_default():
    563       x = constant_op.constant(1)
    564     with g_1.as_default():
    565       y = constant_op.constant(2)
    566       z = y * 2
    567       with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
    568         z.op._update_input(0, x)  # pylint: disable=protected-access
    569 
    570   def testUpdateInputTypeError(self):
    571     g = ops.Graph()
    572     with g.as_default():
    573       w = constant_op.constant(0)
    574       x = constant_op.constant("")
    575       y = constant_op.constant(1)
    576       z = y + w
    577       z.op._update_input(0, x)  # pylint: disable=protected-access
    578     with session.Session(graph=g) as sess:
    579       with self.assertRaisesRegexp(
    580           errors.InvalidArgumentError,
    581           "Input 0 of node add was passed string from Const_1:0 incompatible "
    582           "with expected int32"):
    583         sess.run(z)
    584 
    585   def testUpdateInputShapeError(self):
    586     # C-API throws the error differently.
    587     if ops._USE_C_API:
    588       return
    589     g = ops.Graph()
    590     with g.as_default():
    591       w = constant_op.constant(2, shape=[3, 1])
    592       x = constant_op.constant(0, shape=[3, 1])
    593       y = constant_op.constant(1, shape=[2, 2])
    594       z = w + x
    595       z.op._update_input(0, y)  # pylint: disable=protected-access
    596 
    597     with session.Session(graph=g) as sess:
    598       with self.assertRaisesRegexp(errors.InvalidArgumentError,
    599                                    r"Incompatible shapes: \[2,2\] vs. \[3,1\]"):
    600         sess.run(z)
    601 
    602   def testUpdateInputShapeErrorC(self):
    603     if not ops._USE_C_API:
    604       return
    605     g = ops.Graph()
    606     with g.as_default():
    607       w = constant_op.constant(2, shape=[3, 1])
    608       x = constant_op.constant(0, shape=[3, 1])
    609       y = constant_op.constant(1, shape=[2, 2])
    610       z = w + x
    611     with self.assertRaisesRegexp(
    612         errors.InvalidArgumentError,
    613         r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"):
    614       z.op._update_input(0, y)  # pylint: disable=protected-access
    615 
    616   def testUpdateInputOutOfRange(self):
    617     # C-API throws the error differently.
    618     if ops._USE_C_API: return
    619     g = ops.Graph()
    620     with g.as_default():
    621       x = constant_op.constant(1)
    622     with self.assertRaisesRegexp(IndexError, "list index out of range"):
    623       x.op._update_input(1, x)  # pylint: disable=protected-access
    624 
    625   def testUpdateInputOutOfRangeC(self):
    626     # C-API throws the error differently.
    627     if not ops._USE_C_API: return
    628     g = ops.Graph()
    629     with g.as_default():
    630       x = constant_op.constant(1)
    631     with self.assertRaisesRegexp(
    632         errors.OutOfRangeError,
    633         r"Cannot update edge. Input index \[1\] is greater than the number of "
    634         r"total inputs \[0\]."
    635     ):
    636       x.op._update_input(1, x)  # pylint: disable=protected-access
    637 
    638   def testOpDef(self):
    639     x = constant_op.constant(0)
    640     y = constant_op.constant(1)
    641     z = x + y
    642 
    643     # Pure Python mode doesn't create OpDefs for constants
    644     if ops._USE_C_API:
    645       self.assertEqual(x.op.op_def.name, "Const")
    646       self.assertEqual(len(x.op.op_def.input_arg), 0)
    647       self.assertEqual(len(x.op.op_def.output_arg), 1)
    648 
    649     self.assertEqual(z.op.op_def.name, "Add")
    650     self.assertEqual(len(z.op.op_def.input_arg), 2)
    651     self.assertEqual(len(z.op.op_def.output_arg), 1)
    652 
    653   def testInputFromDifferentGraphError(self):
    654     g_0 = ops.Graph()
    655     g_1 = ops.Graph()
    656     with g_0.as_default():
    657       x = constant_op.constant(1)
    658     with g_1.as_default():
    659       y = constant_op.constant(2)
    660       with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
    661         y * x  # pylint: disable=pointless-statement
    662 
    663   def testInputsAreImmutable(self):
    664     g = ops.Graph()
    665     with g.as_default():
    666       x = test_ops.int_output()
    667       op = test_ops.int_input_int_output(x, name="myop").op
    668     with self.assertRaisesRegexp(
    669         AttributeError, "'_InputList' object has no attribute 'append'"):
    670       op.inputs.append(None)
    671 
    672 
    673 @test_util.with_c_api
    674 class CreateOpTest(test_util.TensorFlowTestCase):
    675 
    676   def testNodeDefArgs(self):
    677     g = ops.Graph()
    678     op1 = g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
    679     with g.device("/device:GPU:0"):
    680       op2 = g.create_op(
    681           "FloatOutputStringOutput", [], [dtypes.float32, dtypes.string], None,
    682           name="myop2")
    683     op3 = g.create_op(
    684         "Foo3",
    685         [list(op1.values())[0], list(op2.values())[1], list(op2.values())[0]],
    686         [dtypes.float32, dtypes.int32],
    687         None,
    688         name="myop3")
    689     self.assertDeviceEqual(None, op1.device)
    690     self.assertDeviceEqual("/device:GPU:0", op2.device)
    691     self.assertDeviceEqual(None, op3.device)
    692     self.assertProtoEquals("name:'myop1' op:'FloatOutput'", op1.node_def)
    693     self.assertProtoEquals(
    694         "name:'myop2' op:'FloatOutputStringOutput' device:'/device:GPU:0'",
    695         op2.node_def)
    696     self.assertProtoEquals(
    697         "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo3'",
    698         op3.node_def)
    699 
    700   def testReferenceInput(self):
    701     g = ops.Graph()
    702     op1 = g.create_op(
    703         "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32],
    704         name="op1")
    705     self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
    706     ref_t, nonref_t = op1.values()
    707     # NOTE(mrry): Must specify input_types to preserve ref-typed input.
    708     op2 = g.create_op(
    709         "RefInputFloatInput", [ref_t, nonref_t], [],
    710         input_types=[dtypes.float32_ref, dtypes.float32],
    711         name="op2")
    712     self.assertProtoEquals(
    713         "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
    714         op2.node_def)
    715     op3 = g.create_op("TwoFloatInputs", [ref_t, nonref_t], [], name="op3")
    716     self.assertProtoEquals(
    717         "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
    718         op3.node_def)
    719 
    720   def testFinalized(self):
    721     g = ops.Graph()
    722     g.finalize()
    723     with self.assertRaises(RuntimeError):
    724       g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
    725 
    726     # Test unfinalize.
    727     g._unsafe_unfinalize()
    728     g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
    729 
    730 
    731 # NOTE(skyewm): these cases test the private Graph._create_op_from_tf_operation
    732 # method. Arguably we should only test the public APIs that depend on this
    733 # method. However, this logic is complex and tricky, and it can be difficult to
    734 # ascertain if we have adequate coverage (e.g. a graph may run successfully if
    735 # the control flow context isn't set properly, but a more complicated use case
    736 # that might not be obvious to test will fail). Thus we instead explicitly test
    737 # the low-level behavior.
    738 @test_util.with_c_api
    739 class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase):
    740 
    741   def testBasic(self):
    742     g = ops.Graph()
    743     with g.as_default():
    744       x = test_ops.int_output()
    745       if ops._USE_C_API:
    746         c_op = ops._create_c_op(
    747             g, ops._NodeDef("IntInputIntOutput", "myop"), [x], [])
    748         op = g._create_op_from_tf_operation(c_op)
    749       else:
    750         # Test pure-Python version to make sure C API has same behavior.
    751         op = test_ops.int_input_int_output(x, name="myop").op
    752 
    753     self.assertEqual(op.name, "myop")
    754     self.assertEqual(op.type, "IntInputIntOutput")
    755     self.assertEqual(len(op.outputs), 1)
    756     self.assertEqual(op.outputs[0].shape, tensor_shape.unknown_shape())
    757     self.assertEqual(list(op.inputs), [x])
    758     self.assertEqual(op.control_inputs, [])
    759     self.assertEqual(op.graph, g)
    760     self.assertEqual(x.consumers(), [op])
    761     self.assertIsNotNone(op.traceback)
    762     self.assertEqual(g.get_operation_by_name("myop"), op)
    763     self.assertEqual(g.get_tensor_by_name("myop:0"), op.outputs[0])
    764 
    765   def testShape(self):
    766     g = ops.Graph()
    767     with g.as_default():
    768       x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
    769       if ops._USE_C_API:
    770         c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), [x], [])
    771         op = g._create_op_from_tf_operation(c_op)
    772       else:
    773         # Test pure-Python version to make sure C API has same behavior.
    774         op = array_ops.identity(x, name="myop").op
    775 
    776     self.assertEqual(op.name, "myop")
    777     self.assertEqual(op.type, "Identity")
    778     self.assertEqual(len(op.outputs), 1)
    779     self.assertEqual(op.outputs[0].shape, tensor_shape.matrix(2, 3))
    780 
    781   def testUniqueName(self):
    782     g = ops.Graph()
    783     with g.as_default():
    784       if ops._USE_C_API:
    785         c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], [])
    786         c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], [])
    787         op = g._create_op_from_tf_operation(c_op)
    788         op2 = g._create_op_from_tf_operation(c_op2)
    789       else:
    790         # Test pure-Python version to make sure C API has same behavior.
    791         op = test_ops.int_output(name="myop").op
    792         op2 = test_ops.int_output(name="myop_1").op
    793 
    794       # Create ops with same names as op1 and op2. We expect the new names to be
    795       # uniquified.
    796       op3 = test_ops.int_output(name="myop").op
    797       op4 = test_ops.int_output(name="myop_1").op
    798 
    799     self.assertEqual(op.name, "myop")
    800     self.assertEqual(op2.name, "myop_1")
    801     self.assertEqual(op3.name, "myop_2")
    802     self.assertEqual(op4.name, "myop_1_1")
    803 
    804   def testCond(self):
    805     g = ops.Graph()
    806     with g.as_default():
    807       x = test_ops.int_output()
    808 
    809       def true_fn():
    810         if ops._USE_C_API:
    811           ops._create_c_op(ops.get_default_graph(),
    812                            ops._NodeDef("IntInput", "cond/myop"), [x], [])
    813           new_ops = g._add_new_tf_operations()
    814           self.assertEqual(len(new_ops), 1)
    815         else:
    816           # Test pure-Python version to make sure C API has same behavior.
    817           test_ops.int_input(x, name="myop")
    818         return x
    819 
    820       control_flow_ops.cond(x < 10, true_fn, lambda: x)
    821 
    822     op = g.get_operation_by_name("cond/myop")
    823     self.assertIsNotNone(op)
    824     self.assertEqual(op.name, "cond/myop")
    825     self.assertEqual(op.type, "IntInput")
    826     self.assertEqual(op.outputs, [])
    827     op_input = op.inputs[0].op
    828     self.assertEqual(op_input.type, "Switch")
    829     self.assertEqual(op_input.inputs[0], x)
    830     self.assertEqual(op.graph, g)
    831     # pylint: disable=protected-access
    832     self.assertIsNotNone(op._get_control_flow_context())
    833     self.assertEqual(op._get_control_flow_context().name,
    834                      "cond/cond_text")
    835     # pylint: enable=protected-access
    836 
    837   def testWhileLoop(self):
    838     g = ops.Graph()
    839     with g.as_default():
    840       x = test_ops.int_output()
    841 
    842       def body(i):
    843         if ops._USE_C_API:
    844           ops._create_c_op(ops.get_default_graph(),
    845                            ops._NodeDef("IntInput", "myloop/myop"), [x], [])
    846           new_ops = g._add_new_tf_operations()
    847           self.assertEqual(len(new_ops), 1)
    848         else:
    849           # Test pure-Python version to make sure C API has same behavior.
    850           test_ops.int_input(x, name="myop")
    851         return i
    852 
    853       control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
    854 
    855     op = g.get_operation_by_name("myloop/myop")
    856     self.assertIsNotNone(op)
    857     self.assertEqual(op.name, "myloop/myop")
    858     self.assertEqual(op.type, "IntInput")
    859     self.assertEqual(op.outputs, [])
    860     op_input = op.inputs[0].op
    861     self.assertEqual(op_input.type, "Enter")
    862     self.assertEqual(list(op_input.inputs), [x])
    863     self.assertEqual(op.graph, g)
    864     # pylint: disable=protected-access
    865     self.assertIsNotNone(op._get_control_flow_context())
    866     self.assertEqual(op._get_control_flow_context().name,
    867                      "myloop/while_context")
    868     # pylint: enable=protected-access
    869 
    870   def testWhileLoopWithInternalControlDep(self):
    871     g = ops.Graph()
    872     with g.as_default():
    873       x = test_ops.int_output()
    874 
    875       def body(i):
    876         c = constant_op.constant(1.0, name="c")
    877         if ops._USE_C_API:
    878           ops._create_c_op(ops.get_default_graph(),
    879                            ops._NodeDef("IntInput", "myloop/myop"), [x], [])
    880           with ops.control_dependencies([c]):
    881             new_ops = g._add_new_tf_operations()
    882             self.assertEqual(len(new_ops), 1)
    883         else:
    884           with ops.control_dependencies([c]):
    885             test_ops.int_input(x, name="myop")
    886         return i
    887 
    888       control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
    889 
    890     op = g.get_operation_by_name("myloop/myop")
    891     self.assertIsNotNone(op)
    892     c = g.get_operation_by_name("myloop/c")
    893     self.assertIsNotNone(c)
    894     # Internal control dep is preserved
    895     self.assertEqual(op.control_inputs, [c])
    896 
    897   def testWhileLoopWithExternalControlDep(self):
    898     g = ops.Graph()
    899     with g.as_default():
    900       x = test_ops.int_output()
    901       c = constant_op.constant(1.0)
    902 
    903       def body(i):
    904         if ops._USE_C_API:
    905           ops._create_c_op(ops.get_default_graph(),
    906                            ops._NodeDef("IntInput", "myloop/myop"), [x], [])
    907           with ops.control_dependencies([c]):
    908             new_ops = g._add_new_tf_operations()
    909             self.assertEqual(len(new_ops), 1)
    910         else:
    911           with ops.control_dependencies([c]):
    912             test_ops.int_input(x, name="myop")
    913         return i
    914 
    915       control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
    916 
    917     op = g.get_operation_by_name("myloop/myop")
    918     self.assertIsNotNone(op)
    919     # External control dep is removed and replaced with internal control dep
    920     self.assertNotEqual(op.control_inputs[0], c.op)
    921     self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
    922 
    923 
    924 @test_util.with_c_api
    925 class ApplyOpTest(test_util.TensorFlowTestCase):
    926 
    927   def testNodeDefArgs(self):
    928     g = ops.Graph()
    929     t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1")
    930     with g.device("/device:GPU:0"):
    931       t2 = _apply_op(
    932           g, "TwoIntOutputs", [], [dtypes.int32, dtypes.int32], name="myop2")
    933     t3 = _apply_op(
    934         g,
    935         "Foo1", [t1, t2[1], t2[0]], [dtypes.float32, dtypes.int32],
    936         name="myop3")
    937     self.assertTrue(isinstance(t1, ops.Tensor))
    938     self.assertTrue(isinstance(t2, list))
    939     self.assertTrue(isinstance(t3, list))
    940     self.assertTrue(isinstance(t3[0], ops.Tensor))
    941     self.assertEqual("myop1", t1._as_node_def_input())
    942     self.assertEqual("myop2", t2[0]._as_node_def_input())
    943     self.assertEqual("myop2:1", t2[1]._as_node_def_input())
    944     self.assertEqual("myop3", t3[0]._as_node_def_input())
    945     # Validate that we got the right ops as well
    946     self.assertProtoEquals("name:'myop1' op:'FloatOutput'", t1.op.node_def)
    947     self.assertProtoEquals(
    948         "name:'myop2' op:'TwoIntOutputs' device:'/device:GPU:0'",
    949         t2[0].op.node_def)
    950     self.assertProtoEquals(
    951         "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo1'",
    952         t3[0].op.node_def)
    953 
    954   def testReferenceInput(self):
    955     g = ops.Graph()
    956     ref_t, nonref_t = _apply_op(
    957         g, "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32],
    958         name="op1")
    959     self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'",
    960                            ref_t.op.node_def)
    961     # NOTE(mrry): Must specify input_types to preserve ref-typed input.
    962     out_2 = _apply_op(
    963         g,
    964         "RefInputFloatInputIntOutput", [ref_t, nonref_t], [dtypes.int32],
    965         input_types=[dtypes.float32_ref, dtypes.float32],
    966         name="op2")
    967     self.assertProtoEquals(
    968         "op:'RefInputFloatInputIntOutput' name:'op2' input:'op1' input:'op1:1'",
    969         out_2.op.node_def)
    970     out_3 = _apply_op(
    971         g, "TwoFloatInputsIntOutput", [ref_t, nonref_t], [dtypes.int32],
    972         name="op3")
    973     self.assertProtoEquals(
    974         "op:'TwoFloatInputsIntOutput' name:'op3' input:'op1' input:'op1:1'",
    975         out_3.op.node_def)
    976 
    977 
    978 @test_util.with_c_api
    979 class NameStackTest(test_util.TensorFlowTestCase):
    980 
    981   def testBasics(self):
    982     g = ops.Graph()
    983     self.assertEqual("foo", g.unique_name("foo", mark_as_used=False))
    984     self.assertEqual("foo", g.unique_name("foo", mark_as_used=False))
    985     self.assertEqual("foo", g.unique_name("foo"))
    986     self.assertEqual("foo_1", g.unique_name("foo", mark_as_used=False))
    987     self.assertEqual("foo_1", g.unique_name("foo"))
    988     self.assertEqual("foo_2", g.unique_name("foo", mark_as_used=False))
    989     self.assertEqual("foo_2", g.unique_name("foo"))
    990     self.assertEqual("foo_1_1", g.unique_name("foo_1", mark_as_used=False))
    991     self.assertEqual("foo_1_1", g.unique_name("foo_1"))
    992     self.assertEqual("foo_1_2", g.unique_name("foo_1", mark_as_used=False))
    993     self.assertEqual("foo_1_2", g.unique_name("foo_1"))
    994     self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2", mark_as_used=False))
    995     self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2"))
    996     with g.name_scope("bar"):
    997       self.assertEqual("bar/foo", g.unique_name("foo", mark_as_used=False))
    998       self.assertEqual("bar/foo", g.unique_name("foo"))
    999       self.assertEqual("bar/foo_1", g.unique_name("foo", mark_as_used=False))
   1000       self.assertEqual("bar/foo_1", g.unique_name("foo"))
   1001       with g.name_scope(None):
   1002         self.assertEqual("foo_3", g.unique_name("foo", mark_as_used=False))
   1003         self.assertEqual("foo_3", g.unique_name("foo"))
   1004       with g.name_scope("baz"):
   1005         self.assertEqual(
   1006             "bar/baz/foo", g.unique_name(
   1007                 "foo", mark_as_used=False))
   1008         self.assertEqual("bar/baz/foo", g.unique_name("foo"))
   1009         self.assertEqual(
   1010             "bar/baz/foo_1", g.unique_name(
   1011                 "foo", mark_as_used=False))
   1012         self.assertEqual("bar/baz/foo_1", g.unique_name("foo"))
   1013       with g.name_scope("baz"):
   1014         self.assertEqual(
   1015             "bar/baz_1/foo", g.unique_name(
   1016                 "foo", mark_as_used=False))
   1017         self.assertEqual("bar/baz_1/foo", g.unique_name("foo"))
   1018         self.assertEqual(
   1019             "bar/baz_1/foo_1", g.unique_name(
   1020                 "foo", mark_as_used=False))
   1021         self.assertEqual("bar/baz_1/foo_1", g.unique_name("foo"))
   1022     with g.name_scope("quux"):
   1023       self.assertEqual("quux/foo", g.unique_name("foo", mark_as_used=False))
   1024       self.assertEqual("quux/foo", g.unique_name("foo"))
   1025     with g.name_scope("bar"):
   1026       with g.name_scope("baz"):
   1027         self.assertEqual(
   1028             "bar_1/baz/foo", g.unique_name(
   1029                 "foo", mark_as_used=False))
   1030         self.assertEqual("bar_1/baz/foo", g.unique_name("foo"))
   1031     self.assertEqual("foo_4", g.unique_name("foo", mark_as_used=False))
   1032     self.assertEqual("foo_4", g.unique_name("foo"))
   1033     self.assertEqual("bar_2", g.unique_name("bar", mark_as_used=False))
   1034     self.assertEqual("bar_2", g.unique_name("bar"))
   1035 
   1036   def testNameAndVariableScope(self):
   1037     with self.test_session() as sess:
   1038       with sess.graph.name_scope("l0"):
   1039         with variable_scope.variable_scope("l1"):
   1040           with sess.graph.name_scope("l1") as scope:
   1041             self.assertEqual("l0/l1/l1/", scope)
   1042             self.assertEqual(
   1043                 "l0/l1/l1/foo",
   1044                 sess.graph.unique_name(
   1045                     "foo", mark_as_used=False))
   1046             self.assertEqual("l0/l1/l1/foo", sess.graph.unique_name("foo"))
   1047           with sess.graph.name_scope("l2") as scope:
   1048             self.assertEqual("l0/l1/l2/", scope)
   1049             self.assertEqual(
   1050                 "l0/l1/l2/foo",
   1051                 sess.graph.unique_name(
   1052                     "foo", mark_as_used=False))
   1053             self.assertEqual("l0/l1/l2/foo", sess.graph.unique_name("foo"))
   1054 
   1055   def testOutOfOrderUniqueName(self):
   1056     g = ops.Graph()
   1057     self.assertEqual("foo_2", g.unique_name("foo_2"))
   1058     self.assertEqual("foo", g.unique_name("foo"))
   1059     self.assertEqual("foo_1", g.unique_name("foo"))
   1060     self.assertEqual("foo_3", g.unique_name("foo"))
   1061 
   1062   def testInvalidNameRaisesError(self):
   1063     g = ops.Graph()
   1064     with g.name_scope(""):  # Should not raise
   1065       pass
   1066     with g.name_scope("foo/"):  # Should not raise
   1067       with g.name_scope("_bar"):  # Should not raise
   1068         pass
   1069     with self.assertRaises(ValueError):
   1070       with g.name_scope("foo:0"):
   1071         pass
   1072     with self.assertRaises(ValueError):
   1073       with g.name_scope("_bar"):
   1074         pass
   1075 
   1076 
   1077 @test_util.with_c_api
   1078 class NameTest(test_util.TensorFlowTestCase):
   1079 
   1080   def testGenerateName(self):
   1081     g = ops.Graph()
   1082     op0 = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
   1083     self.assertEqual("TwoFloatOutputs", op0.name)
   1084     self.assertEqual("TwoFloatOutputs:0", op0.outputs[0].name)
   1085     self.assertEqual("TwoFloatOutputs:1", op0.outputs[1].name)
   1086 
   1087     op1 = g.create_op("FloatOutput", [], [dtypes.float32])
   1088     self.assertEqual("FloatOutput", op1.name)
   1089     self.assertEqual("FloatOutput:0", op1.outputs[0].name)
   1090 
   1091     op2 = g.create_op("FloatOutput", [], [dtypes.float32])
   1092     self.assertEqual("FloatOutput_1", op2.name)
   1093     self.assertEqual("FloatOutput_1:0", op2.outputs[0].name)
   1094 
   1095     op3 = g.create_op("FloatOutput", [], [dtypes.float32], name="my_op")
   1096     self.assertEqual("my_op", op3.name)
   1097     self.assertEqual("my_op:0", op3.outputs[0].name)
   1098 
   1099   def testNameScope(self):
   1100     g = ops.Graph()
   1101 
   1102     with g.name_scope("foo") as foo:
   1103       self.assertEqual("foo/", foo)
   1104       with g.name_scope("foo2") as foo2:
   1105         self.assertEqual("foo/foo2/", foo2)
   1106       with g.name_scope(None) as empty1:
   1107         self.assertEqual("", empty1)
   1108         with g.name_scope("foo3") as foo3:
   1109           self.assertEqual("foo3/", foo3)
   1110       with g.name_scope("") as empty2:
   1111         self.assertEqual("", empty2)
   1112 
   1113     self.assertEqual("FloatOutput",
   1114                      g.create_op("FloatOutput", [], [dtypes.float32]).name)
   1115     with g.name_scope("bar") as scope:
   1116       self.assertEqual("bar/FloatOutput",
   1117                        g.create_op("FloatOutput", [], [dtypes.float32]).name)
   1118       self.assertEqual("bar/FloatOutput_1",
   1119                        g.create_op("FloatOutput", [], [dtypes.float32]).name)
   1120       # If you use the value from "with .. as", that values is used as-is.
   1121       self.assertEqual(
   1122           "bar", g.create_op(
   1123               "FloatOutput", [], [dtypes.float32], name=scope).name)
   1124     with g.name_scope("baz") as scope:
   1125       with g.name_scope("quux"):
   1126         self.assertEqual("baz/quux/FloatOutput",
   1127                          g.create_op("FloatOutput", [], [dtypes.float32]).name)
   1128       # If you use the value from the enclosing "with .. as", nothing is pushed.
   1129       with g.name_scope(scope):
   1130         self.assertEqual("baz/FloatOutput",
   1131                          g.create_op("FloatOutput", [], [dtypes.float32]).name)
   1132         self.assertEqual(
   1133             "baz", g.create_op(
   1134                 "FloatOutput", [], [dtypes.float32], name=scope).name)
   1135         self.assertEqual(
   1136             "trailing",
   1137             g.create_op(
   1138                 "FloatOutput", [], [dtypes.float32], name="trailing/").name)
   1139     with g.name_scope("bar"):
   1140       self.assertEqual("bar_1/FloatOutput",
   1141                        g.create_op("FloatOutput", [], [dtypes.float32]).name)
   1142     with g.name_scope("bar/"):
   1143       self.assertEqual("bar/FloatOutput_2",
   1144                        g.create_op("FloatOutput", [], [dtypes.float32]).name)
   1145 
   1146 
   1147 @test_util.with_c_api
   1148 class DeviceTest(test_util.TensorFlowTestCase):
   1149 
   1150   def testNoDevice(self):
   1151     g = ops.Graph()
   1152     op = g.create_op("FloatOutput", [], [dtypes.float32])
   1153     self.assertDeviceEqual(None, op.device)
   1154     gd = g.as_graph_def()
   1155     self.assertProtoEqualsVersion("""
   1156       node { name: "FloatOutput" op: "FloatOutput" }
   1157     """, gd)
   1158 
   1159   def testDevicePartialString(self):
   1160     g = ops.Graph()
   1161     with g.device("/job:worker/replica:2"):
   1162       g.create_op("FloatOutput", [], [dtypes.float32])
   1163     gd = g.as_graph_def()
   1164     self.assertProtoEqualsVersion("""
   1165       node { name: "FloatOutput" op: "FloatOutput"
   1166              device: "/job:worker/replica:2" }
   1167     """, gd)
   1168 
   1169   def testDeviceFull(self):
   1170     g = ops.Graph()
   1171     with g.device(
   1172         pydev.DeviceSpec(
   1173             job="worker", replica=2, task=0, device_type="CPU",
   1174             device_index=3)):
   1175       g.create_op("FloatOutput", [], [dtypes.float32])
   1176     gd = g.as_graph_def()
   1177     self.assertProtoEqualsVersion("""
   1178       node { name: "FloatOutput" op: "FloatOutput"
   1179              device: "/job:worker/replica:2/task:0/device:CPU:3" }
   1180     """, gd)
   1181 
   1182   def testNesting(self):
   1183     g = ops.Graph()
   1184     with g.device("/job:worker/replica:2"):
   1185       g.create_op("FloatOutput", [], [dtypes.float32])
   1186       with g.device("/job:worker/replica:3/task:0"):
   1187         g.create_op("FloatOutput", [], [dtypes.float32])
   1188       g.create_op("FloatOutput", [], [dtypes.float32])
   1189     gd = g.as_graph_def()
   1190     self.assertProtoEqualsVersion("""
   1191       node { name: "FloatOutput" op: "FloatOutput"
   1192              device: "/job:worker/replica:2" }
   1193       node { name: "FloatOutput_1" op: "FloatOutput"
   1194              device: "/job:worker/replica:3/task:0" }
   1195       node { name: "FloatOutput_2" op: "FloatOutput"
   1196              device: "/job:worker/replica:2" }
   1197     """, gd)
   1198 
   1199   def testNestingString(self):
   1200     g = ops.Graph()
   1201     with g.device("/job:worker/replica:2"):
   1202       g.create_op("FloatOutput", [], [dtypes.float32])
   1203       with g.device("/job:worker/replica:3/task:0"):
   1204         g.create_op("FloatOutput", [], [dtypes.float32])
   1205       g.create_op("FloatOutput", [], [dtypes.float32])
   1206     gd = g.as_graph_def()
   1207     self.assertProtoEqualsVersion("""
   1208       node { name: "FloatOutput" op: "FloatOutput"
   1209              device: "/job:worker/replica:2" }
   1210       node { name: "FloatOutput_1" op: "FloatOutput"
   1211              device: "/job:worker/replica:3/task:0" }
   1212       node { name: "FloatOutput_2" op: "FloatOutput"
   1213              device: "/job:worker/replica:2" }
   1214     """, gd)
   1215 
   1216   def testNestingOverrideGpuCpu(self):
   1217     g = ops.Graph()
   1218     with g.device("/job:worker/replica:2/device:CPU:1"):
   1219       g.create_op("FloatOutput", [], [dtypes.float32])
   1220       with g.device("/job:worker/replica:2/device:GPU:2"):
   1221         g.create_op("FloatOutput", [], [dtypes.float32])
   1222       g.create_op("FloatOutput", [], [dtypes.float32])
   1223     gd = g.as_graph_def()
   1224     self.assertProtoEqualsVersion("""
   1225       node { name: "FloatOutput" op: "FloatOutput"
   1226              device: "/job:worker/replica:2/device:CPU:1"  }
   1227       node { name: "FloatOutput_1" op: "FloatOutput"
   1228              device: "/job:worker/replica:2/device:GPU:2" }
   1229       node { name: "FloatOutput_2" op: "FloatOutput"
   1230              device: "/job:worker/replica:2/device:CPU:1" }
   1231     """, gd)
   1232 
   1233   def testNestingWithMergeDeviceFunction(self):
   1234     g = ops.Graph()
   1235 
   1236     with g.device(pydev.merge_device("/device:GPU:0")):
   1237       g.create_op("FloatOutput", [], [dtypes.float32])
   1238       with g.device(pydev.merge_device("/job:worker")):
   1239         g.create_op("FloatOutput", [], [dtypes.float32])
   1240         with g.device(pydev.merge_device("/device:CPU:0")):
   1241           g.create_op("FloatOutput", [], [dtypes.float32])
   1242           with g.device(pydev.merge_device("/job:ps")):
   1243             g.create_op("FloatOutput", [], [dtypes.float32])
   1244             with g.device(pydev.merge_device(None)):
   1245               g.create_op("FloatOutput", [], [dtypes.float32])
   1246 
   1247     gd = g.as_graph_def()
   1248     self.assertProtoEqualsVersion("""
   1249       node { name: "FloatOutput" op: "FloatOutput"
   1250              device: "/device:GPU:0" }
   1251       node { name: "FloatOutput_1" op: "FloatOutput"
   1252              device: "/job:worker/device:GPU:0" }
   1253       node { name: "FloatOutput_2" op: "FloatOutput"
   1254              device: "/job:worker/device:CPU:0" }
   1255       node { name: "FloatOutput_3" op: "FloatOutput"
   1256              device: "/job:ps/device:CPU:0" }
   1257       node { name: "FloatOutput_4" op: "FloatOutput"
   1258              device: "/job:ps/device:CPU:0" }
   1259     """, gd)
   1260 
   1261   def testNestingWithDeviceStrings(self):
   1262     g = ops.Graph()
   1263 
   1264     with g.device("/device:GPU:0"):
   1265       g.create_op("FloatOutput", [], [dtypes.float32])
   1266       with g.device("/job:worker"):
   1267         g.create_op("FloatOutput", [], [dtypes.float32])
   1268         with g.device("/device:CPU:0"):
   1269           g.create_op("FloatOutput", [], [dtypes.float32])
   1270           with g.device("/job:ps"):
   1271             g.create_op("FloatOutput", [], [dtypes.float32])
   1272             with g.device(""):
   1273               g.create_op("FloatOutput", [], [dtypes.float32])
   1274 
   1275     gd = g.as_graph_def()
   1276     self.assertProtoEqualsVersion("""
   1277       node { name: "FloatOutput" op: "FloatOutput"
   1278              device: "/device:GPU:0" }
   1279       node { name: "FloatOutput_1" op: "FloatOutput"
   1280              device: "/job:worker/device:GPU:0" }
   1281       node { name: "FloatOutput_2" op: "FloatOutput"
   1282              device: "/job:worker/device:CPU:0" }
   1283       node { name: "FloatOutput_3" op: "FloatOutput"
   1284              device: "/job:ps/device:CPU:0" }
   1285       node { name: "FloatOutput_4" op: "FloatOutput"
   1286              device: "/job:ps/device:CPU:0" }
   1287     """, gd)
   1288 
   1289   def testNestingWithDeviceStringWildcard(self):
   1290     g = ops.Graph()
   1291 
   1292     with g.device("/device:GPU:7"):
   1293       g.create_op("FloatOutput", [], [dtypes.float32])
   1294       with g.device("/device:GPU:*"):
   1295         g.create_op("FloatOutput", [], [dtypes.float32])
   1296 
   1297     with g.device("/device:CPU:*"):
   1298       g.create_op("FloatOutput", [], [dtypes.float32])
   1299       with g.device("/device:CPU:5"):
   1300         g.create_op("FloatOutput", [], [dtypes.float32])
   1301 
   1302     gd = g.as_graph_def()
   1303     self.assertProtoEqualsVersion("""
   1304       node { name: "FloatOutput" op: "FloatOutput"
   1305              device: "/device:GPU:7" }
   1306       node { name: "FloatOutput_1" op: "FloatOutput"
   1307              device: "/device:GPU:7" }
   1308       node { name: "FloatOutput_2" op: "FloatOutput"
   1309              device: "/device:CPU:*" }
   1310       node { name: "FloatOutput_3" op: "FloatOutput"
   1311              device: "/device:CPU:5" }
   1312     """, gd)
   1313 
   1314   def testNoneClearsDefault(self):
   1315     g = ops.Graph()
   1316     with g.device("/job:worker/replica:2/device:CPU:1"):
   1317       g.create_op("FloatOutput", [], [dtypes.float32])
   1318       with g.device(None):
   1319         g.create_op("FloatOutput", [], [dtypes.float32])
   1320       g.create_op("FloatOutput", [], [dtypes.float32])
   1321     gd = g.as_graph_def()
   1322     self.assertProtoEqualsVersion("""
   1323       node { name: "FloatOutput" op: "FloatOutput"
   1324              device: "/job:worker/replica:2/device:CPU:1" }
   1325       node { name: "FloatOutput_1" op: "FloatOutput" }
   1326       node { name: "FloatOutput_2" op: "FloatOutput"
   1327              device: "/job:worker/replica:2/device:CPU:1" }
   1328     """, gd)
   1329 
   1330   def testNoneIgnoresOuterDeviceFunction(self):
   1331     g = ops.Graph()
   1332     with g.device(lambda op: "/job:worker/replica:2/device:CPU:1"):
   1333       g.create_op("FloatOutput", [], [dtypes.float32])
   1334       with g.device(None):
   1335         g.create_op("FloatOutput", [], [dtypes.float32])
   1336       g.create_op("FloatOutput", [], [dtypes.float32])
   1337     gd = g.as_graph_def()
   1338     self.assertProtoEqualsVersion("""
   1339       node { name: "FloatOutput" op: "FloatOutput"
   1340              device: "/job:worker/replica:2/device:CPU:1" }
   1341       node { name: "FloatOutput_1" op: "FloatOutput" }
   1342       node { name: "FloatOutput_2" op: "FloatOutput"
   1343              device: "/job:worker/replica:2/device:CPU:1" }
   1344     """, gd)
   1345 
   1346   def _overwritingDeviceFunction(self, unused_op):
   1347     # This device function unconditionally overwrites the device of ops.
   1348     #
   1349     # NOTE(mrry): Writing device functions like this is not
   1350     # recommended. Instead, in most cases you should use
   1351     # `pydev.merge_device("/job:ps")` or simply `"/job:ps"` as the
   1352     # argument to `tf.device()` and the device component will be merged in.
   1353     return "/job:overwrite"
   1354 
   1355   def testOverwritingBehavior(self):
   1356     g = ops.Graph()
   1357     with g.device(self._overwritingDeviceFunction):
   1358       g.create_op("FloatOutput", [], [dtypes.float32])
   1359       with g.device("/job:ps"):  # Will be overwritten.
   1360         g.create_op("FloatOutput", [], [dtypes.float32])
   1361       with g.device(pydev.merge_device("/job:ps")):  # Will be overwritten.
   1362         g.create_op("FloatOutput", [], [dtypes.float32])
   1363       with g.device(None):  # Disables overwriting device function
   1364         with g.device("/job:ps"):
   1365           g.create_op("FloatOutput", [], [dtypes.float32])
   1366       with g.device(None):  # Disables overwriting device function
   1367         with g.device(pydev.merge_device("/job:ps")):
   1368           g.create_op("FloatOutput", [], [dtypes.float32])
   1369     gd = g.as_graph_def()
   1370     self.assertProtoEqualsVersion("""
   1371       node { name: "FloatOutput" op: "FloatOutput"
   1372              device: "/job:overwrite" }
   1373       node { name: "FloatOutput_1" op: "FloatOutput"
   1374              device: "/job:overwrite" }
   1375       node { name: "FloatOutput_2" op: "FloatOutput"
   1376              device: "/job:overwrite" }
   1377       node { name: "FloatOutput_3" op: "FloatOutput"
   1378              device: "/job:ps" }
   1379       node { name: "FloatOutput_4" op: "FloatOutput"
   1380              device: "/job:ps" }
   1381     """, gd)
   1382 
   1383 
   1384 @test_util.with_c_api
   1385 class ObjectWithName(object):
   1386 
   1387   def __init__(self, name):
   1388     self._name = name
   1389 
   1390   @property
   1391   def name(self):
   1392     return self._name
   1393 
   1394 
   1395 @test_util.with_c_api
   1396 class CollectionTest(test_util.TensorFlowTestCase):
   1397 
   1398   def test_get_collections(self):
   1399     g = ops.Graph()
   1400     self.assertSequenceEqual(g.collections, [])
   1401     g.add_to_collection("key", 12)
   1402     g.add_to_collection("key", 15)
   1403     self.assertSequenceEqual(g.collections, ["key"])
   1404     g.add_to_collection("other", "foo")
   1405     self.assertSequenceEqual(sorted(g.collections), ["key", "other"])
   1406 
   1407   def test_add_to_collection(self):
   1408     g = ops.Graph()
   1409     g.add_to_collection("key", 12)
   1410     g.add_to_collection("other", "foo")
   1411     g.add_to_collection("key", 34)
   1412 
   1413     # Note that only blank1 is returned.
   1414     g.add_to_collection("blah", 27)
   1415     blank1 = ObjectWithName("prefix/foo")
   1416     g.add_to_collection("blah", blank1)
   1417     blank2 = ObjectWithName("junk/foo")
   1418     g.add_to_collection("blah", blank2)
   1419 
   1420     self.assertEqual([12, 34], g.get_collection("key"))
   1421     self.assertEqual([], g.get_collection("nothing"))
   1422     self.assertEqual([27, blank1, blank2], g.get_collection("blah"))
   1423     self.assertEqual([blank1], g.get_collection("blah", "prefix"))
   1424     self.assertEqual([blank1], g.get_collection("blah", ".*x"))
   1425 
   1426     # Make sure that get_collection() returns a first-level
   1427     # copy of the collection, while get_collection_ref() returns
   1428     # the original list.
   1429     other_collection_snapshot = g.get_collection("other")
   1430     other_collection_ref = g.get_collection_ref("other")
   1431     self.assertEqual(["foo"], other_collection_snapshot)
   1432     self.assertEqual(["foo"], other_collection_ref)
   1433     g.add_to_collection("other", "bar")
   1434     self.assertEqual(["foo"], other_collection_snapshot)
   1435     self.assertEqual(["foo", "bar"], other_collection_ref)
   1436     self.assertEqual(["foo", "bar"], g.get_collection("other"))
   1437     self.assertTrue(other_collection_ref is g.get_collection_ref("other"))
   1438 
   1439     # Verify that getting an empty collection ref returns a modifiable list.
   1440     empty_coll_ref = g.get_collection_ref("empty")
   1441     self.assertEqual([], empty_coll_ref)
   1442     empty_coll = g.get_collection("empty")
   1443     self.assertEqual([], empty_coll)
   1444     self.assertFalse(empty_coll is empty_coll_ref)
   1445     empty_coll_ref2 = g.get_collection_ref("empty")
   1446     self.assertTrue(empty_coll_ref2 is empty_coll_ref)
   1447     # Add to the collection.
   1448     empty_coll_ref.append("something")
   1449     self.assertEqual(["something"], empty_coll_ref)
   1450     self.assertEqual(["something"], empty_coll_ref2)
   1451     self.assertEqual([], empty_coll)
   1452     self.assertEqual(["something"], g.get_collection("empty"))
   1453     empty_coll_ref3 = g.get_collection_ref("empty")
   1454     self.assertTrue(empty_coll_ref3 is empty_coll_ref)
   1455 
   1456   def test_add_to_collections_uniquify(self):
   1457     g = ops.Graph()
   1458     g.add_to_collections([1, 2, 1], "key")
   1459     # Make sure "key" is not added twice
   1460     self.assertEqual(["key"], g.get_collection(1))
   1461 
   1462   def test_add_to_collections_from_list(self):
   1463     g = ops.Graph()
   1464     g.add_to_collections(["abc", "123"], "key")
   1465     self.assertEqual(["key"], g.get_collection("abc"))
   1466     self.assertEqual(["key"], g.get_collection("123"))
   1467 
   1468   def test_add_to_collections_from_tuple(self):
   1469     g = ops.Graph()
   1470     g.add_to_collections(("abc", "123"), "key")
   1471     self.assertEqual(["key"], g.get_collection("abc"))
   1472     self.assertEqual(["key"], g.get_collection("123"))
   1473 
   1474   def test_add_to_collections_from_generator(self):
   1475     g = ops.Graph()
   1476 
   1477     def generator():
   1478       yield "abc"
   1479       yield "123"
   1480 
   1481     g.add_to_collections(generator(), "key")
   1482     self.assertEqual(["key"], g.get_collection("abc"))
   1483     self.assertEqual(["key"], g.get_collection("123"))
   1484 
   1485   def test_add_to_collections_from_set(self):
   1486     g = ops.Graph()
   1487     g.add_to_collections(set(["abc", "123"]), "key")
   1488     self.assertEqual(["key"], g.get_collection("abc"))
   1489     self.assertEqual(["key"], g.get_collection("123"))
   1490 
   1491   def test_add_to_collections_from_string(self):
   1492     g = ops.Graph()
   1493     g.add_to_collections("abc", "key")
   1494     self.assertEqual(["key"], g.get_collection("abc"))
   1495 
   1496   def test_default_graph(self):
   1497     with ops.Graph().as_default():
   1498       ops.add_to_collection("key", 90)
   1499       ops.add_to_collection("key", 100)
   1500       # Collections are ordered.
   1501       self.assertEqual([90, 100], ops.get_collection("key"))
   1502 
   1503 
   1504 ops.NotDifferentiable("FloatOutput")
   1505 
   1506 
   1507 @ops.RegisterGradient("CopyOp")
   1508 def _CopyGrad(op, x_grad):  # pylint: disable=invalid-name
   1509   _ = op
   1510   return x_grad
   1511 
   1512 
   1513 @ops.RegisterGradient("copy_override")
   1514 def _CopyOverrideGrad(op, x_grad):  # pylint: disable=invalid-name
   1515   _ = op
   1516   return x_grad
   1517 
   1518 
   1519 @test_util.with_c_api
   1520 class RegistrationTest(test_util.TensorFlowTestCase):
   1521 
   1522   def testRegisterGradients(self):
   1523     x = test_ops.float_output()
   1524     y = test_ops.copy_op(x)
   1525     fn = ops.get_gradient_function(y.op)
   1526     self.assertEqual(_CopyGrad, fn)
   1527 
   1528   def testOverrideGradients(self):
   1529     g = ops.Graph()
   1530     with g.as_default():
   1531       x = test_ops.float_output()
   1532       with g.gradient_override_map({"CopyOp": "copy_override"}):
   1533         y = test_ops.copy_op(x)
   1534       fn = ops.get_gradient_function(y.op)
   1535       self.assertEqual(_CopyOverrideGrad, fn)
   1536 
   1537   def testNonExistentOverride(self):
   1538     g = ops.Graph()
   1539     with g.as_default():
   1540       x = test_ops.float_output()
   1541       with g.gradient_override_map({"CopyOp": "unknown_override"}):
   1542         y = test_ops.copy_op(x)
   1543       with self.assertRaisesRegexp(LookupError, "unknown_override"):
   1544         ops.get_gradient_function(y.op)
   1545 
   1546 
   1547 @test_util.with_c_api
   1548 class ComparisonTest(test_util.TensorFlowTestCase):
   1549 
   1550   def testMembershipAllowed(self):
   1551     g = ops.Graph()
   1552     t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1")
   1553     t2 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop2")
   1554     self.assertTrue(isinstance(t1, ops.Tensor))
   1555     self.assertTrue(isinstance(t2, ops.Tensor))
   1556     self.assertTrue(t1 in [t1])
   1557     self.assertTrue(t1 not in [t2])
   1558 
   1559 
   1560 @test_util.with_c_api
   1561 class ControlDependenciesTest(test_util.TensorFlowTestCase):
   1562 
   1563   @test_util.enable_c_api
   1564   def testBasic(self):
   1565     g = ops.Graph()
   1566     with g.as_default():
   1567       # Creating unregistered ops with _apply_op() doesn't work with the C API
   1568       # TODO(skyewm): address this more consistently. Possible solutions are
   1569       # to use registered ops in all tests, create a way to register ops in
   1570       # Python tests, or conditionally disable the op registration check in
   1571       # the C API.
   1572       a = constant_op.constant(1.0)
   1573       b = constant_op.constant(1.0)
   1574       with g.control_dependencies([a]):
   1575         c = constant_op.constant(1.0)
   1576         d = array_ops.identity(b)
   1577         e = array_ops.identity(c)
   1578 
   1579     self.assertEqual(c.op.control_inputs, [a.op])
   1580     self.assertEqual(d.op.control_inputs, [a.op])
   1581     # e should be dominated by c.
   1582     self.assertEqual(e.op.control_inputs, [])
   1583 
   1584   @test_util.run_in_graph_and_eager_modes()
   1585   def testEager(self):
   1586     def future():
   1587       future.calls += 1
   1588       return constant_op.constant(2.0)
   1589     future.calls = 0
   1590 
   1591     if context.in_graph_mode():
   1592       g = ops.Graph()
   1593       with g.as_default():
   1594         a = constant_op.constant(1.0)
   1595         b = future()
   1596         with g.control_dependencies([a, b]):
   1597           c = constant_op.constant(3.0)
   1598       self.assertEqual(c.op.control_inputs, [a.op, b.op])
   1599       self.assertEqual(future.calls, 1)
   1600     else:
   1601       a = constant_op.constant(1.0)
   1602       b = future()
   1603       with ops.control_dependencies([a, b]):
   1604         c = constant_op.constant(3.0)
   1605       self.assertEqual(future.calls, 1)
   1606 
   1607   def testBasicWithConversion(self):
   1608     g = ops.Graph()
   1609     a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1610 
   1611     class ConvertibleObj(object):
   1612 
   1613       def _as_graph_element(self):
   1614         return a
   1615 
   1616     with g.control_dependencies([ConvertibleObj()]):
   1617       c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1618 
   1619     self.assertEqual(c.op.control_inputs, [a.op])
   1620 
   1621   def testNested(self):
   1622     g = ops.Graph()
   1623     a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1624     a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1625     a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1626     a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1627 
   1628     with g.control_dependencies([a_1, a_2, a_3, a_4]):
   1629       b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1630 
   1631     with g.control_dependencies([a_1]):
   1632       with g.control_dependencies([a_2]):
   1633         with g.control_dependencies([a_3]):
   1634           with g.control_dependencies([a_4]):
   1635             b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1636 
   1637     self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op],
   1638                           b_1.op.control_inputs)
   1639     self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs)
   1640 
   1641   def testClear(self):
   1642     g = ops.Graph()
   1643     a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1644     a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1645     a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1646     a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1647 
   1648     with g.control_dependencies([a_1]):
   1649       with g.control_dependencies([a_2]):
   1650         with g.control_dependencies(None):
   1651           with g.control_dependencies([a_3]):
   1652             with g.control_dependencies([a_4]):
   1653               # deps [a_3, a_4]
   1654               b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1655             # deps = [a_3]
   1656             b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1657           # deps back to None
   1658           b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1659         # deps back to [a_1, a_2]
   1660         b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1661       # deps back to [a_1]
   1662       b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1663       with g.control_dependencies(None):
   1664         # deps are None again
   1665         b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1666 
   1667     self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
   1668     self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
   1669     self.assertItemsEqual([], b_none.op.control_inputs)
   1670     self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
   1671     self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
   1672     self.assertItemsEqual([], b_none2.op.control_inputs)
   1673 
   1674   def testComplex(self):
   1675     g = ops.Graph()
   1676 
   1677     # Usage pattern:
   1678     # * Nodes a_i are constants defined at the outermost scope, and are used
   1679     #   as control inputs for the ith nested scope.
   1680     # * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
   1681     # * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
   1682     # * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
   1683     # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
   1684 
   1685     a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1686     a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1687     a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1688     a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1689 
   1690     with g.control_dependencies([a_1]):
   1691       b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
   1692                       [dtypes.float32])
   1693       c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
   1694                       [dtypes.float32])
   1695       d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1],
   1696                       [dtypes.float32])
   1697       e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1698       with g.control_dependencies([a_2]):
   1699         b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
   1700                         [dtypes.float32])
   1701         c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
   1702                         [dtypes.float32])
   1703         d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2],
   1704                         [dtypes.float32])
   1705         e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1],
   1706                         [dtypes.float32])
   1707         with g.control_dependencies([a_3]):
   1708           b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
   1709                           [dtypes.float32])
   1710           c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
   1711                           [dtypes.float32])
   1712           d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3],
   1713                           [dtypes.float32])
   1714           e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2],
   1715                           [dtypes.float32])
   1716           with g.control_dependencies([a_4]):
   1717             b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
   1718                             [dtypes.float32])
   1719             c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
   1720                             [dtypes.float32])
   1721             d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4],
   1722                             [dtypes.float32])
   1723             e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3],
   1724                             [dtypes.float32])
   1725 
   1726     self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
   1727     self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs)
   1728     self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs)
   1729     self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs)
   1730 
   1731     self.assertItemsEqual([], c_1.op.control_inputs)
   1732     self.assertItemsEqual([a_2.op], c_2.op.control_inputs)
   1733     self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs)
   1734     self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs)
   1735 
   1736     self.assertItemsEqual([], d_1.op.control_inputs)
   1737     self.assertItemsEqual([], d_2.op.control_inputs)
   1738     self.assertItemsEqual([], d_3.op.control_inputs)
   1739     self.assertItemsEqual([], d_4.op.control_inputs)
   1740 
   1741     self.assertItemsEqual([a_1.op], e_1.op.control_inputs)
   1742     self.assertItemsEqual([a_2.op], e_2.op.control_inputs)
   1743     self.assertItemsEqual([a_3.op], e_3.op.control_inputs)
   1744     self.assertItemsEqual([a_4.op], e_4.op.control_inputs)
   1745 
   1746   def testRepeatedDependency(self):
   1747     g = ops.Graph()
   1748     a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
   1749     a_0, a_1 = a.outputs
   1750     with g.control_dependencies([a_0]):
   1751       b = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1752       with g.control_dependencies([a_1]):
   1753         c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1754 
   1755     self.assertEqual(b.op.control_inputs, [a])
   1756     self.assertEqual(c.op.control_inputs, [a])
   1757 
   1758   def testNoControlDependencyWithDataDependency(self):
   1759     g = ops.Graph()
   1760     a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1761     with g.control_dependencies([a]):
   1762       b = _apply_op(g, "Identity", [a], [dtypes.float32])
   1763 
   1764     self.assertEqual(b.op.control_inputs, [])
   1765 
   1766 
   1767 @test_util.with_c_api
   1768 class OpScopeTest(test_util.TensorFlowTestCase):
   1769 
   1770   @test_util.run_in_graph_and_eager_modes()
   1771   def testNames(self):
   1772     with ops.name_scope("foo") as foo:
   1773       self.assertEqual("foo/", foo)
   1774       with ops.name_scope("foo2") as foo2:
   1775         self.assertEqual("foo/foo2/", foo2)
   1776       with ops.name_scope(None) as empty1:
   1777         self.assertEqual("", empty1)
   1778         with ops.name_scope("foo3") as foo3:
   1779           self.assertEqual("foo3/", foo3)
   1780       with ops.name_scope("") as empty2:
   1781         self.assertEqual("", empty2)
   1782     with ops.name_scope("foo/") as outer_foo:
   1783       self.assertEqual("foo/", outer_foo)
   1784       with ops.name_scope("") as empty3:
   1785         self.assertEqual("", empty3)
   1786       with ops.name_scope("foo4") as foo4:
   1787         self.assertEqual("foo/foo4/", foo4)
   1788       with ops.name_scope("foo5//") as foo5:
   1789         self.assertEqual("foo5//", foo5)
   1790         with ops.name_scope("foo6") as foo6:
   1791           self.assertEqual("foo5//foo6/", foo6)
   1792       with ops.name_scope("/") as foo7:
   1793         self.assertEqual("/", foo7)
   1794       with ops.name_scope("//") as foo8:
   1795         self.assertEqual("//", foo8)
   1796       with ops.name_scope("a//b/c") as foo9:
   1797         self.assertEqual("foo/a//b/c/", foo9)
   1798     with ops.name_scope("a//b/c") as foo10:
   1799       self.assertEqual("a//b/c/", foo10)
   1800 
   1801   @test_util.run_in_graph_and_eager_modes()
   1802   def testEagerDefaultScopeName(self):
   1803     with ops.name_scope(None, "default") as scope:
   1804       self.assertEqual(scope, "default/")
   1805       with ops.name_scope(None, "default2") as scope2:
   1806         self.assertEqual(scope2, "default/default2/")
   1807 
   1808   def testNoScopeName(self):
   1809     g0 = ops.Graph()
   1810     values = [
   1811         g0.create_op("A", [], [dtypes.float32]),
   1812         g0.create_op("B", [], [dtypes.float32])
   1813     ]
   1814     with self.assertRaises(ValueError):
   1815       with ops.name_scope(None, values=values):
   1816         pass
   1817     with self.assertRaises(ValueError):
   1818       with ops.name_scope(None, None, values):
   1819         pass
   1820 
   1821   def testEmptyScopeName(self):
   1822     g0 = ops.Graph()
   1823     a = g0.create_op("A", [], [dtypes.float32])
   1824     b = g0.create_op("B", [], [dtypes.float32])
   1825     with ops.name_scope("", values=[a, b]) as scope:
   1826       self.assertEqual("", scope)
   1827       self.assertEqual(g0, ops.get_default_graph())
   1828     with ops.name_scope("", "my_default_scope", [a, b]) as scope:
   1829       self.assertEqual("", scope)
   1830       self.assertEqual(g0, ops.get_default_graph())
   1831 
   1832   def testDefaultScopeName(self):
   1833     g0 = ops.Graph()
   1834     a = g0.create_op("A", [], [dtypes.float32])
   1835     b = g0.create_op("B", [], [dtypes.float32])
   1836     scope_name = "my_scope"
   1837     default_scope_name = "my_default_scope"
   1838     with ops.name_scope(scope_name, default_scope_name, [a, b]) as scope:
   1839       self.assertEqual("%s/" % scope_name, scope)
   1840       self.assertEqual(g0, ops.get_default_graph())
   1841     with ops.name_scope(None, default_scope_name, [a, b]) as scope:
   1842       self.assertEqual("%s/" % default_scope_name, scope)
   1843       self.assertEqual(g0, ops.get_default_graph())
   1844 
   1845   def _testGraphElements(self, graph_elements):
   1846     scope_name = "my_scope"
   1847     with ops.name_scope(scope_name, values=graph_elements) as scope:
   1848       self.assertEqual("%s/" % scope_name, scope)
   1849       self.assertEqual(graph_elements[0].graph, ops.get_default_graph())
   1850     g1 = ops.Graph()
   1851     a = g1.create_op("A", [], [dtypes.float32])
   1852     with self.assertRaises(ValueError):
   1853       with ops.name_scope(scope_name, values=graph_elements + [a]):
   1854         pass
   1855 
   1856   def testTensor(self):
   1857     g0 = ops.Graph()
   1858     a = g0.create_op("A", [], [dtypes.float32])
   1859     b = g0.create_op("B", [], [dtypes.float32])
   1860     self._testGraphElements([a, b])
   1861 
   1862   def testSparseTensor(self):
   1863     g0 = ops.Graph()
   1864     a = g0.create_op("A", [], [dtypes.float32])
   1865     b = g0.create_op("B", [], [dtypes.float32])
   1866     sparse = sparse_tensor.SparseTensor(
   1867         _apply_op(g0, "Int64Output", [], [dtypes.int64]),
   1868         _apply_op(g0, "FloatOutput", [], [dtypes.float32]),
   1869         _apply_op(g0, "Int64Output", [], [dtypes.int64]))
   1870     self._testGraphElements([a, sparse, b])
   1871 
   1872   def testVariable(self):
   1873     g0 = ops.Graph()
   1874     with g0.as_default():
   1875       variable = variables.Variable([1.0])
   1876     a = g0.create_op("A", [], [dtypes.float32])
   1877     b = g0.create_op("B", [], [dtypes.float32])
   1878     self._testGraphElements([a, variable, b])
   1879 
   1880 
   1881 class InitScopeTest(test_util.TensorFlowTestCase):
   1882 
   1883   def testClearsControlDependencies(self):
   1884     g = ops.Graph()
   1885     a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1886     a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1887     a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1888     a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1889 
   1890     with g.as_default():
   1891       with g.control_dependencies([a_1]):
   1892         with g.control_dependencies([a_2]):
   1893           with ops.init_scope():
   1894             with g.control_dependencies([a_3]):
   1895               with g.control_dependencies([a_4]):
   1896                 # deps [a_3, a_4]
   1897                 b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1898               # deps = [a_3]
   1899               b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1900             # deps back to None
   1901             b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1902           # deps back to [a_1, a_2]
   1903           b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1904         # deps back to [a_1]
   1905         b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1906         with ops.init_scope():
   1907           # deps are None again
   1908           b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   1909 
   1910     self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
   1911     self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
   1912     self.assertItemsEqual([], b_none.op.control_inputs)
   1913     self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
   1914     self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
   1915     self.assertItemsEqual([], b_none2.op.control_inputs)
   1916 
   1917   def testLiftsOpsFromFunctions(self):
   1918     g0 = ops.Graph()
   1919     g1 = ops.Graph()
   1920     g1._building_function = True  # pylint: disable=protected-access
   1921     g2 = ops.Graph()
   1922     g2._building_function = True  # pylint: disable=protected-access
   1923 
   1924     with g0.as_default():
   1925       with g1.as_default():
   1926         with g2.as_default():
   1927           with ops.init_scope():
   1928             _ = constant_op.constant(1.0)
   1929 
   1930     self.assertEqual(len(g2.get_operations()), 0)
   1931     self.assertEqual(len(g1.get_operations()), 0)
   1932     self.assertEqual(len(g0.get_operations()), 1)
   1933 
   1934   def testComposes(self):
   1935     g0 = ops.Graph()
   1936     g1 = ops.Graph()
   1937     g1._building_function = True  # pylint: disable=protected-access
   1938     g2 = ops.Graph()
   1939     g2._building_function = True  # pylint: disable=protected-access
   1940     g3 = ops.Graph()
   1941     g3._building_function = False  # pylint: disable=protected-access
   1942 
   1943     with g0.as_default():
   1944       with g1.as_default():
   1945         with ops.init_scope():
   1946           # This op should be lifted into g0.
   1947           _ = constant_op.constant(1.0)
   1948           self.assertIs(g0, ops.get_default_graph())
   1949           self.assertEqual(len(g2.get_operations()), 0)
   1950           self.assertEqual(len(g1.get_operations()), 0)
   1951           self.assertEqual(len(g0.get_operations()), 1)
   1952         with g2.as_default():
   1953           with ops.init_scope():
   1954             # This op should be lifted into g0.
   1955             _ = constant_op.constant(1.0)
   1956             self.assertIs(g0, ops.get_default_graph())
   1957             with g3.as_default():
   1958               with ops.init_scope():
   1959                 # This op should be lifted into g3, because g3 is not building a
   1960                 # function.
   1961                 _ = constant_op.constant(1.0)
   1962                 self.assertIs(g3, ops.get_default_graph())
   1963 
   1964     self.assertEqual(len(g3.get_operations()), 1)
   1965     self.assertEqual(len(g2.get_operations()), 0)
   1966     self.assertEqual(len(g1.get_operations()), 0)
   1967     self.assertEqual(len(g0.get_operations()), 2)
   1968 
   1969   def testEscapesToEagerContext(self):
   1970     g = ops.Graph()
   1971     g._building_function = True  # pylint: disable=protected-access
   1972     with context.eager_mode():
   1973       with context.graph_mode():
   1974         with g.as_default():
   1975           with ops.init_scope():
   1976             # Because g is building a function, init_scope should
   1977             # escape out to the eager context.
   1978             self.assertTrue(context.in_eager_mode())
   1979           # g should be reinstated as the default graph, and the
   1980           # graph context should be re-entered.
   1981           self.assertIs(g, ops.get_default_graph())
   1982           self.assertTrue(context.in_graph_mode())
   1983 
   1984   def testAllGraphsBuildingFunctionsRaisesError(self):
   1985     g = ops.Graph()
   1986     g._building_function = True  # pylint: disable=protected-access
   1987     with g.as_default():
   1988       with self.assertRaises(AssertionError):
   1989         with ops.init_scope():
   1990           pass
   1991 
   1992   def testStaysInEagerWhenOnlyEagerContextActive(self):
   1993     with context.eager_mode():
   1994       with ops.init_scope():
   1995         self.assertTrue(context.eager_mode())
   1996       self.assertTrue(context.eager_mode())
   1997 
   1998   def testEscapesDefunWhenInEagerMode(self):
   1999 
   2000     def function_with_variables():
   2001       with ops.init_scope():
   2002         v = resource_variable_ops.ResourceVariable(3)
   2003       return v.assign_add(1)
   2004 
   2005     with context.eager_mode():
   2006       # Each invocation of function_with_variables recreates a variable.
   2007       self.assertEqual(4, int(function_with_variables()))
   2008       self.assertEqual(4, int(function_with_variables()))
   2009 
   2010       compiled = eager_function.defun(function_with_variables)
   2011       # The init_scope in function_with_variables lifts the variable out
   2012       # of the graph function constructed by defun; hence,
   2013       # compiled now appears to be stateful.
   2014       self.assertEqual(4, int(compiled()))
   2015       self.assertEqual(5, int(compiled()))
   2016 
   2017   def testEscapesDefunWhenInGraphMode(self):
   2018     def function_with_variables(name):
   2019       with ops.init_scope():
   2020         _ = variable_scope.get_variable(name, shape=(1,))
   2021 
   2022     g = ops.Graph()
   2023     with g.as_default():
   2024       with self.test_session():
   2025         # First ensure that graphs that are not building functions are
   2026         # not escaped.
   2027         function_with_variables("foo")
   2028         with self.assertRaisesRegexp(ValueError,
   2029                                      r"Variable foo already exists.*"):
   2030           # This will fail because reuse is not set to True.
   2031           function_with_variables("foo")
   2032 
   2033         compiled = eager_function.defun(function_with_variables)
   2034         compiled("bar")
   2035         self.assertEqual(
   2036             len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2)
   2037 
   2038         # The second call to `compiled` should not create variables: the
   2039         # init_scope has lifted the variable creation code out of the defun.
   2040         compiled("bar")
   2041         self.assertEqual(
   2042             len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2)
   2043 
   2044   def testEscapesNestedDefun(self):
   2045 
   2046     def inner_function():
   2047       with ops.init_scope():
   2048         v = resource_variable_ops.ResourceVariable(1)
   2049       return v.assign_add(2)
   2050 
   2051     def outer_function(inner=None):
   2052       with ops.init_scope():
   2053         v0 = resource_variable_ops.ResourceVariable(0)
   2054       return v0.assign_add(1) + inner()
   2055 
   2056     with context.eager_mode():
   2057       # Each invocation of outer_function recreates variables.
   2058       self.assertEqual(4, int(outer_function(inner=inner_function)))
   2059       self.assertEqual(4, int(outer_function(inner=inner_function)))
   2060 
   2061       compiled_inner = eager_function.defun(inner_function)
   2062       compiled_outer = eager_function.defun(outer_function)
   2063       # The init_scope lifts variables out of the graph functions
   2064       # constructed by defun; hence, compiled_outer should now appear to be
   2065       # stateful.
   2066       self.assertEqual(4, int(compiled_outer(inner=compiled_inner)))
   2067       self.assertEqual(7, int(compiled_outer(inner=compiled_inner)))
   2068 
   2069   def testInstallsDefaultGraphWhenGraphStackIsEmptyInGraphMode(self):
   2070     with context.graph_mode():
   2071       # pylint: disable=protected-access
   2072       self.assertEqual(len(ops._default_graph_stack.stack), 0)
   2073       with ops.init_scope():
   2074         self.assertGreater(len(ops._default_graph_stack.stack), 0)
   2075       self.assertEqual(len(ops._default_graph_stack.stack), 0)
   2076       # pylint: enable=protected-access
   2077 
   2078   def testPreservesNameScopeInGraphConstruction(self):
   2079     with ops.Graph().as_default():
   2080       function_graph = ops.Graph()
   2081       with function_graph.as_default():
   2082         with ops.name_scope("inner"), ops.init_scope():
   2083           self.assertEqual(ops.get_name_scope(), "inner")
   2084       self.assertEqual(ops.get_name_scope(), "")
   2085 
   2086   def testPreservesNameScopeInEagerExecution(self):
   2087     with context.eager_mode():
   2088       def foo():
   2089         with ops.name_scope("inner"), ops.init_scope():
   2090           if context.in_graph_mode():
   2091             self.assertEqual(ops.get_name_scope(), "inner")
   2092           else:
   2093             # A trailing slash is always appended when eager execution is
   2094             # enabled.
   2095             self.assertEqual(context.context().scope_name, "inner/")
   2096       foo()
   2097       self.assertEqual(ops.get_name_scope(), "")
   2098       foo_compiled = eager_function.defun(foo)
   2099       foo_compiled()
   2100       self.assertEqual(ops.get_name_scope(), "")
   2101 
   2102 
   2103 @test_util.with_c_api
   2104 class GraphTest(test_util.TensorFlowTestCase):
   2105 
   2106   def setUp(self):
   2107     ops.reset_default_graph()
   2108 
   2109   def _AssertDefault(self, expected):
   2110     self.assertIs(expected, ops.get_default_graph())
   2111 
   2112   def testResetDefaultGraphNesting(self):
   2113     g0 = ops.Graph()
   2114     with self.assertRaises(AssertionError):
   2115       with g0.as_default():
   2116         ops.reset_default_graph()
   2117 
   2118   def testGraphContextManager(self):
   2119     g0 = ops.Graph()
   2120     with g0.as_default() as g1:
   2121       self.assertIs(g0, g1)
   2122 
   2123   def testDefaultGraph(self):
   2124     orig = ops.get_default_graph()
   2125     self._AssertDefault(orig)
   2126     g0 = ops.Graph()
   2127     self._AssertDefault(orig)
   2128     context_manager_0 = g0.as_default()
   2129     self._AssertDefault(orig)
   2130     with context_manager_0 as g0:
   2131       self._AssertDefault(g0)
   2132       with ops.Graph().as_default() as g1:
   2133         self._AssertDefault(g1)
   2134       self._AssertDefault(g0)
   2135     self._AssertDefault(orig)
   2136 
   2137   def testPreventFeeding(self):
   2138     g = ops.Graph()
   2139     a = constant_op.constant(2.0)
   2140     self.assertTrue(g.is_feedable(a))
   2141     g.prevent_feeding(a)
   2142     self.assertFalse(g.is_feedable(a))
   2143 
   2144   def testPreventFetching(self):
   2145     g = ops.Graph()
   2146     a = constant_op.constant(2.0)
   2147     self.assertTrue(g.is_fetchable(a))
   2148     g.prevent_fetching(a.op)
   2149     self.assertFalse(g.is_fetchable(a))
   2150 
   2151   def testAsGraphElementConversions(self):
   2152 
   2153     class ConvertibleObj(object):
   2154 
   2155       def _as_graph_element(self):
   2156         return "FloatOutput:0"
   2157 
   2158     class NonConvertibleObj(object):
   2159 
   2160       pass
   2161 
   2162     g = ops.Graph()
   2163     a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
   2164     self.assertEqual(a, g.as_graph_element(ConvertibleObj()))
   2165     with self.assertRaises(TypeError):
   2166       g.as_graph_element(NonConvertibleObj())
   2167 
   2168   # Regression test against creating custom __del__ functions in classes
   2169   # involved in cyclic references, e.g. Graph and Operation. (Python won't gc
   2170   # cycles that require calling a __del__ method, because the __del__ method can
   2171   # theoretically increase the object's refcount to "save" it from gc, and any
   2172   # already-deleted objects in the cycle would have be to restored.)
   2173   def testGarbageCollected(self):
   2174     # Create a graph we can delete and a weak reference to monitor if it's gc'd
   2175     g = ops.Graph()
   2176     g_ref = weakref.ref(g)
   2177     # Create some ops
   2178     with g.as_default():
   2179       a = constant_op.constant(2.0)
   2180       b = constant_op.constant(3.0)
   2181       c = math_ops.add(a, b)
   2182     # Create a session we can delete
   2183     with session.Session(graph=g) as sess:
   2184       sess.run(c)
   2185     # Delete all references and trigger gc
   2186     del g
   2187     del a
   2188     del b
   2189     del c
   2190     del sess
   2191     gc.collect()
   2192     self.assertIsNone(g_ref())
   2193 
   2194   def testRunnableAfterInvalidShape(self):
   2195     with ops.Graph().as_default():
   2196       with self.assertRaises(ValueError):
   2197         math_ops.add([1, 2], [1, 2, 3])
   2198       a = constant_op.constant(1)
   2199       with session.Session() as sess:
   2200         sess.run(a)
   2201 
   2202   def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
   2203     g = ops.Graph()
   2204     with g.as_default():
   2205       with g._kernel_label_map({"KernelLabelRequired": "overload_1"}):
   2206         with self.assertRaises(ValueError):
   2207           test_ops.kernel_label_required(1)
   2208       a = constant_op.constant(1)
   2209       with session.Session() as sess:
   2210         sess.run(a)
   2211 
   2212 
   2213 @test_util.with_c_api
   2214 class AttrScopeTest(test_util.TensorFlowTestCase):
   2215 
   2216   def _get_test_attrs(self):
   2217     x = control_flow_ops.no_op()
   2218     try:
   2219       a = compat.as_text(x.get_attr("_A"))
   2220     except ValueError:
   2221       a = None
   2222     try:
   2223       b = compat.as_text(x.get_attr("_B"))
   2224     except ValueError:
   2225       b = None
   2226     return (a, b)
   2227 
   2228   def testNoLabel(self):
   2229     with self.test_session():
   2230       self.assertAllEqual((None, None), self._get_test_attrs())
   2231 
   2232   def testLabelMap(self):
   2233     with self.test_session() as sess:
   2234       a1 = self._get_test_attrs()
   2235       with sess.graph._attr_scope({
   2236           "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo"))
   2237       }):
   2238         a2 = self._get_test_attrs()
   2239         with sess.graph._attr_scope({
   2240             "_A": None,
   2241             "_B": attr_value_pb2.AttrValue(s=compat.as_bytes("bar"))
   2242         }):
   2243           a3 = self._get_test_attrs()
   2244           with sess.graph._attr_scope({
   2245               "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("baz"))
   2246           }):
   2247             a4 = self._get_test_attrs()
   2248           a5 = self._get_test_attrs()
   2249         a6 = self._get_test_attrs()
   2250       a7 = self._get_test_attrs()
   2251 
   2252       self.assertAllEqual((None, None), a1)
   2253       self.assertAllEqual(("foo", None), a2)
   2254       self.assertAllEqual((None, "bar"), a3)
   2255       self.assertAllEqual(("baz", "bar"), a4)
   2256       self.assertAllEqual((None, "bar"), a5)
   2257       self.assertAllEqual(("foo", None), a6)
   2258       self.assertAllEqual((None, None), a7)
   2259 
   2260 
   2261 ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape)
   2262 
   2263 
   2264 @test_util.with_c_api
   2265 class KernelLabelTest(test_util.TensorFlowTestCase):
   2266 
   2267   @test_util.enable_c_api
   2268   def testNoLabel(self):
   2269     with self.test_session():
   2270       self.assertAllEqual(b"My label is: default",
   2271                           test_ops.kernel_label().eval())
   2272 
   2273   def testLabelMap(self):
   2274     with self.test_session() as sess:
   2275       default_1 = test_ops.kernel_label()
   2276       # pylint: disable=protected-access
   2277       with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}):
   2278         overload_1_1 = test_ops.kernel_label()
   2279         with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}):
   2280           overload_2 = test_ops.kernel_label()
   2281           with sess.graph._kernel_label_map({"KernelLabel": ""}):
   2282             default_2 = test_ops.kernel_label()
   2283         overload_1_2 = test_ops.kernel_label()
   2284       # pylint: enable=protected-access
   2285       default_3 = test_ops.kernel_label()
   2286 
   2287       self.assertAllEqual(b"My label is: default", default_1.eval())
   2288       self.assertAllEqual(b"My label is: default", default_2.eval())
   2289       self.assertAllEqual(b"My label is: default", default_3.eval())
   2290       self.assertAllEqual(b"My label is: overload_1", overload_1_1.eval())
   2291       self.assertAllEqual(b"My label is: overload_1", overload_1_2.eval())
   2292       self.assertAllEqual(b"My label is: overload_2", overload_2.eval())
   2293 
   2294 
   2295 @test_util.with_c_api
   2296 class AsGraphDefTest(test_util.TensorFlowTestCase):
   2297 
   2298   def testGraphDefVersion(self):
   2299     """Test that the graphdef version is plumbed through to kernels."""
   2300     with ops.Graph().as_default() as g:
   2301       version = g.graph_def_versions.producer
   2302       with self.test_session(graph=g):
   2303         v = test_ops.graph_def_version().eval()
   2304         self.assertEqual(version, v)
   2305 
   2306   def testAddShapes(self):
   2307     with ops.Graph().as_default() as g:
   2308       t1, t2, t3, t4, t5 = _apply_op(g, "FiveFloatOutputs", [],
   2309                                      [dtypes.float32] * 5)
   2310       t1.set_shape(None)
   2311       t2.set_shape([])
   2312       t3.set_shape([None])
   2313       t4.set_shape([43, 37])
   2314       t5.set_shape([43, None])
   2315 
   2316       b = constant_op.constant(1.0)  # pylint: disable=unused-variable
   2317 
   2318       gd = g.as_graph_def(add_shapes=True)
   2319       self.assertProtoEqualsVersion("""
   2320       node { name: "FiveFloatOutputs" op: "FiveFloatOutputs"
   2321         attr {
   2322           key: "_output_shapes"
   2323           value {
   2324             list {
   2325               shape { unknown_rank: true }
   2326               shape { }
   2327               shape { dim { size: -1 } }
   2328               shape { dim { size: 43 } dim { size: 37 } }
   2329               shape { dim { size: 43 } dim { size: -1 } }
   2330             }
   2331           }
   2332         }
   2333       }
   2334     node { name: "Const" op: "Const"
   2335       attr {
   2336         key: "_output_shapes"
   2337         value {
   2338           list {
   2339             shape { }
   2340           }
   2341         }
   2342       }
   2343       attr {
   2344         key: "dtype"
   2345         value { type: DT_FLOAT }
   2346       }
   2347       attr {
   2348         key: "value"
   2349         value {
   2350           tensor {
   2351             dtype: DT_FLOAT
   2352             tensor_shape { }
   2353          float_val: 1.0  } } } }
   2354       """, gd)
   2355 
   2356 
   2357 @ops.RegisterStatistics("a", "flops")
   2358 def _calc_a_forward_flops(unused_graph, unused_node):
   2359   return ops.OpStats("flops", 20)
   2360 
   2361 
   2362 @test_util.with_c_api
   2363 class StatisticsTest(test_util.TensorFlowTestCase):
   2364 
   2365   def testRegisteredNode(self):
   2366     graph = ops.Graph()
   2367     node = ops._NodeDef("a", "an_a")
   2368     flops = ops.get_stats_for_node_def(graph, node, "flops")
   2369     self.assertEqual(20, flops.value)
   2370     missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat")
   2371     self.assertEqual(None, missing_stat.value)
   2372 
   2373   def testUnregisteredNode(self):
   2374     graph = ops.Graph()
   2375     node = ops._NodeDef("b", "a_b")
   2376     weight_params = ops.get_stats_for_node_def(graph, node, "weight_params")
   2377     self.assertEqual(None, weight_params.value)
   2378 
   2379   def testAccumulateStatistics(self):
   2380     flops_total = ops.OpStats("flops")
   2381     self.assertEqual(None, flops_total.value)
   2382     second_flops = ops.OpStats("flops", 3)
   2383     flops_total += second_flops
   2384     self.assertEqual(3, flops_total.value)
   2385 
   2386 
   2387 @test_util.with_c_api
   2388 class ColocationGroupTest(test_util.TensorFlowTestCase):
   2389 
   2390   def testBasic(self):
   2391     a = constant_op.constant([2.0], name="a")
   2392     with ops.colocate_with(a.op):
   2393       b = constant_op.constant(3.0)
   2394     c = constant_op.constant(4.0)
   2395     self.assertEqual([b"loc:@a"], a.op.colocation_groups())
   2396     self.assertEqual([b"loc:@a"], b.op.colocation_groups())
   2397     with self.assertRaises(ValueError):
   2398       c.op.get_attr("_class")
   2399 
   2400   def testColocationDeviceInteraction(self):
   2401     with ops.device("/cpu:0"):
   2402       with ops.device("/device:GPU:0"):
   2403         a = constant_op.constant([2.0], name="a")
   2404       with ops.colocate_with(a.op):
   2405         # 'b' is created in the scope of /cpu:0, but it is
   2406         # colocated with 'a', which is on '/device:GPU:0'.  colocate_with
   2407         # overrides devices because it is a stronger constraint.
   2408         b = constant_op.constant(3.0)
   2409     self.assertEqual([b"loc:@a"], b.op.colocation_groups())
   2410     self.assertEqual(a.op.device, b.op.device)
   2411 
   2412   def testColocationCanonicalization(self):
   2413     with ops.device("/device:GPU:0"):
   2414       _ = constant_op.constant(2.0)
   2415     with ops.device(lambda op: "/device:GPU:0"):
   2416       b = constant_op.constant(3.0)
   2417     with ops.get_default_graph().colocate_with(b):
   2418       with ops.device("/device:GPU:0"):
   2419         c = constant_op.constant(4.0)
   2420 
   2421     # A's device will be /device:GPU:0
   2422     # B's device will be /device:GPU:0
   2423     # C's device will be /device:GPU:0 because it
   2424     # inherits B's device name, after canonicalizing the names.
   2425     self.assertEqual(b.op.device, c.op.device)
   2426 
   2427   def testLocationOverrides(self):
   2428     with ops.device("/cpu:0"):
   2429       with ops.device("/device:GPU:0"):
   2430         a = constant_op.constant([2.0], name="a")
   2431         # Note that this colocation is "redundant", since we are
   2432         # within the scope of "/device:GPU:0".  However, we would like to
   2433         # preserve in the GraphDef that these two ops should be
   2434         # colocated in a portable way.
   2435         with ops.colocate_with(a.op):
   2436           b = constant_op.constant(3.0)
   2437         c = constant_op.constant(4.0)
   2438       d = constant_op.constant(5.0)
   2439 
   2440     self.assertEqual([b"loc:@a"], b.op.colocation_groups())
   2441     self.assertEqual("/device:GPU:0", a.op.device)
   2442     self.assertEqual(a.op.device, b.op.device)
   2443 
   2444     # Test that device function stack is restored.
   2445     self.assertEqual("/device:GPU:0", c.op.device)
   2446     self.assertEqual("/device:CPU:0", d.op.device)
   2447 
   2448   def testNestedColocateWith(self):
   2449     a = constant_op.constant([2.0], name="a")
   2450     with ops.colocate_with(a.op):
   2451       b = constant_op.constant(3.0)
   2452       with ops.colocate_with(b.op):
   2453         c = constant_op.constant(4.0)
   2454     self.assertEqual([b"loc:@a"], b.op.colocation_groups())
   2455     self.assertEqual([b"loc:@a"], c.op.colocation_groups())
   2456 
   2457   def testMultiColocationGroups(self):
   2458     a = constant_op.constant([2.0], name="a")
   2459     b = constant_op.constant(3.0, name="b")
   2460     with ops.colocate_with(a.op):
   2461       with ops.colocate_with(b.op):
   2462         c = constant_op.constant(4.0)
   2463     self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups()))
   2464 
   2465   def testColocationIgnoreStack(self):
   2466     a = constant_op.constant([2.0], name="a")
   2467     b = constant_op.constant(3.0, name="b")
   2468     with ops.colocate_with(a.op):
   2469       with ops.colocate_with(b.op, ignore_existing=True):
   2470         c = constant_op.constant(4.0)
   2471     self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups()))
   2472 
   2473   def testColocateWithReset(self):
   2474     a = constant_op.constant([2.0], name="a")
   2475     with ops.colocate_with(a.op):
   2476       b = constant_op.constant(3.0, name="b")
   2477       with ops.colocate_with(None, ignore_existing=True):
   2478         c = constant_op.constant(4.0, name="c")
   2479     self.assertEqual([b"loc:@a"], b.op.colocation_groups())
   2480     self.assertEqual([b"loc:@c"], c.op.colocation_groups())
   2481 
   2482   def testColocateWithInitialNoneThenNested(self):
   2483     a = constant_op.constant([2.0], name="a")
   2484     with ops.colocate_with(a.op):
   2485       with ops.colocate_with(None, ignore_existing=True):
   2486         b = constant_op.constant(3.0, name="b")
   2487         with ops.colocate_with(b.op):
   2488           c = constant_op.constant(4.0, name="c")
   2489     self.assertEqual([b"loc:@b"], b.op.colocation_groups())
   2490     self.assertEqual([b"loc:@b"], c.op.colocation_groups())
   2491 
   2492   def testColocateVariables(self):
   2493     a = variables.Variable([2.0], name="a")
   2494     with ops.colocate_with(a.op):
   2495       b = variables.Variable([3.0], name="b")
   2496     self.assertEqual([b"loc:@a"], b.op.colocation_groups())
   2497 
   2498   def testInconsistentDeviceWithinColocate(self):
   2499     with ops.device("/device:GPU:0"):
   2500       a = constant_op.constant([2.0], name="a")
   2501       with ops.colocate_with(a.op):
   2502         # This is allowed due to legacy but clearly wrong, since we
   2503         # should really be colocating with 'a'.  We allow devices to
   2504         # override colocate_with, but we log warnings to suggest that
   2505         # this is probably unintentional or misguided.
   2506         with ops.device("/cpu:0"):
   2507           b = constant_op.constant([3.0], name="b")
   2508 
   2509     self.assertEqual("/device:CPU:0", b.device)
   2510 
   2511 
   2512 @test_util.with_c_api
   2513 class DeprecatedTest(test_util.TensorFlowTestCase):
   2514 
   2515   def testSuccess(self):
   2516     # TODO(skyewm): make g.graph_def_versions work with the C API enabled
   2517     if ops._USE_C_API: return
   2518 
   2519     with ops.Graph().as_default() as g:
   2520       g.graph_def_versions.producer = 7
   2521       old = test_ops.old()
   2522       with self.test_session(graph=g):
   2523         old.run()
   2524 
   2525   def _error(self):
   2526     return ((r"Op Old is not available in GraphDef version %d\. "
   2527              r"It has been removed in version 8\. For reasons\.") %
   2528             versions.GRAPH_DEF_VERSION)
   2529 
   2530   def testGraphConstructionFail(self):
   2531     with ops.Graph().as_default():
   2532       with self.assertRaisesRegexp(NotImplementedError, self._error()):
   2533         test_ops.old()
   2534 
   2535   def testGraphExecutionFail(self):
   2536     # TODO(skyewm): make g.graph_def_versions work with the C API enabled
   2537     if ops._USE_C_API: return
   2538 
   2539     with ops.Graph().as_default() as g:
   2540       g.graph_def_versions.producer = 7
   2541       old = test_ops.old()
   2542       g.graph_def_versions.producer = versions.GRAPH_DEF_VERSION
   2543       with self.test_session(graph=g):
   2544         with self.assertRaisesRegexp(errors.UnimplementedError, self._error()):
   2545           old.run()
   2546 
   2547 
   2548 @test_util.with_c_api
   2549 class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase):
   2550 
   2551   def testSuccess(self):
   2552     op = ops.Operation(
   2553         ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
   2554     t = op.outputs[0]
   2555     self.assertTrue(ops.is_dense_tensor_like(t))
   2556 
   2557     v = variables.Variable([17])
   2558     self.assertTrue(ops.is_dense_tensor_like(v))
   2559 
   2560   class BadClassNoName(object):
   2561     pass
   2562 
   2563   class BadClassBadName(object):
   2564 
   2565     def name(self):
   2566       pass
   2567 
   2568   class BadClassNoDtype(object):
   2569 
   2570     @property
   2571     def name(self):
   2572       pass
   2573 
   2574   class BadClassBadDtype(object):
   2575 
   2576     @property
   2577     def name(self):
   2578       pass
   2579 
   2580     def dtype(self):
   2581       pass
   2582 
   2583   def testBadClass(self):
   2584     with self.assertRaisesRegexp(TypeError, "`name`"):
   2585       ops.register_dense_tensor_like_type(
   2586           DenseTensorLikeTypeTest.BadClassNoName)
   2587     with self.assertRaisesRegexp(TypeError, "`name`"):
   2588       ops.register_dense_tensor_like_type(
   2589           DenseTensorLikeTypeTest.BadClassBadName)
   2590     with self.assertRaisesRegexp(TypeError, "`dtype`"):
   2591       ops.register_dense_tensor_like_type(
   2592           DenseTensorLikeTypeTest.BadClassNoDtype)
   2593     with self.assertRaisesRegexp(TypeError, "`dtype`"):
   2594       ops.register_dense_tensor_like_type(
   2595           DenseTensorLikeTypeTest.BadClassBadDtype)
   2596 
   2597 
   2598 @test_util.with_c_api
   2599 class NameScopeTest(test_util.TensorFlowTestCase):
   2600 
   2601   def testStripAndPrependScope(self):
   2602     strs = [
   2603         "hidden1/hidden1/weights",  # Same prefix. Should strip.
   2604         "hidden1///hidden1/weights",  # Extra "/". Should strip.
   2605         "^hidden1/hidden1/weights",  # Same prefix. Should strip.
   2606         "loc:@hidden1/hidden1/weights",  # Same prefix. Should strip.
   2607         "hhidden1/hidden1/weights",  # Different prefix. Should keep.
   2608         "hidden1"
   2609     ]  # Not a prefix. Should keep.
   2610     expected_striped = [
   2611         "hidden1/weights", "hidden1/weights", "^hidden1/weights",
   2612         "loc:@hidden1/weights", "hhidden1/hidden1/weights", "hidden1"
   2613     ]
   2614     expected_prepended = [
   2615         "hidden2/hidden1/weights", "hidden2/hidden1/weights",
   2616         "^hidden2/hidden1/weights", "loc:@hidden2/hidden1/weights",
   2617         "hidden2/hhidden1/hidden1/weights", "hidden2/hidden1"
   2618     ]
   2619     name_scope_to_strip = "hidden1"
   2620     name_scope_to_add = "hidden2"
   2621     for es, ep, s in zip(expected_striped, expected_prepended, strs):
   2622       striped = ops.strip_name_scope(s, name_scope_to_strip)
   2623       self.assertEqual(es, striped)
   2624       self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add))
   2625 
   2626   def testGetNameScope(self):
   2627     with ops.Graph().as_default() as g:
   2628       with ops.name_scope("scope1"):
   2629         with ops.name_scope("scope2"):
   2630           with ops.name_scope("scope3"):
   2631             self.assertEqual("scope1/scope2/scope3", g.get_name_scope())
   2632           self.assertEqual("scope1/scope2", g.get_name_scope())
   2633         self.assertEqual("scope1", g.get_name_scope())
   2634       self.assertEqual("", g.get_name_scope())
   2635 
   2636   def testTwoGraphs(self):
   2637 
   2638     def f():
   2639       g1 = ops.Graph()
   2640       g2 = ops.Graph()
   2641       with g1.as_default():
   2642         with g2.as_default():
   2643           with ops.name_scope("_"):
   2644             pass
   2645 
   2646     self.assertRaisesRegexp(ValueError, "'_' is not a valid scope name", f)
   2647 
   2648 
   2649 @test_util.with_c_api
   2650 class TracebackTest(test_util.TensorFlowTestCase):
   2651 
   2652   def testTracebackWithStartLines(self):
   2653     with self.test_session() as sess:
   2654       a = constant_op.constant(2.0)
   2655       sess.run(
   2656           a,
   2657           options=config_pb2.RunOptions(
   2658               trace_level=config_pb2.RunOptions.FULL_TRACE))
   2659       self.assertTrue(sess.graph.get_operations())
   2660 
   2661       # Tests that traceback_with_start_lines is the same as traceback
   2662       # but includes one more element at the end.
   2663       for op in sess.graph.get_operations():
   2664         self.assertEquals(len(op.traceback), len(op.traceback_with_start_lines))
   2665         for frame, frame_with_start_line in zip(
   2666             op.traceback, op.traceback_with_start_lines):
   2667           self.assertEquals(5, len(frame_with_start_line))
   2668           self.assertEquals(frame, frame_with_start_line[:-1])
   2669 
   2670 
   2671 @test_util.with_c_api
   2672 class OutputTypesTest(test_util.TensorFlowTestCase):
   2673   """Tests Operation._output_types property.
   2674 
   2675   This test should not exist as _output_types is a private property.
   2676   This property is used by util.copy_elements and its tests would normally
   2677   cover Operation._output_types. However, we can't yet run these tests in C
   2678   API mode because their use _set_device method. This test will be deleted
   2679   once we port _set_device and run the copy tests with C API on.
   2680   """
   2681   # TODO(iga): Remove this test
   2682 
   2683   def setUp(self):
   2684     self.prev_use_c_api = ops._USE_C_API  # pylint: disable=protected-access
   2685     ops._USE_C_API = True  # pylint: disable=protected-access
   2686 
   2687   def tearDown(self):
   2688     ops._USE_C_API = self.prev_use_c_api  # pylint: disable=protected-access
   2689 
   2690   def testOneOutput(self):
   2691     g = ops.Graph()
   2692     with g.as_default():
   2693       # Using a constant because creating unregistered ops
   2694       # doesn't work with the C API.
   2695       op = constant_op.constant(12, dtype=dtypes.uint16).op
   2696       # pylint: disable=protected-access
   2697       self.assertEqual([types_pb2.DT_UINT16], op._output_types)
   2698       # pylint: enable=protected-access
   2699 
   2700   def testTwoDifferentOutputs(self):
   2701     g = ops.Graph()
   2702     with g.as_default():
   2703       x = constant_op.constant([1, 1, 2, 4, 4, 4, 7, 8, 8],
   2704                                dtype=dtypes.double)
   2705       y, _ = gen_array_ops._unique(x)
   2706       self.assertEqual([types_pb2.DT_DOUBLE, types_pb2.DT_INT32],
   2707                        y.op._output_types)  # pylint: disable=protected-access
   2708 
   2709   def testThreeOutputs(self):
   2710     g = ops.Graph()
   2711     with g.as_default():
   2712       # Using a split operationt because creating unregistered ops
   2713       # doesn't work with the C API.
   2714       a = constant_op.constant("abc", dtype=dtypes.string, shape=[5, 30])
   2715       split0, _, _ = array_ops.split(a, [4, 15, 11], 1)
   2716       # pylint: disable=protected-access
   2717       self.assertEqual([types_pb2.DT_STRING] * 3, split0.op._output_types)
   2718       # pylint: enable=protected-access
   2719 
   2720 
   2721 @test_util.with_c_api
   2722 class EnableEagerExecutionTest(test_util.TensorFlowTestCase):
   2723 
   2724   def testBadArgumentsToEnableEagerExecution(self):
   2725     with self.assertRaisesRegexp(TypeError, "config must be a tf.ConfigProto"):
   2726       ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT)
   2727     with self.assertRaisesRegexp(ValueError, "device_policy must be one of"):
   2728       c = config_pb2.ConfigProto()
   2729       ops.enable_eager_execution(c, c)
   2730 
   2731 
   2732 if __name__ == "__main__":
   2733   googletest.main()
   2734