1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tensorflow.python.framework.ops.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import gc 22 import weakref 23 24 from tensorflow.core.framework import attr_value_pb2 25 from tensorflow.core.framework import types_pb2 26 from tensorflow.core.protobuf import config_pb2 27 from tensorflow.python.client import session 28 from tensorflow.python.eager import context 29 from tensorflow.python.eager import function as eager_function 30 from tensorflow.python.framework import common_shapes 31 from tensorflow.python.framework import constant_op 32 from tensorflow.python.framework import device as pydev 33 from tensorflow.python.framework import dtypes 34 from tensorflow.python.framework import errors 35 from tensorflow.python.framework import function 36 from tensorflow.python.framework import ops 37 from tensorflow.python.framework import sparse_tensor 38 from tensorflow.python.framework import tensor_shape 39 from tensorflow.python.framework import tensor_util 40 from tensorflow.python.framework import test_ops 41 from tensorflow.python.framework import test_util 42 from tensorflow.python.framework import versions 43 from tensorflow.python.ops import array_ops 44 from tensorflow.python.ops import control_flow_ops 45 from tensorflow.python.ops import gen_array_ops 46 from tensorflow.python.ops import math_ops 47 from tensorflow.python.ops import resource_variable_ops 48 from tensorflow.python.ops import resources 49 from tensorflow.python.ops import variable_scope 50 from tensorflow.python.ops import variables 51 import tensorflow.python.ops.gradients # pylint: disable=unused-import 52 from tensorflow.python.platform import googletest 53 from tensorflow.python.util import compat 54 55 ops._set_call_cpp_shape_fn(common_shapes.call_cpp_shape_fn) 56 57 58 @test_util.with_c_api 59 class ResourceTest(test_util.TensorFlowTestCase): 60 61 def testBuildGraph(self): 62 with self.test_session(): 63 pt = test_ops.stub_resource_handle_op(container="a", shared_name="b") 64 test_ops.resource_create_op(pt).run() 65 66 def testInitialize(self): 67 with self.test_session(): 68 handle = test_ops.stub_resource_handle_op(container="a", shared_name="b") 69 resources.register_resource( 70 handle=handle, 71 create_op=test_ops.resource_create_op(handle), 72 is_initialized_op=test_ops.resource_initialized_op(handle)) 73 self.assertEquals( 74 len( 75 resources.report_uninitialized_resources( 76 resources.shared_resources()).eval()), 1) 77 resources.initialize_resources(resources.shared_resources()).run() 78 self.assertEquals( 79 len( 80 resources.report_uninitialized_resources( 81 resources.shared_resources()).eval()), 0) 82 83 84 @test_util.with_c_api 85 class TensorAndShapeTest(test_util.TensorFlowTestCase): 86 87 def testShape(self): 88 op = ops.Operation( 89 ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) 90 t = op.outputs[0] 91 self.assertEqual(tensor_shape.unknown_shape(), t.get_shape()) 92 t.set_shape([1, 2, 3]) 93 self.assertEqual([1, 2, 3], t.get_shape()) 94 95 def testIterable(self): 96 op = ops.Operation( 97 ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) 98 t = op.outputs[0] 99 self.assertTrue(isinstance(t, ops.Tensor)) 100 with self.assertRaisesRegexp(TypeError, "iter"): 101 for _ in t: 102 pass 103 104 def testAddShape(self): 105 with self.test_session(): 106 a = array_ops.zeros([2, 3]) 107 b = array_ops.ones([1, 3]) 108 c = a + b 109 self.assertEqual([2, 3], c.shape) 110 111 def testUnknownDim(self): 112 with self.test_session(): 113 a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3]) 114 b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3]) 115 c = a + b 116 self.assertEqual([2, None, 3], c.shape.as_list()) 117 118 def testUnknownShape(self): 119 with self.test_session(): 120 a = array_ops.placeholder(dtype=dtypes.float32, shape=None) 121 b = array_ops.ones([1, 3]) 122 c = a + b 123 self.assertEqual(tensor_shape.unknown_shape(), c.shape) 124 125 def testScalarShape(self): 126 with self.test_session(): 127 a = array_ops.placeholder(dtype=dtypes.float32, shape=[]) 128 b = array_ops.ones([]) 129 c = a + b 130 self.assertEqual(tensor_shape.scalar(), c.shape) 131 132 def testShapeFunctionError(self): 133 with self.test_session(): 134 a = array_ops.ones([1, 2, 3]) 135 b = array_ops.ones([4, 5, 6]) 136 with self.assertRaisesRegexp( 137 ValueError, 138 r"Dimensions must be equal, but are 2 and 5 for 'add' \(op: 'Add'\) " 139 r"with input shapes: \[1,2,3\], \[4,5,6\]."): 140 _ = a + b 141 142 143 @test_util.with_c_api 144 class IndexedSlicesTest(test_util.TensorFlowTestCase): 145 146 def testToTensor(self): 147 with self.test_session(): 148 values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) 149 indices = constant_op.constant([0, 2]) 150 dense_shape = constant_op.constant([3, 2]) 151 x = ops.IndexedSlices(values, indices, dense_shape) 152 tensor = ops.convert_to_tensor(x, name="tensor") 153 self.assertAllEqual(tensor.eval(), [[2, 3], [0, 0], [5, 7]]) 154 155 def testNegation(self): 156 with self.test_session(): 157 values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) 158 indices = constant_op.constant([0, 2]) 159 x = -ops.IndexedSlices(values, indices) 160 self.assertAllEqual(x.values.eval(), [[-2, -3], [-5, -7]]) 161 self.assertAllEqual(x.indices.eval(), [0, 2]) 162 163 def testScalarMul(self): 164 with self.test_session(): 165 values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) 166 indices = constant_op.constant([0, 2]) 167 x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices)) 168 self.assertAllEqual(x.values.eval(), [[-4, -6], [-10, -14]]) 169 self.assertAllEqual(x.indices.eval(), [0, 2]) 170 171 172 @test_util.with_c_api 173 class NodeDefConstructorTest(test_util.TensorFlowTestCase): 174 175 def testNoArgs(self): 176 nodedef = ops._NodeDef("None", "bar") 177 self.assertProtoEquals("op: 'None' name: 'bar'", nodedef) 178 179 def testArgs(self): 180 nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*") 181 self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'", 182 nodedef) 183 nodedef = ops._NodeDef("foo", "bar", device=pydev.DeviceSpec(job="j")) 184 self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef) 185 186 187 def _apply_op(g, *args, **kwargs): 188 op = g.create_op(*args, **kwargs) 189 if len(op.outputs) == 1: 190 return op.outputs[0] 191 else: 192 return op.outputs 193 194 195 @test_util.with_c_api 196 class OperationTest(test_util.TensorFlowTestCase): 197 198 def testNoInputs(self): 199 op = test_ops.float_output_string_output(name="myop").a.op 200 self.assertEqual(2, len(op.values())) 201 self.assertEqual(0, len(op.inputs)) 202 self.assertEqual("myop", op.name) 203 204 float_t, label_str_t = op.values() 205 self.assertEqual(dtypes.float32, float_t.dtype) 206 self.assertEqual(op, float_t.op) 207 self.assertEqual(0, float_t._value_index) 208 self.assertEqual(0, len(float_t.consumers())) 209 self.assertEqual("myop", float_t._as_node_def_input()) 210 211 self.assertEqual(dtypes.string, label_str_t.dtype) 212 self.assertEqual(op, label_str_t.op) 213 self.assertEqual(1, label_str_t._value_index) 214 self.assertEqual(0, len(label_str_t.consumers())) 215 self.assertEqual("myop:1", label_str_t._as_node_def_input()) 216 217 self.assertProtoEquals("op:'FloatOutputStringOutput' name:'myop'", 218 op.node_def) 219 220 def testNoOutputs(self): 221 op1 = test_ops.float_output(name="myop1").op 222 float_t, = op1.values() 223 op2 = test_ops.float_input(float_t, name="myop2") 224 self.assertEqual(0, len(op2.values())) 225 self.assertEqual(1, len(op2.inputs)) 226 self.assertIs(float_t, op2.inputs[0]) 227 228 self.assertEqual(1, len(float_t.consumers())) 229 self.assertEqual(op2, float_t.consumers()[0]) 230 231 self.assertProtoEquals("op:'FloatOutput' name:'myop1'", op1.node_def) 232 self.assertProtoEquals("op:'FloatInput' name:'myop2' input:'myop1'", 233 op2.node_def) 234 235 def testInputsAndOutputs(self): 236 op1 = test_ops.float_output(name="myop1").op 237 self.assertEqual(1, len(op1.values())) 238 float1_t, = op1.values() 239 240 op2 = test_ops.float_output_string_output(name="myop2").a.op 241 self.assertEqual(2, len(op2.values())) 242 float2_t, label2_str_t = op2.values() 243 244 # Note that we consume label2_str_t twice here. 245 op3 = test_ops.foo2(float1_t, label2_str_t, label2_str_t, name="myop3").d.op 246 self.assertEqual(2, len(op3.values())) 247 248 self.assertEqual(1, len(float1_t.consumers())) 249 self.assertEqual(op3, float1_t.consumers()[0]) 250 251 self.assertEqual(0, len(float2_t.consumers())) 252 253 self.assertEqual(2, len(label2_str_t.consumers())) 254 self.assertEqual(op3, label2_str_t.consumers()[0]) 255 self.assertEqual(op3, label2_str_t.consumers()[1]) 256 257 self.assertProtoEquals(""" 258 op:'Foo2' name:'myop3' 259 input:'myop1' input:'myop2:1' input:'myop2:1' 260 """, op3.node_def) 261 262 def testDeviceObject(self): 263 op = ops.Operation(ops._NodeDef("None", "myop"), ops.Graph(), [], []) 264 op._set_device("/job:goo/device:GPU:0") 265 self.assertProtoEquals( 266 "op:'None' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def) 267 op = ops.Operation(ops._NodeDef("None", "op2"), ops.Graph(), [], []) 268 op._set_device( 269 pydev.DeviceSpec( 270 job="muu", device_type="CPU", device_index=0)) 271 self.assertProtoEquals( 272 "op:'None' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def) 273 274 def testReferenceInput(self): 275 g = ops.Graph() 276 op1 = ops.Operation( 277 ops._NodeDef("RefOutputFloatOutput", "op1"), g, [], 278 [dtypes.float32_ref, dtypes.float32]) 279 g._add_op(op1) 280 self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def) 281 self.assertEquals([], list(op1.inputs)) 282 ref_t, nonref_t = op1.values() 283 # NOTE(mrry): Must specify input_types to preserve ref-typed input. 284 op2 = ops.Operation( 285 ops._NodeDef("RefInputFloatInput", "op2"), 286 g, [ref_t, nonref_t], [], 287 input_types=[dtypes.float32_ref, dtypes.float32]) 288 g._add_op(op2) 289 self.assertProtoEquals( 290 "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'", 291 op2.node_def) 292 self.assertEquals([ref_t, nonref_t], list(op2.inputs)) 293 op3 = ops.Operation( 294 ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], []) 295 g._add_op(op3) 296 self.assertProtoEquals( 297 "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'", 298 op3.node_def) 299 300 def testInvalidNames(self): 301 g = ops.Graph() 302 with self.assertRaises(ValueError): 303 ops.Operation(ops._NodeDef("op", ""), g) 304 with self.assertRaises(ValueError): 305 ops.Operation(ops._NodeDef("op", "_invalid"), g) 306 with self.assertRaises(ValueError): 307 ops.Operation(ops._NodeDef("op", "-invalid"), g) 308 with self.assertRaises(ValueError): 309 ops.Operation(ops._NodeDef("op", "/invalid"), g) 310 with self.assertRaises(ValueError): 311 ops.Operation(ops._NodeDef("op", "invalid:0"), g) 312 313 def testNoShapeFunction(self): 314 op = test_ops.a() 315 self.assertEqual(tensor_shape.unknown_shape(), op.get_shape()) 316 317 def testConvertToTensorNestedArray(self): 318 with self.test_session(): 319 values = [[2], [3], [5], [7]] 320 tensor = ops.convert_to_tensor(values) 321 self.assertAllEqual((4, 1), tensor.get_shape().as_list()) 322 self.assertAllEqual(values, tensor.eval()) 323 324 def testShapeTuple(self): 325 with self.test_session(): 326 c = constant_op.constant(1) 327 self.assertEqual(c._shape_tuple(), ()) # pylint: disable=protected-access 328 329 def testConvertToTensorEager(self): 330 with context.eager_mode(): 331 t = constant_op.constant(1) 332 self.assertTrue(isinstance(t, ops.EagerTensor)) 333 converted = ops.convert_to_tensor(t) 334 self.assertTrue(isinstance(converted, ops.EagerTensor)) 335 converted = ops.convert_to_tensor(1) 336 self.assertTrue(isinstance(converted, ops.EagerTensor)) 337 338 def testConvertToTensorNestedTuple(self): 339 with self.test_session(): 340 values = ((2,), (3,), (5,), (7,)) 341 tensor = ops.convert_to_tensor(values) 342 self.assertAllEqual((4, 1), tensor.get_shape().as_list()) 343 self.assertAllEqual(values, ops.convert_to_tensor(values).eval()) 344 345 def testConvertToTensorNestedTensors(self): 346 with self.test_session(): 347 values = ((2,), (3,), (5,), (7,)) 348 tensor = ops.convert_to_tensor( 349 [constant_op.constant(row) for row in values]) 350 self.assertAllEqual((4, 1), tensor.get_shape().as_list()) 351 self.assertAllEqual(values, tensor.eval()) 352 tensor = ops.convert_to_tensor( 353 [[constant_op.constant(v) for v in row] for row in values]) 354 self.assertAllEqual((4, 1), tensor.get_shape().as_list()) 355 self.assertAllEqual(values, tensor.eval()) 356 357 def testConvertToTensorNestedMix(self): 358 with self.test_session(): 359 values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7])) 360 tensor = ops.convert_to_tensor(values) 361 self.assertAllEqual((4, 1), tensor.get_shape().as_list()) 362 self.assertAllEqual(((2,), (3,), (5,), (7,)), tensor.eval()) 363 364 def testConvertToTensorPreferred(self): 365 with self.test_session(): 366 values = [2, 3, 5, 7] 367 tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32) 368 self.assertEqual(dtypes.float32, tensor.dtype) 369 370 with self.test_session(): 371 # Convert empty tensor to anything. 372 values = [] 373 tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64) 374 self.assertEqual(dtypes.int64, tensor.dtype) 375 376 with self.test_session(): 377 # The preferred dtype is a type error and will convert to 378 # float32 instead. 379 values = [1.23] 380 tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64) 381 self.assertEqual(dtypes.float32, tensor.dtype) 382 383 def testConvertToInvalidTensorType(self): 384 with self.assertRaises(TypeError): 385 # Forcing an invalid dtype should fail with a type error. 386 values = [1.23] 387 _ = ops.convert_to_tensor(values, dtype=dtypes.int64) 388 389 def testNoConvert(self): 390 # Operation cannot be converted to Tensor. 391 op = control_flow_ops.no_op() 392 with self.assertRaisesRegexp(TypeError, 393 r"Can't convert Operation '.*' to Tensor"): 394 ops.convert_to_tensor(op) 395 396 def testStr(self): 397 node_def = ops._NodeDef("None", "op1") 398 op = ops.Operation(node_def, ops.Graph(), [], [dtypes.float32]) 399 self.assertEqual(str(node_def), str(op)) 400 401 def testRepr(self): 402 op = ops.Operation( 403 ops._NodeDef("None", "op1"), ops.Graph(), [], [dtypes.float32]) 404 self.assertEqual("<tf.Operation 'op1' type=None>", repr(op)) 405 406 def testGetAttr(self): 407 op = test_ops.default_attrs() 408 self.assertEqual(op.get_attr("string_val"), b"abc") 409 self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""]) 410 self.assertEqual(op.get_attr("int_val"), 123) 411 self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3]) 412 self.assertEqual(op.get_attr("float_val"), 10.0) 413 self.assertEqual(op.get_attr("float_list_val"), [10.0]) 414 self.assertEqual(op.get_attr("bool_val"), True) 415 self.assertEqual(op.get_attr("bool_list_val"), [True, False]) 416 self.assertEqual(op.get_attr("shape_val"), 417 tensor_shape.as_shape([2, 1]).as_proto()) 418 self.assertEqual(op.get_attr("shape_list_val"), 419 [tensor_shape.as_shape([]).as_proto(), 420 tensor_shape.as_shape([1]).as_proto()]) 421 self.assertEqual(op.get_attr("tensor_val"), 422 tensor_util.make_tensor_proto(1, dtypes.int32)) 423 self.assertEqual(op.get_attr("tensor_list_val"), 424 [tensor_util.make_tensor_proto(1, dtypes.int32)]) 425 426 type_val = op.get_attr("type_val") 427 # First check that type_val is a DType, because the assertEquals will work 428 # no matter what since DType overrides __eq__ 429 self.assertIsInstance(type_val, dtypes.DType) 430 self.assertEqual(type_val, dtypes.int32) 431 432 type_list_val = op.get_attr("type_list_val") 433 self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val)) 434 self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32]) 435 436 @function.Defun(dtypes.float32, func_name="MyFunc") 437 def func(x): 438 return x 439 440 op = test_ops.func_attr(func) 441 self.assertEqual(op.get_attr("f"), 442 attr_value_pb2.NameAttrList(name="MyFunc")) 443 444 # Try fetching missing attr 445 if ops._USE_C_API: 446 error_msg = "Operation 'FuncAttr' has no attr named 'FakeAttr'." 447 else: 448 error_msg = "No attr named 'FakeAttr' in name: \"FuncAttr\"" 449 450 with self.assertRaisesRegexp(ValueError, error_msg): 451 op.get_attr("FakeAttr") 452 453 # TODO(b/65162920): remove this test when users who are directly mutating the 454 # node_def have been updated to proper usage. 455 def testSetAttr(self): 456 op = test_ops.int_attr().op 457 op._set_attr("foo", attr_value_pb2.AttrValue(i=2)) 458 # TODO(skyewm): add node_def check 459 self.assertEqual(op.get_attr("foo"), 2) 460 461 # TODO(nolivia): test all error cases 462 def testAddControlInput(self): 463 # The C API dedups redundant control edges, pure Python does not 464 if ops._USE_C_API: return 465 with ops.Graph().as_default(): 466 x = constant_op.constant(1).op 467 y = constant_op.constant(2).op 468 z = constant_op.constant(3).op 469 z._add_control_input(x) # pylint: disable=protected-access 470 self.assertEqual(z.control_inputs, [x]) 471 z._add_control_input(x) # pylint: disable=protected-access 472 self.assertEqual(z.control_inputs, [x, x]) 473 z._add_control_inputs([x, y, y]) # pylint: disable=protected-access 474 self.assertEqual(z.control_inputs, [x, x, x, y, y]) 475 476 def testAddControlInputC(self): 477 # The C API dedups redundant control edges, pure Python does not 478 if not ops._USE_C_API: return 479 with ops.Graph().as_default(): 480 x = constant_op.constant(1).op 481 y = constant_op.constant(2).op 482 z = constant_op.constant(3).op 483 z._add_control_input(x) # pylint: disable=protected-access 484 self.assertEqual(z.control_inputs, [x]) 485 z._add_control_input(x) # pylint: disable=protected-access 486 self.assertEqual(z.control_inputs, [x]) 487 z._add_control_inputs([x, y, y]) # pylint: disable=protected-access 488 self.assertEqual(z.control_inputs, [x, y]) 489 490 def testRemoveAllControlInputs(self): 491 a = constant_op.constant(1) 492 with ops.control_dependencies([a]): 493 b = constant_op.constant(2) 494 c = constant_op.constant(3) 495 d = constant_op.constant(4) 496 e = constant_op.constant(5) 497 with ops.control_dependencies([a, c]): 498 f = d + e 499 500 self.assertEqual(a.op.control_inputs, []) 501 self.assertEqual(b.op.control_inputs, [a.op]) 502 self.assertEqual(f.op.control_inputs, [a.op, c.op]) 503 504 a.op._remove_all_control_inputs() # pylint: disable=protected-access 505 self.assertEqual(a.op.control_inputs, []) 506 507 b.op._remove_all_control_inputs() # pylint: disable=protected-access 508 self.assertEqual(b.op.control_inputs, []) 509 510 f.op._remove_all_control_inputs() # pylint: disable=protected-access 511 self.assertEqual(f.op.control_inputs, []) 512 self.assertEqual(list(f.op.inputs), [d, e]) 513 514 def testControlInputCycle(self): 515 # Non-C API path has a different error message 516 if not ops._USE_C_API: return 517 graph = ops.Graph() 518 with graph.as_default(): 519 z = constant_op.constant(0) 520 x = constant_op.constant(1) 521 y = constant_op.constant(2) 522 y.op._add_control_input(z.op) # pylint: disable=protected-access 523 y.op._add_control_input(x.op) # pylint: disable=protected-access 524 x.op._add_control_input(y.op) # pylint: disable=protected-access 525 with self.test_session(graph=graph) as sess: 526 with self.assertRaisesRegexp( 527 errors.InvalidArgumentError, 528 "Graph is invalid, contains a cycle with 2 nodes"): 529 sess.run(x) 530 531 def testUpdateInput(self): 532 g = ops.Graph() 533 with g.as_default(): 534 x = constant_op.constant(1) 535 y = constant_op.constant(2) 536 z = x + y 537 538 z.op._update_input(0, y) # pylint: disable=protected-access 539 self.assertEquals(list(z.op.inputs), [y, y]) 540 self.assertEquals(x.consumers(), []) 541 self.assertEquals(y.consumers(), [z.op, z.op]) 542 with session.Session(graph=g) as sess: 543 self.assertEquals(sess.run(z), 4) 544 545 z.op._update_input(0, x) # pylint: disable=protected-access 546 self.assertEquals(list(z.op.inputs), [x, y]) 547 self.assertEquals(x.consumers(), [z.op]) 548 self.assertEquals(y.consumers(), [z.op]) 549 with session.Session(graph=g) as sess: 550 self.assertEquals(sess.run(z), 3) 551 552 z.op._update_input(1, y) # pylint: disable=protected-access 553 self.assertEquals(list(z.op.inputs), [x, y]) 554 self.assertEquals(x.consumers(), [z.op]) 555 self.assertEquals(y.consumers(), [z.op]) 556 with session.Session(graph=g) as sess: 557 self.assertEquals(sess.run(z), 3) 558 559 def testUpdateInputGraphError(self): 560 g_0 = ops.Graph() 561 g_1 = ops.Graph() 562 with g_0.as_default(): 563 x = constant_op.constant(1) 564 with g_1.as_default(): 565 y = constant_op.constant(2) 566 z = y * 2 567 with self.assertRaisesRegexp(ValueError, "must be from the same graph"): 568 z.op._update_input(0, x) # pylint: disable=protected-access 569 570 def testUpdateInputTypeError(self): 571 g = ops.Graph() 572 with g.as_default(): 573 w = constant_op.constant(0) 574 x = constant_op.constant("") 575 y = constant_op.constant(1) 576 z = y + w 577 z.op._update_input(0, x) # pylint: disable=protected-access 578 with session.Session(graph=g) as sess: 579 with self.assertRaisesRegexp( 580 errors.InvalidArgumentError, 581 "Input 0 of node add was passed string from Const_1:0 incompatible " 582 "with expected int32"): 583 sess.run(z) 584 585 def testUpdateInputShapeError(self): 586 # C-API throws the error differently. 587 if ops._USE_C_API: 588 return 589 g = ops.Graph() 590 with g.as_default(): 591 w = constant_op.constant(2, shape=[3, 1]) 592 x = constant_op.constant(0, shape=[3, 1]) 593 y = constant_op.constant(1, shape=[2, 2]) 594 z = w + x 595 z.op._update_input(0, y) # pylint: disable=protected-access 596 597 with session.Session(graph=g) as sess: 598 with self.assertRaisesRegexp(errors.InvalidArgumentError, 599 r"Incompatible shapes: \[2,2\] vs. \[3,1\]"): 600 sess.run(z) 601 602 def testUpdateInputShapeErrorC(self): 603 if not ops._USE_C_API: 604 return 605 g = ops.Graph() 606 with g.as_default(): 607 w = constant_op.constant(2, shape=[3, 1]) 608 x = constant_op.constant(0, shape=[3, 1]) 609 y = constant_op.constant(1, shape=[2, 2]) 610 z = w + x 611 with self.assertRaisesRegexp( 612 errors.InvalidArgumentError, 613 r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"): 614 z.op._update_input(0, y) # pylint: disable=protected-access 615 616 def testUpdateInputOutOfRange(self): 617 # C-API throws the error differently. 618 if ops._USE_C_API: return 619 g = ops.Graph() 620 with g.as_default(): 621 x = constant_op.constant(1) 622 with self.assertRaisesRegexp(IndexError, "list index out of range"): 623 x.op._update_input(1, x) # pylint: disable=protected-access 624 625 def testUpdateInputOutOfRangeC(self): 626 # C-API throws the error differently. 627 if not ops._USE_C_API: return 628 g = ops.Graph() 629 with g.as_default(): 630 x = constant_op.constant(1) 631 with self.assertRaisesRegexp( 632 errors.OutOfRangeError, 633 r"Cannot update edge. Input index \[1\] is greater than the number of " 634 r"total inputs \[0\]." 635 ): 636 x.op._update_input(1, x) # pylint: disable=protected-access 637 638 def testOpDef(self): 639 x = constant_op.constant(0) 640 y = constant_op.constant(1) 641 z = x + y 642 643 # Pure Python mode doesn't create OpDefs for constants 644 if ops._USE_C_API: 645 self.assertEqual(x.op.op_def.name, "Const") 646 self.assertEqual(len(x.op.op_def.input_arg), 0) 647 self.assertEqual(len(x.op.op_def.output_arg), 1) 648 649 self.assertEqual(z.op.op_def.name, "Add") 650 self.assertEqual(len(z.op.op_def.input_arg), 2) 651 self.assertEqual(len(z.op.op_def.output_arg), 1) 652 653 def testInputFromDifferentGraphError(self): 654 g_0 = ops.Graph() 655 g_1 = ops.Graph() 656 with g_0.as_default(): 657 x = constant_op.constant(1) 658 with g_1.as_default(): 659 y = constant_op.constant(2) 660 with self.assertRaisesRegexp(ValueError, "must be from the same graph"): 661 y * x # pylint: disable=pointless-statement 662 663 def testInputsAreImmutable(self): 664 g = ops.Graph() 665 with g.as_default(): 666 x = test_ops.int_output() 667 op = test_ops.int_input_int_output(x, name="myop").op 668 with self.assertRaisesRegexp( 669 AttributeError, "'_InputList' object has no attribute 'append'"): 670 op.inputs.append(None) 671 672 673 @test_util.with_c_api 674 class CreateOpTest(test_util.TensorFlowTestCase): 675 676 def testNodeDefArgs(self): 677 g = ops.Graph() 678 op1 = g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") 679 with g.device("/device:GPU:0"): 680 op2 = g.create_op( 681 "FloatOutputStringOutput", [], [dtypes.float32, dtypes.string], None, 682 name="myop2") 683 op3 = g.create_op( 684 "Foo3", 685 [list(op1.values())[0], list(op2.values())[1], list(op2.values())[0]], 686 [dtypes.float32, dtypes.int32], 687 None, 688 name="myop3") 689 self.assertDeviceEqual(None, op1.device) 690 self.assertDeviceEqual("/device:GPU:0", op2.device) 691 self.assertDeviceEqual(None, op3.device) 692 self.assertProtoEquals("name:'myop1' op:'FloatOutput'", op1.node_def) 693 self.assertProtoEquals( 694 "name:'myop2' op:'FloatOutputStringOutput' device:'/device:GPU:0'", 695 op2.node_def) 696 self.assertProtoEquals( 697 "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo3'", 698 op3.node_def) 699 700 def testReferenceInput(self): 701 g = ops.Graph() 702 op1 = g.create_op( 703 "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32], 704 name="op1") 705 self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def) 706 ref_t, nonref_t = op1.values() 707 # NOTE(mrry): Must specify input_types to preserve ref-typed input. 708 op2 = g.create_op( 709 "RefInputFloatInput", [ref_t, nonref_t], [], 710 input_types=[dtypes.float32_ref, dtypes.float32], 711 name="op2") 712 self.assertProtoEquals( 713 "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'", 714 op2.node_def) 715 op3 = g.create_op("TwoFloatInputs", [ref_t, nonref_t], [], name="op3") 716 self.assertProtoEquals( 717 "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'", 718 op3.node_def) 719 720 def testFinalized(self): 721 g = ops.Graph() 722 g.finalize() 723 with self.assertRaises(RuntimeError): 724 g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") 725 726 # Test unfinalize. 727 g._unsafe_unfinalize() 728 g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") 729 730 731 # NOTE(skyewm): these cases test the private Graph._create_op_from_tf_operation 732 # method. Arguably we should only test the public APIs that depend on this 733 # method. However, this logic is complex and tricky, and it can be difficult to 734 # ascertain if we have adequate coverage (e.g. a graph may run successfully if 735 # the control flow context isn't set properly, but a more complicated use case 736 # that might not be obvious to test will fail). Thus we instead explicitly test 737 # the low-level behavior. 738 @test_util.with_c_api 739 class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase): 740 741 def testBasic(self): 742 g = ops.Graph() 743 with g.as_default(): 744 x = test_ops.int_output() 745 if ops._USE_C_API: 746 c_op = ops._create_c_op( 747 g, ops._NodeDef("IntInputIntOutput", "myop"), [x], []) 748 op = g._create_op_from_tf_operation(c_op) 749 else: 750 # Test pure-Python version to make sure C API has same behavior. 751 op = test_ops.int_input_int_output(x, name="myop").op 752 753 self.assertEqual(op.name, "myop") 754 self.assertEqual(op.type, "IntInputIntOutput") 755 self.assertEqual(len(op.outputs), 1) 756 self.assertEqual(op.outputs[0].shape, tensor_shape.unknown_shape()) 757 self.assertEqual(list(op.inputs), [x]) 758 self.assertEqual(op.control_inputs, []) 759 self.assertEqual(op.graph, g) 760 self.assertEqual(x.consumers(), [op]) 761 self.assertIsNotNone(op.traceback) 762 self.assertEqual(g.get_operation_by_name("myop"), op) 763 self.assertEqual(g.get_tensor_by_name("myop:0"), op.outputs[0]) 764 765 def testShape(self): 766 g = ops.Graph() 767 with g.as_default(): 768 x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) 769 if ops._USE_C_API: 770 c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), [x], []) 771 op = g._create_op_from_tf_operation(c_op) 772 else: 773 # Test pure-Python version to make sure C API has same behavior. 774 op = array_ops.identity(x, name="myop").op 775 776 self.assertEqual(op.name, "myop") 777 self.assertEqual(op.type, "Identity") 778 self.assertEqual(len(op.outputs), 1) 779 self.assertEqual(op.outputs[0].shape, tensor_shape.matrix(2, 3)) 780 781 def testUniqueName(self): 782 g = ops.Graph() 783 with g.as_default(): 784 if ops._USE_C_API: 785 c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], []) 786 c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], []) 787 op = g._create_op_from_tf_operation(c_op) 788 op2 = g._create_op_from_tf_operation(c_op2) 789 else: 790 # Test pure-Python version to make sure C API has same behavior. 791 op = test_ops.int_output(name="myop").op 792 op2 = test_ops.int_output(name="myop_1").op 793 794 # Create ops with same names as op1 and op2. We expect the new names to be 795 # uniquified. 796 op3 = test_ops.int_output(name="myop").op 797 op4 = test_ops.int_output(name="myop_1").op 798 799 self.assertEqual(op.name, "myop") 800 self.assertEqual(op2.name, "myop_1") 801 self.assertEqual(op3.name, "myop_2") 802 self.assertEqual(op4.name, "myop_1_1") 803 804 def testCond(self): 805 g = ops.Graph() 806 with g.as_default(): 807 x = test_ops.int_output() 808 809 def true_fn(): 810 if ops._USE_C_API: 811 ops._create_c_op(ops.get_default_graph(), 812 ops._NodeDef("IntInput", "cond/myop"), [x], []) 813 new_ops = g._add_new_tf_operations() 814 self.assertEqual(len(new_ops), 1) 815 else: 816 # Test pure-Python version to make sure C API has same behavior. 817 test_ops.int_input(x, name="myop") 818 return x 819 820 control_flow_ops.cond(x < 10, true_fn, lambda: x) 821 822 op = g.get_operation_by_name("cond/myop") 823 self.assertIsNotNone(op) 824 self.assertEqual(op.name, "cond/myop") 825 self.assertEqual(op.type, "IntInput") 826 self.assertEqual(op.outputs, []) 827 op_input = op.inputs[0].op 828 self.assertEqual(op_input.type, "Switch") 829 self.assertEqual(op_input.inputs[0], x) 830 self.assertEqual(op.graph, g) 831 # pylint: disable=protected-access 832 self.assertIsNotNone(op._get_control_flow_context()) 833 self.assertEqual(op._get_control_flow_context().name, 834 "cond/cond_text") 835 # pylint: enable=protected-access 836 837 def testWhileLoop(self): 838 g = ops.Graph() 839 with g.as_default(): 840 x = test_ops.int_output() 841 842 def body(i): 843 if ops._USE_C_API: 844 ops._create_c_op(ops.get_default_graph(), 845 ops._NodeDef("IntInput", "myloop/myop"), [x], []) 846 new_ops = g._add_new_tf_operations() 847 self.assertEqual(len(new_ops), 1) 848 else: 849 # Test pure-Python version to make sure C API has same behavior. 850 test_ops.int_input(x, name="myop") 851 return i 852 853 control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") 854 855 op = g.get_operation_by_name("myloop/myop") 856 self.assertIsNotNone(op) 857 self.assertEqual(op.name, "myloop/myop") 858 self.assertEqual(op.type, "IntInput") 859 self.assertEqual(op.outputs, []) 860 op_input = op.inputs[0].op 861 self.assertEqual(op_input.type, "Enter") 862 self.assertEqual(list(op_input.inputs), [x]) 863 self.assertEqual(op.graph, g) 864 # pylint: disable=protected-access 865 self.assertIsNotNone(op._get_control_flow_context()) 866 self.assertEqual(op._get_control_flow_context().name, 867 "myloop/while_context") 868 # pylint: enable=protected-access 869 870 def testWhileLoopWithInternalControlDep(self): 871 g = ops.Graph() 872 with g.as_default(): 873 x = test_ops.int_output() 874 875 def body(i): 876 c = constant_op.constant(1.0, name="c") 877 if ops._USE_C_API: 878 ops._create_c_op(ops.get_default_graph(), 879 ops._NodeDef("IntInput", "myloop/myop"), [x], []) 880 with ops.control_dependencies([c]): 881 new_ops = g._add_new_tf_operations() 882 self.assertEqual(len(new_ops), 1) 883 else: 884 with ops.control_dependencies([c]): 885 test_ops.int_input(x, name="myop") 886 return i 887 888 control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") 889 890 op = g.get_operation_by_name("myloop/myop") 891 self.assertIsNotNone(op) 892 c = g.get_operation_by_name("myloop/c") 893 self.assertIsNotNone(c) 894 # Internal control dep is preserved 895 self.assertEqual(op.control_inputs, [c]) 896 897 def testWhileLoopWithExternalControlDep(self): 898 g = ops.Graph() 899 with g.as_default(): 900 x = test_ops.int_output() 901 c = constant_op.constant(1.0) 902 903 def body(i): 904 if ops._USE_C_API: 905 ops._create_c_op(ops.get_default_graph(), 906 ops._NodeDef("IntInput", "myloop/myop"), [x], []) 907 with ops.control_dependencies([c]): 908 new_ops = g._add_new_tf_operations() 909 self.assertEqual(len(new_ops), 1) 910 else: 911 with ops.control_dependencies([c]): 912 test_ops.int_input(x, name="myop") 913 return i 914 915 control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") 916 917 op = g.get_operation_by_name("myloop/myop") 918 self.assertIsNotNone(op) 919 # External control dep is removed and replaced with internal control dep 920 self.assertNotEqual(op.control_inputs[0], c.op) 921 self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) 922 923 924 @test_util.with_c_api 925 class ApplyOpTest(test_util.TensorFlowTestCase): 926 927 def testNodeDefArgs(self): 928 g = ops.Graph() 929 t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1") 930 with g.device("/device:GPU:0"): 931 t2 = _apply_op( 932 g, "TwoIntOutputs", [], [dtypes.int32, dtypes.int32], name="myop2") 933 t3 = _apply_op( 934 g, 935 "Foo1", [t1, t2[1], t2[0]], [dtypes.float32, dtypes.int32], 936 name="myop3") 937 self.assertTrue(isinstance(t1, ops.Tensor)) 938 self.assertTrue(isinstance(t2, list)) 939 self.assertTrue(isinstance(t3, list)) 940 self.assertTrue(isinstance(t3[0], ops.Tensor)) 941 self.assertEqual("myop1", t1._as_node_def_input()) 942 self.assertEqual("myop2", t2[0]._as_node_def_input()) 943 self.assertEqual("myop2:1", t2[1]._as_node_def_input()) 944 self.assertEqual("myop3", t3[0]._as_node_def_input()) 945 # Validate that we got the right ops as well 946 self.assertProtoEquals("name:'myop1' op:'FloatOutput'", t1.op.node_def) 947 self.assertProtoEquals( 948 "name:'myop2' op:'TwoIntOutputs' device:'/device:GPU:0'", 949 t2[0].op.node_def) 950 self.assertProtoEquals( 951 "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo1'", 952 t3[0].op.node_def) 953 954 def testReferenceInput(self): 955 g = ops.Graph() 956 ref_t, nonref_t = _apply_op( 957 g, "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32], 958 name="op1") 959 self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", 960 ref_t.op.node_def) 961 # NOTE(mrry): Must specify input_types to preserve ref-typed input. 962 out_2 = _apply_op( 963 g, 964 "RefInputFloatInputIntOutput", [ref_t, nonref_t], [dtypes.int32], 965 input_types=[dtypes.float32_ref, dtypes.float32], 966 name="op2") 967 self.assertProtoEquals( 968 "op:'RefInputFloatInputIntOutput' name:'op2' input:'op1' input:'op1:1'", 969 out_2.op.node_def) 970 out_3 = _apply_op( 971 g, "TwoFloatInputsIntOutput", [ref_t, nonref_t], [dtypes.int32], 972 name="op3") 973 self.assertProtoEquals( 974 "op:'TwoFloatInputsIntOutput' name:'op3' input:'op1' input:'op1:1'", 975 out_3.op.node_def) 976 977 978 @test_util.with_c_api 979 class NameStackTest(test_util.TensorFlowTestCase): 980 981 def testBasics(self): 982 g = ops.Graph() 983 self.assertEqual("foo", g.unique_name("foo", mark_as_used=False)) 984 self.assertEqual("foo", g.unique_name("foo", mark_as_used=False)) 985 self.assertEqual("foo", g.unique_name("foo")) 986 self.assertEqual("foo_1", g.unique_name("foo", mark_as_used=False)) 987 self.assertEqual("foo_1", g.unique_name("foo")) 988 self.assertEqual("foo_2", g.unique_name("foo", mark_as_used=False)) 989 self.assertEqual("foo_2", g.unique_name("foo")) 990 self.assertEqual("foo_1_1", g.unique_name("foo_1", mark_as_used=False)) 991 self.assertEqual("foo_1_1", g.unique_name("foo_1")) 992 self.assertEqual("foo_1_2", g.unique_name("foo_1", mark_as_used=False)) 993 self.assertEqual("foo_1_2", g.unique_name("foo_1")) 994 self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2", mark_as_used=False)) 995 self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2")) 996 with g.name_scope("bar"): 997 self.assertEqual("bar/foo", g.unique_name("foo", mark_as_used=False)) 998 self.assertEqual("bar/foo", g.unique_name("foo")) 999 self.assertEqual("bar/foo_1", g.unique_name("foo", mark_as_used=False)) 1000 self.assertEqual("bar/foo_1", g.unique_name("foo")) 1001 with g.name_scope(None): 1002 self.assertEqual("foo_3", g.unique_name("foo", mark_as_used=False)) 1003 self.assertEqual("foo_3", g.unique_name("foo")) 1004 with g.name_scope("baz"): 1005 self.assertEqual( 1006 "bar/baz/foo", g.unique_name( 1007 "foo", mark_as_used=False)) 1008 self.assertEqual("bar/baz/foo", g.unique_name("foo")) 1009 self.assertEqual( 1010 "bar/baz/foo_1", g.unique_name( 1011 "foo", mark_as_used=False)) 1012 self.assertEqual("bar/baz/foo_1", g.unique_name("foo")) 1013 with g.name_scope("baz"): 1014 self.assertEqual( 1015 "bar/baz_1/foo", g.unique_name( 1016 "foo", mark_as_used=False)) 1017 self.assertEqual("bar/baz_1/foo", g.unique_name("foo")) 1018 self.assertEqual( 1019 "bar/baz_1/foo_1", g.unique_name( 1020 "foo", mark_as_used=False)) 1021 self.assertEqual("bar/baz_1/foo_1", g.unique_name("foo")) 1022 with g.name_scope("quux"): 1023 self.assertEqual("quux/foo", g.unique_name("foo", mark_as_used=False)) 1024 self.assertEqual("quux/foo", g.unique_name("foo")) 1025 with g.name_scope("bar"): 1026 with g.name_scope("baz"): 1027 self.assertEqual( 1028 "bar_1/baz/foo", g.unique_name( 1029 "foo", mark_as_used=False)) 1030 self.assertEqual("bar_1/baz/foo", g.unique_name("foo")) 1031 self.assertEqual("foo_4", g.unique_name("foo", mark_as_used=False)) 1032 self.assertEqual("foo_4", g.unique_name("foo")) 1033 self.assertEqual("bar_2", g.unique_name("bar", mark_as_used=False)) 1034 self.assertEqual("bar_2", g.unique_name("bar")) 1035 1036 def testNameAndVariableScope(self): 1037 with self.test_session() as sess: 1038 with sess.graph.name_scope("l0"): 1039 with variable_scope.variable_scope("l1"): 1040 with sess.graph.name_scope("l1") as scope: 1041 self.assertEqual("l0/l1/l1/", scope) 1042 self.assertEqual( 1043 "l0/l1/l1/foo", 1044 sess.graph.unique_name( 1045 "foo", mark_as_used=False)) 1046 self.assertEqual("l0/l1/l1/foo", sess.graph.unique_name("foo")) 1047 with sess.graph.name_scope("l2") as scope: 1048 self.assertEqual("l0/l1/l2/", scope) 1049 self.assertEqual( 1050 "l0/l1/l2/foo", 1051 sess.graph.unique_name( 1052 "foo", mark_as_used=False)) 1053 self.assertEqual("l0/l1/l2/foo", sess.graph.unique_name("foo")) 1054 1055 def testOutOfOrderUniqueName(self): 1056 g = ops.Graph() 1057 self.assertEqual("foo_2", g.unique_name("foo_2")) 1058 self.assertEqual("foo", g.unique_name("foo")) 1059 self.assertEqual("foo_1", g.unique_name("foo")) 1060 self.assertEqual("foo_3", g.unique_name("foo")) 1061 1062 def testInvalidNameRaisesError(self): 1063 g = ops.Graph() 1064 with g.name_scope(""): # Should not raise 1065 pass 1066 with g.name_scope("foo/"): # Should not raise 1067 with g.name_scope("_bar"): # Should not raise 1068 pass 1069 with self.assertRaises(ValueError): 1070 with g.name_scope("foo:0"): 1071 pass 1072 with self.assertRaises(ValueError): 1073 with g.name_scope("_bar"): 1074 pass 1075 1076 1077 @test_util.with_c_api 1078 class NameTest(test_util.TensorFlowTestCase): 1079 1080 def testGenerateName(self): 1081 g = ops.Graph() 1082 op0 = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32]) 1083 self.assertEqual("TwoFloatOutputs", op0.name) 1084 self.assertEqual("TwoFloatOutputs:0", op0.outputs[0].name) 1085 self.assertEqual("TwoFloatOutputs:1", op0.outputs[1].name) 1086 1087 op1 = g.create_op("FloatOutput", [], [dtypes.float32]) 1088 self.assertEqual("FloatOutput", op1.name) 1089 self.assertEqual("FloatOutput:0", op1.outputs[0].name) 1090 1091 op2 = g.create_op("FloatOutput", [], [dtypes.float32]) 1092 self.assertEqual("FloatOutput_1", op2.name) 1093 self.assertEqual("FloatOutput_1:0", op2.outputs[0].name) 1094 1095 op3 = g.create_op("FloatOutput", [], [dtypes.float32], name="my_op") 1096 self.assertEqual("my_op", op3.name) 1097 self.assertEqual("my_op:0", op3.outputs[0].name) 1098 1099 def testNameScope(self): 1100 g = ops.Graph() 1101 1102 with g.name_scope("foo") as foo: 1103 self.assertEqual("foo/", foo) 1104 with g.name_scope("foo2") as foo2: 1105 self.assertEqual("foo/foo2/", foo2) 1106 with g.name_scope(None) as empty1: 1107 self.assertEqual("", empty1) 1108 with g.name_scope("foo3") as foo3: 1109 self.assertEqual("foo3/", foo3) 1110 with g.name_scope("") as empty2: 1111 self.assertEqual("", empty2) 1112 1113 self.assertEqual("FloatOutput", 1114 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1115 with g.name_scope("bar") as scope: 1116 self.assertEqual("bar/FloatOutput", 1117 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1118 self.assertEqual("bar/FloatOutput_1", 1119 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1120 # If you use the value from "with .. as", that values is used as-is. 1121 self.assertEqual( 1122 "bar", g.create_op( 1123 "FloatOutput", [], [dtypes.float32], name=scope).name) 1124 with g.name_scope("baz") as scope: 1125 with g.name_scope("quux"): 1126 self.assertEqual("baz/quux/FloatOutput", 1127 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1128 # If you use the value from the enclosing "with .. as", nothing is pushed. 1129 with g.name_scope(scope): 1130 self.assertEqual("baz/FloatOutput", 1131 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1132 self.assertEqual( 1133 "baz", g.create_op( 1134 "FloatOutput", [], [dtypes.float32], name=scope).name) 1135 self.assertEqual( 1136 "trailing", 1137 g.create_op( 1138 "FloatOutput", [], [dtypes.float32], name="trailing/").name) 1139 with g.name_scope("bar"): 1140 self.assertEqual("bar_1/FloatOutput", 1141 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1142 with g.name_scope("bar/"): 1143 self.assertEqual("bar/FloatOutput_2", 1144 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1145 1146 1147 @test_util.with_c_api 1148 class DeviceTest(test_util.TensorFlowTestCase): 1149 1150 def testNoDevice(self): 1151 g = ops.Graph() 1152 op = g.create_op("FloatOutput", [], [dtypes.float32]) 1153 self.assertDeviceEqual(None, op.device) 1154 gd = g.as_graph_def() 1155 self.assertProtoEqualsVersion(""" 1156 node { name: "FloatOutput" op: "FloatOutput" } 1157 """, gd) 1158 1159 def testDevicePartialString(self): 1160 g = ops.Graph() 1161 with g.device("/job:worker/replica:2"): 1162 g.create_op("FloatOutput", [], [dtypes.float32]) 1163 gd = g.as_graph_def() 1164 self.assertProtoEqualsVersion(""" 1165 node { name: "FloatOutput" op: "FloatOutput" 1166 device: "/job:worker/replica:2" } 1167 """, gd) 1168 1169 def testDeviceFull(self): 1170 g = ops.Graph() 1171 with g.device( 1172 pydev.DeviceSpec( 1173 job="worker", replica=2, task=0, device_type="CPU", 1174 device_index=3)): 1175 g.create_op("FloatOutput", [], [dtypes.float32]) 1176 gd = g.as_graph_def() 1177 self.assertProtoEqualsVersion(""" 1178 node { name: "FloatOutput" op: "FloatOutput" 1179 device: "/job:worker/replica:2/task:0/device:CPU:3" } 1180 """, gd) 1181 1182 def testNesting(self): 1183 g = ops.Graph() 1184 with g.device("/job:worker/replica:2"): 1185 g.create_op("FloatOutput", [], [dtypes.float32]) 1186 with g.device("/job:worker/replica:3/task:0"): 1187 g.create_op("FloatOutput", [], [dtypes.float32]) 1188 g.create_op("FloatOutput", [], [dtypes.float32]) 1189 gd = g.as_graph_def() 1190 self.assertProtoEqualsVersion(""" 1191 node { name: "FloatOutput" op: "FloatOutput" 1192 device: "/job:worker/replica:2" } 1193 node { name: "FloatOutput_1" op: "FloatOutput" 1194 device: "/job:worker/replica:3/task:0" } 1195 node { name: "FloatOutput_2" op: "FloatOutput" 1196 device: "/job:worker/replica:2" } 1197 """, gd) 1198 1199 def testNestingString(self): 1200 g = ops.Graph() 1201 with g.device("/job:worker/replica:2"): 1202 g.create_op("FloatOutput", [], [dtypes.float32]) 1203 with g.device("/job:worker/replica:3/task:0"): 1204 g.create_op("FloatOutput", [], [dtypes.float32]) 1205 g.create_op("FloatOutput", [], [dtypes.float32]) 1206 gd = g.as_graph_def() 1207 self.assertProtoEqualsVersion(""" 1208 node { name: "FloatOutput" op: "FloatOutput" 1209 device: "/job:worker/replica:2" } 1210 node { name: "FloatOutput_1" op: "FloatOutput" 1211 device: "/job:worker/replica:3/task:0" } 1212 node { name: "FloatOutput_2" op: "FloatOutput" 1213 device: "/job:worker/replica:2" } 1214 """, gd) 1215 1216 def testNestingOverrideGpuCpu(self): 1217 g = ops.Graph() 1218 with g.device("/job:worker/replica:2/device:CPU:1"): 1219 g.create_op("FloatOutput", [], [dtypes.float32]) 1220 with g.device("/job:worker/replica:2/device:GPU:2"): 1221 g.create_op("FloatOutput", [], [dtypes.float32]) 1222 g.create_op("FloatOutput", [], [dtypes.float32]) 1223 gd = g.as_graph_def() 1224 self.assertProtoEqualsVersion(""" 1225 node { name: "FloatOutput" op: "FloatOutput" 1226 device: "/job:worker/replica:2/device:CPU:1" } 1227 node { name: "FloatOutput_1" op: "FloatOutput" 1228 device: "/job:worker/replica:2/device:GPU:2" } 1229 node { name: "FloatOutput_2" op: "FloatOutput" 1230 device: "/job:worker/replica:2/device:CPU:1" } 1231 """, gd) 1232 1233 def testNestingWithMergeDeviceFunction(self): 1234 g = ops.Graph() 1235 1236 with g.device(pydev.merge_device("/device:GPU:0")): 1237 g.create_op("FloatOutput", [], [dtypes.float32]) 1238 with g.device(pydev.merge_device("/job:worker")): 1239 g.create_op("FloatOutput", [], [dtypes.float32]) 1240 with g.device(pydev.merge_device("/device:CPU:0")): 1241 g.create_op("FloatOutput", [], [dtypes.float32]) 1242 with g.device(pydev.merge_device("/job:ps")): 1243 g.create_op("FloatOutput", [], [dtypes.float32]) 1244 with g.device(pydev.merge_device(None)): 1245 g.create_op("FloatOutput", [], [dtypes.float32]) 1246 1247 gd = g.as_graph_def() 1248 self.assertProtoEqualsVersion(""" 1249 node { name: "FloatOutput" op: "FloatOutput" 1250 device: "/device:GPU:0" } 1251 node { name: "FloatOutput_1" op: "FloatOutput" 1252 device: "/job:worker/device:GPU:0" } 1253 node { name: "FloatOutput_2" op: "FloatOutput" 1254 device: "/job:worker/device:CPU:0" } 1255 node { name: "FloatOutput_3" op: "FloatOutput" 1256 device: "/job:ps/device:CPU:0" } 1257 node { name: "FloatOutput_4" op: "FloatOutput" 1258 device: "/job:ps/device:CPU:0" } 1259 """, gd) 1260 1261 def testNestingWithDeviceStrings(self): 1262 g = ops.Graph() 1263 1264 with g.device("/device:GPU:0"): 1265 g.create_op("FloatOutput", [], [dtypes.float32]) 1266 with g.device("/job:worker"): 1267 g.create_op("FloatOutput", [], [dtypes.float32]) 1268 with g.device("/device:CPU:0"): 1269 g.create_op("FloatOutput", [], [dtypes.float32]) 1270 with g.device("/job:ps"): 1271 g.create_op("FloatOutput", [], [dtypes.float32]) 1272 with g.device(""): 1273 g.create_op("FloatOutput", [], [dtypes.float32]) 1274 1275 gd = g.as_graph_def() 1276 self.assertProtoEqualsVersion(""" 1277 node { name: "FloatOutput" op: "FloatOutput" 1278 device: "/device:GPU:0" } 1279 node { name: "FloatOutput_1" op: "FloatOutput" 1280 device: "/job:worker/device:GPU:0" } 1281 node { name: "FloatOutput_2" op: "FloatOutput" 1282 device: "/job:worker/device:CPU:0" } 1283 node { name: "FloatOutput_3" op: "FloatOutput" 1284 device: "/job:ps/device:CPU:0" } 1285 node { name: "FloatOutput_4" op: "FloatOutput" 1286 device: "/job:ps/device:CPU:0" } 1287 """, gd) 1288 1289 def testNestingWithDeviceStringWildcard(self): 1290 g = ops.Graph() 1291 1292 with g.device("/device:GPU:7"): 1293 g.create_op("FloatOutput", [], [dtypes.float32]) 1294 with g.device("/device:GPU:*"): 1295 g.create_op("FloatOutput", [], [dtypes.float32]) 1296 1297 with g.device("/device:CPU:*"): 1298 g.create_op("FloatOutput", [], [dtypes.float32]) 1299 with g.device("/device:CPU:5"): 1300 g.create_op("FloatOutput", [], [dtypes.float32]) 1301 1302 gd = g.as_graph_def() 1303 self.assertProtoEqualsVersion(""" 1304 node { name: "FloatOutput" op: "FloatOutput" 1305 device: "/device:GPU:7" } 1306 node { name: "FloatOutput_1" op: "FloatOutput" 1307 device: "/device:GPU:7" } 1308 node { name: "FloatOutput_2" op: "FloatOutput" 1309 device: "/device:CPU:*" } 1310 node { name: "FloatOutput_3" op: "FloatOutput" 1311 device: "/device:CPU:5" } 1312 """, gd) 1313 1314 def testNoneClearsDefault(self): 1315 g = ops.Graph() 1316 with g.device("/job:worker/replica:2/device:CPU:1"): 1317 g.create_op("FloatOutput", [], [dtypes.float32]) 1318 with g.device(None): 1319 g.create_op("FloatOutput", [], [dtypes.float32]) 1320 g.create_op("FloatOutput", [], [dtypes.float32]) 1321 gd = g.as_graph_def() 1322 self.assertProtoEqualsVersion(""" 1323 node { name: "FloatOutput" op: "FloatOutput" 1324 device: "/job:worker/replica:2/device:CPU:1" } 1325 node { name: "FloatOutput_1" op: "FloatOutput" } 1326 node { name: "FloatOutput_2" op: "FloatOutput" 1327 device: "/job:worker/replica:2/device:CPU:1" } 1328 """, gd) 1329 1330 def testNoneIgnoresOuterDeviceFunction(self): 1331 g = ops.Graph() 1332 with g.device(lambda op: "/job:worker/replica:2/device:CPU:1"): 1333 g.create_op("FloatOutput", [], [dtypes.float32]) 1334 with g.device(None): 1335 g.create_op("FloatOutput", [], [dtypes.float32]) 1336 g.create_op("FloatOutput", [], [dtypes.float32]) 1337 gd = g.as_graph_def() 1338 self.assertProtoEqualsVersion(""" 1339 node { name: "FloatOutput" op: "FloatOutput" 1340 device: "/job:worker/replica:2/device:CPU:1" } 1341 node { name: "FloatOutput_1" op: "FloatOutput" } 1342 node { name: "FloatOutput_2" op: "FloatOutput" 1343 device: "/job:worker/replica:2/device:CPU:1" } 1344 """, gd) 1345 1346 def _overwritingDeviceFunction(self, unused_op): 1347 # This device function unconditionally overwrites the device of ops. 1348 # 1349 # NOTE(mrry): Writing device functions like this is not 1350 # recommended. Instead, in most cases you should use 1351 # `pydev.merge_device("/job:ps")` or simply `"/job:ps"` as the 1352 # argument to `tf.device()` and the device component will be merged in. 1353 return "/job:overwrite" 1354 1355 def testOverwritingBehavior(self): 1356 g = ops.Graph() 1357 with g.device(self._overwritingDeviceFunction): 1358 g.create_op("FloatOutput", [], [dtypes.float32]) 1359 with g.device("/job:ps"): # Will be overwritten. 1360 g.create_op("FloatOutput", [], [dtypes.float32]) 1361 with g.device(pydev.merge_device("/job:ps")): # Will be overwritten. 1362 g.create_op("FloatOutput", [], [dtypes.float32]) 1363 with g.device(None): # Disables overwriting device function 1364 with g.device("/job:ps"): 1365 g.create_op("FloatOutput", [], [dtypes.float32]) 1366 with g.device(None): # Disables overwriting device function 1367 with g.device(pydev.merge_device("/job:ps")): 1368 g.create_op("FloatOutput", [], [dtypes.float32]) 1369 gd = g.as_graph_def() 1370 self.assertProtoEqualsVersion(""" 1371 node { name: "FloatOutput" op: "FloatOutput" 1372 device: "/job:overwrite" } 1373 node { name: "FloatOutput_1" op: "FloatOutput" 1374 device: "/job:overwrite" } 1375 node { name: "FloatOutput_2" op: "FloatOutput" 1376 device: "/job:overwrite" } 1377 node { name: "FloatOutput_3" op: "FloatOutput" 1378 device: "/job:ps" } 1379 node { name: "FloatOutput_4" op: "FloatOutput" 1380 device: "/job:ps" } 1381 """, gd) 1382 1383 1384 @test_util.with_c_api 1385 class ObjectWithName(object): 1386 1387 def __init__(self, name): 1388 self._name = name 1389 1390 @property 1391 def name(self): 1392 return self._name 1393 1394 1395 @test_util.with_c_api 1396 class CollectionTest(test_util.TensorFlowTestCase): 1397 1398 def test_get_collections(self): 1399 g = ops.Graph() 1400 self.assertSequenceEqual(g.collections, []) 1401 g.add_to_collection("key", 12) 1402 g.add_to_collection("key", 15) 1403 self.assertSequenceEqual(g.collections, ["key"]) 1404 g.add_to_collection("other", "foo") 1405 self.assertSequenceEqual(sorted(g.collections), ["key", "other"]) 1406 1407 def test_add_to_collection(self): 1408 g = ops.Graph() 1409 g.add_to_collection("key", 12) 1410 g.add_to_collection("other", "foo") 1411 g.add_to_collection("key", 34) 1412 1413 # Note that only blank1 is returned. 1414 g.add_to_collection("blah", 27) 1415 blank1 = ObjectWithName("prefix/foo") 1416 g.add_to_collection("blah", blank1) 1417 blank2 = ObjectWithName("junk/foo") 1418 g.add_to_collection("blah", blank2) 1419 1420 self.assertEqual([12, 34], g.get_collection("key")) 1421 self.assertEqual([], g.get_collection("nothing")) 1422 self.assertEqual([27, blank1, blank2], g.get_collection("blah")) 1423 self.assertEqual([blank1], g.get_collection("blah", "prefix")) 1424 self.assertEqual([blank1], g.get_collection("blah", ".*x")) 1425 1426 # Make sure that get_collection() returns a first-level 1427 # copy of the collection, while get_collection_ref() returns 1428 # the original list. 1429 other_collection_snapshot = g.get_collection("other") 1430 other_collection_ref = g.get_collection_ref("other") 1431 self.assertEqual(["foo"], other_collection_snapshot) 1432 self.assertEqual(["foo"], other_collection_ref) 1433 g.add_to_collection("other", "bar") 1434 self.assertEqual(["foo"], other_collection_snapshot) 1435 self.assertEqual(["foo", "bar"], other_collection_ref) 1436 self.assertEqual(["foo", "bar"], g.get_collection("other")) 1437 self.assertTrue(other_collection_ref is g.get_collection_ref("other")) 1438 1439 # Verify that getting an empty collection ref returns a modifiable list. 1440 empty_coll_ref = g.get_collection_ref("empty") 1441 self.assertEqual([], empty_coll_ref) 1442 empty_coll = g.get_collection("empty") 1443 self.assertEqual([], empty_coll) 1444 self.assertFalse(empty_coll is empty_coll_ref) 1445 empty_coll_ref2 = g.get_collection_ref("empty") 1446 self.assertTrue(empty_coll_ref2 is empty_coll_ref) 1447 # Add to the collection. 1448 empty_coll_ref.append("something") 1449 self.assertEqual(["something"], empty_coll_ref) 1450 self.assertEqual(["something"], empty_coll_ref2) 1451 self.assertEqual([], empty_coll) 1452 self.assertEqual(["something"], g.get_collection("empty")) 1453 empty_coll_ref3 = g.get_collection_ref("empty") 1454 self.assertTrue(empty_coll_ref3 is empty_coll_ref) 1455 1456 def test_add_to_collections_uniquify(self): 1457 g = ops.Graph() 1458 g.add_to_collections([1, 2, 1], "key") 1459 # Make sure "key" is not added twice 1460 self.assertEqual(["key"], g.get_collection(1)) 1461 1462 def test_add_to_collections_from_list(self): 1463 g = ops.Graph() 1464 g.add_to_collections(["abc", "123"], "key") 1465 self.assertEqual(["key"], g.get_collection("abc")) 1466 self.assertEqual(["key"], g.get_collection("123")) 1467 1468 def test_add_to_collections_from_tuple(self): 1469 g = ops.Graph() 1470 g.add_to_collections(("abc", "123"), "key") 1471 self.assertEqual(["key"], g.get_collection("abc")) 1472 self.assertEqual(["key"], g.get_collection("123")) 1473 1474 def test_add_to_collections_from_generator(self): 1475 g = ops.Graph() 1476 1477 def generator(): 1478 yield "abc" 1479 yield "123" 1480 1481 g.add_to_collections(generator(), "key") 1482 self.assertEqual(["key"], g.get_collection("abc")) 1483 self.assertEqual(["key"], g.get_collection("123")) 1484 1485 def test_add_to_collections_from_set(self): 1486 g = ops.Graph() 1487 g.add_to_collections(set(["abc", "123"]), "key") 1488 self.assertEqual(["key"], g.get_collection("abc")) 1489 self.assertEqual(["key"], g.get_collection("123")) 1490 1491 def test_add_to_collections_from_string(self): 1492 g = ops.Graph() 1493 g.add_to_collections("abc", "key") 1494 self.assertEqual(["key"], g.get_collection("abc")) 1495 1496 def test_default_graph(self): 1497 with ops.Graph().as_default(): 1498 ops.add_to_collection("key", 90) 1499 ops.add_to_collection("key", 100) 1500 # Collections are ordered. 1501 self.assertEqual([90, 100], ops.get_collection("key")) 1502 1503 1504 ops.NotDifferentiable("FloatOutput") 1505 1506 1507 @ops.RegisterGradient("CopyOp") 1508 def _CopyGrad(op, x_grad): # pylint: disable=invalid-name 1509 _ = op 1510 return x_grad 1511 1512 1513 @ops.RegisterGradient("copy_override") 1514 def _CopyOverrideGrad(op, x_grad): # pylint: disable=invalid-name 1515 _ = op 1516 return x_grad 1517 1518 1519 @test_util.with_c_api 1520 class RegistrationTest(test_util.TensorFlowTestCase): 1521 1522 def testRegisterGradients(self): 1523 x = test_ops.float_output() 1524 y = test_ops.copy_op(x) 1525 fn = ops.get_gradient_function(y.op) 1526 self.assertEqual(_CopyGrad, fn) 1527 1528 def testOverrideGradients(self): 1529 g = ops.Graph() 1530 with g.as_default(): 1531 x = test_ops.float_output() 1532 with g.gradient_override_map({"CopyOp": "copy_override"}): 1533 y = test_ops.copy_op(x) 1534 fn = ops.get_gradient_function(y.op) 1535 self.assertEqual(_CopyOverrideGrad, fn) 1536 1537 def testNonExistentOverride(self): 1538 g = ops.Graph() 1539 with g.as_default(): 1540 x = test_ops.float_output() 1541 with g.gradient_override_map({"CopyOp": "unknown_override"}): 1542 y = test_ops.copy_op(x) 1543 with self.assertRaisesRegexp(LookupError, "unknown_override"): 1544 ops.get_gradient_function(y.op) 1545 1546 1547 @test_util.with_c_api 1548 class ComparisonTest(test_util.TensorFlowTestCase): 1549 1550 def testMembershipAllowed(self): 1551 g = ops.Graph() 1552 t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1") 1553 t2 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop2") 1554 self.assertTrue(isinstance(t1, ops.Tensor)) 1555 self.assertTrue(isinstance(t2, ops.Tensor)) 1556 self.assertTrue(t1 in [t1]) 1557 self.assertTrue(t1 not in [t2]) 1558 1559 1560 @test_util.with_c_api 1561 class ControlDependenciesTest(test_util.TensorFlowTestCase): 1562 1563 @test_util.enable_c_api 1564 def testBasic(self): 1565 g = ops.Graph() 1566 with g.as_default(): 1567 # Creating unregistered ops with _apply_op() doesn't work with the C API 1568 # TODO(skyewm): address this more consistently. Possible solutions are 1569 # to use registered ops in all tests, create a way to register ops in 1570 # Python tests, or conditionally disable the op registration check in 1571 # the C API. 1572 a = constant_op.constant(1.0) 1573 b = constant_op.constant(1.0) 1574 with g.control_dependencies([a]): 1575 c = constant_op.constant(1.0) 1576 d = array_ops.identity(b) 1577 e = array_ops.identity(c) 1578 1579 self.assertEqual(c.op.control_inputs, [a.op]) 1580 self.assertEqual(d.op.control_inputs, [a.op]) 1581 # e should be dominated by c. 1582 self.assertEqual(e.op.control_inputs, []) 1583 1584 @test_util.run_in_graph_and_eager_modes() 1585 def testEager(self): 1586 def future(): 1587 future.calls += 1 1588 return constant_op.constant(2.0) 1589 future.calls = 0 1590 1591 if context.in_graph_mode(): 1592 g = ops.Graph() 1593 with g.as_default(): 1594 a = constant_op.constant(1.0) 1595 b = future() 1596 with g.control_dependencies([a, b]): 1597 c = constant_op.constant(3.0) 1598 self.assertEqual(c.op.control_inputs, [a.op, b.op]) 1599 self.assertEqual(future.calls, 1) 1600 else: 1601 a = constant_op.constant(1.0) 1602 b = future() 1603 with ops.control_dependencies([a, b]): 1604 c = constant_op.constant(3.0) 1605 self.assertEqual(future.calls, 1) 1606 1607 def testBasicWithConversion(self): 1608 g = ops.Graph() 1609 a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1610 1611 class ConvertibleObj(object): 1612 1613 def _as_graph_element(self): 1614 return a 1615 1616 with g.control_dependencies([ConvertibleObj()]): 1617 c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1618 1619 self.assertEqual(c.op.control_inputs, [a.op]) 1620 1621 def testNested(self): 1622 g = ops.Graph() 1623 a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1624 a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1625 a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1626 a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1627 1628 with g.control_dependencies([a_1, a_2, a_3, a_4]): 1629 b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1630 1631 with g.control_dependencies([a_1]): 1632 with g.control_dependencies([a_2]): 1633 with g.control_dependencies([a_3]): 1634 with g.control_dependencies([a_4]): 1635 b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1636 1637 self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op], 1638 b_1.op.control_inputs) 1639 self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs) 1640 1641 def testClear(self): 1642 g = ops.Graph() 1643 a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1644 a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1645 a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1646 a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1647 1648 with g.control_dependencies([a_1]): 1649 with g.control_dependencies([a_2]): 1650 with g.control_dependencies(None): 1651 with g.control_dependencies([a_3]): 1652 with g.control_dependencies([a_4]): 1653 # deps [a_3, a_4] 1654 b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1655 # deps = [a_3] 1656 b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1657 # deps back to None 1658 b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1659 # deps back to [a_1, a_2] 1660 b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1661 # deps back to [a_1] 1662 b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1663 with g.control_dependencies(None): 1664 # deps are None again 1665 b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1666 1667 self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs) 1668 self.assertItemsEqual([a_3.op], b_3.op.control_inputs) 1669 self.assertItemsEqual([], b_none.op.control_inputs) 1670 self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs) 1671 self.assertItemsEqual([a_1.op], b_1.op.control_inputs) 1672 self.assertItemsEqual([], b_none2.op.control_inputs) 1673 1674 def testComplex(self): 1675 g = ops.Graph() 1676 1677 # Usage pattern: 1678 # * Nodes a_i are constants defined at the outermost scope, and are used 1679 # as control inputs for the ith nested scope. 1680 # * Nodes b_i are defined as Mul(a_3, a_4) at each scope. 1681 # * Nodes c_i are defined as Mul(a_1, b_1) at each scope. 1682 # * Nodes d_i are defined as Mul(b_i, c_i) at each scope. 1683 # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1. 1684 1685 a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1686 a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1687 a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1688 a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1689 1690 with g.control_dependencies([a_1]): 1691 b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], 1692 [dtypes.float32]) 1693 c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], 1694 [dtypes.float32]) 1695 d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1], 1696 [dtypes.float32]) 1697 e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1698 with g.control_dependencies([a_2]): 1699 b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], 1700 [dtypes.float32]) 1701 c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], 1702 [dtypes.float32]) 1703 d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2], 1704 [dtypes.float32]) 1705 e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1], 1706 [dtypes.float32]) 1707 with g.control_dependencies([a_3]): 1708 b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], 1709 [dtypes.float32]) 1710 c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], 1711 [dtypes.float32]) 1712 d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3], 1713 [dtypes.float32]) 1714 e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2], 1715 [dtypes.float32]) 1716 with g.control_dependencies([a_4]): 1717 b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], 1718 [dtypes.float32]) 1719 c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], 1720 [dtypes.float32]) 1721 d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4], 1722 [dtypes.float32]) 1723 e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3], 1724 [dtypes.float32]) 1725 1726 self.assertItemsEqual([a_1.op], b_1.op.control_inputs) 1727 self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs) 1728 self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs) 1729 self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs) 1730 1731 self.assertItemsEqual([], c_1.op.control_inputs) 1732 self.assertItemsEqual([a_2.op], c_2.op.control_inputs) 1733 self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs) 1734 self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs) 1735 1736 self.assertItemsEqual([], d_1.op.control_inputs) 1737 self.assertItemsEqual([], d_2.op.control_inputs) 1738 self.assertItemsEqual([], d_3.op.control_inputs) 1739 self.assertItemsEqual([], d_4.op.control_inputs) 1740 1741 self.assertItemsEqual([a_1.op], e_1.op.control_inputs) 1742 self.assertItemsEqual([a_2.op], e_2.op.control_inputs) 1743 self.assertItemsEqual([a_3.op], e_3.op.control_inputs) 1744 self.assertItemsEqual([a_4.op], e_4.op.control_inputs) 1745 1746 def testRepeatedDependency(self): 1747 g = ops.Graph() 1748 a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32]) 1749 a_0, a_1 = a.outputs 1750 with g.control_dependencies([a_0]): 1751 b = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1752 with g.control_dependencies([a_1]): 1753 c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1754 1755 self.assertEqual(b.op.control_inputs, [a]) 1756 self.assertEqual(c.op.control_inputs, [a]) 1757 1758 def testNoControlDependencyWithDataDependency(self): 1759 g = ops.Graph() 1760 a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1761 with g.control_dependencies([a]): 1762 b = _apply_op(g, "Identity", [a], [dtypes.float32]) 1763 1764 self.assertEqual(b.op.control_inputs, []) 1765 1766 1767 @test_util.with_c_api 1768 class OpScopeTest(test_util.TensorFlowTestCase): 1769 1770 @test_util.run_in_graph_and_eager_modes() 1771 def testNames(self): 1772 with ops.name_scope("foo") as foo: 1773 self.assertEqual("foo/", foo) 1774 with ops.name_scope("foo2") as foo2: 1775 self.assertEqual("foo/foo2/", foo2) 1776 with ops.name_scope(None) as empty1: 1777 self.assertEqual("", empty1) 1778 with ops.name_scope("foo3") as foo3: 1779 self.assertEqual("foo3/", foo3) 1780 with ops.name_scope("") as empty2: 1781 self.assertEqual("", empty2) 1782 with ops.name_scope("foo/") as outer_foo: 1783 self.assertEqual("foo/", outer_foo) 1784 with ops.name_scope("") as empty3: 1785 self.assertEqual("", empty3) 1786 with ops.name_scope("foo4") as foo4: 1787 self.assertEqual("foo/foo4/", foo4) 1788 with ops.name_scope("foo5//") as foo5: 1789 self.assertEqual("foo5//", foo5) 1790 with ops.name_scope("foo6") as foo6: 1791 self.assertEqual("foo5//foo6/", foo6) 1792 with ops.name_scope("/") as foo7: 1793 self.assertEqual("/", foo7) 1794 with ops.name_scope("//") as foo8: 1795 self.assertEqual("//", foo8) 1796 with ops.name_scope("a//b/c") as foo9: 1797 self.assertEqual("foo/a//b/c/", foo9) 1798 with ops.name_scope("a//b/c") as foo10: 1799 self.assertEqual("a//b/c/", foo10) 1800 1801 @test_util.run_in_graph_and_eager_modes() 1802 def testEagerDefaultScopeName(self): 1803 with ops.name_scope(None, "default") as scope: 1804 self.assertEqual(scope, "default/") 1805 with ops.name_scope(None, "default2") as scope2: 1806 self.assertEqual(scope2, "default/default2/") 1807 1808 def testNoScopeName(self): 1809 g0 = ops.Graph() 1810 values = [ 1811 g0.create_op("A", [], [dtypes.float32]), 1812 g0.create_op("B", [], [dtypes.float32]) 1813 ] 1814 with self.assertRaises(ValueError): 1815 with ops.name_scope(None, values=values): 1816 pass 1817 with self.assertRaises(ValueError): 1818 with ops.name_scope(None, None, values): 1819 pass 1820 1821 def testEmptyScopeName(self): 1822 g0 = ops.Graph() 1823 a = g0.create_op("A", [], [dtypes.float32]) 1824 b = g0.create_op("B", [], [dtypes.float32]) 1825 with ops.name_scope("", values=[a, b]) as scope: 1826 self.assertEqual("", scope) 1827 self.assertEqual(g0, ops.get_default_graph()) 1828 with ops.name_scope("", "my_default_scope", [a, b]) as scope: 1829 self.assertEqual("", scope) 1830 self.assertEqual(g0, ops.get_default_graph()) 1831 1832 def testDefaultScopeName(self): 1833 g0 = ops.Graph() 1834 a = g0.create_op("A", [], [dtypes.float32]) 1835 b = g0.create_op("B", [], [dtypes.float32]) 1836 scope_name = "my_scope" 1837 default_scope_name = "my_default_scope" 1838 with ops.name_scope(scope_name, default_scope_name, [a, b]) as scope: 1839 self.assertEqual("%s/" % scope_name, scope) 1840 self.assertEqual(g0, ops.get_default_graph()) 1841 with ops.name_scope(None, default_scope_name, [a, b]) as scope: 1842 self.assertEqual("%s/" % default_scope_name, scope) 1843 self.assertEqual(g0, ops.get_default_graph()) 1844 1845 def _testGraphElements(self, graph_elements): 1846 scope_name = "my_scope" 1847 with ops.name_scope(scope_name, values=graph_elements) as scope: 1848 self.assertEqual("%s/" % scope_name, scope) 1849 self.assertEqual(graph_elements[0].graph, ops.get_default_graph()) 1850 g1 = ops.Graph() 1851 a = g1.create_op("A", [], [dtypes.float32]) 1852 with self.assertRaises(ValueError): 1853 with ops.name_scope(scope_name, values=graph_elements + [a]): 1854 pass 1855 1856 def testTensor(self): 1857 g0 = ops.Graph() 1858 a = g0.create_op("A", [], [dtypes.float32]) 1859 b = g0.create_op("B", [], [dtypes.float32]) 1860 self._testGraphElements([a, b]) 1861 1862 def testSparseTensor(self): 1863 g0 = ops.Graph() 1864 a = g0.create_op("A", [], [dtypes.float32]) 1865 b = g0.create_op("B", [], [dtypes.float32]) 1866 sparse = sparse_tensor.SparseTensor( 1867 _apply_op(g0, "Int64Output", [], [dtypes.int64]), 1868 _apply_op(g0, "FloatOutput", [], [dtypes.float32]), 1869 _apply_op(g0, "Int64Output", [], [dtypes.int64])) 1870 self._testGraphElements([a, sparse, b]) 1871 1872 def testVariable(self): 1873 g0 = ops.Graph() 1874 with g0.as_default(): 1875 variable = variables.Variable([1.0]) 1876 a = g0.create_op("A", [], [dtypes.float32]) 1877 b = g0.create_op("B", [], [dtypes.float32]) 1878 self._testGraphElements([a, variable, b]) 1879 1880 1881 class InitScopeTest(test_util.TensorFlowTestCase): 1882 1883 def testClearsControlDependencies(self): 1884 g = ops.Graph() 1885 a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1886 a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1887 a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1888 a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1889 1890 with g.as_default(): 1891 with g.control_dependencies([a_1]): 1892 with g.control_dependencies([a_2]): 1893 with ops.init_scope(): 1894 with g.control_dependencies([a_3]): 1895 with g.control_dependencies([a_4]): 1896 # deps [a_3, a_4] 1897 b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1898 # deps = [a_3] 1899 b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1900 # deps back to None 1901 b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1902 # deps back to [a_1, a_2] 1903 b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1904 # deps back to [a_1] 1905 b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1906 with ops.init_scope(): 1907 # deps are None again 1908 b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1909 1910 self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs) 1911 self.assertItemsEqual([a_3.op], b_3.op.control_inputs) 1912 self.assertItemsEqual([], b_none.op.control_inputs) 1913 self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs) 1914 self.assertItemsEqual([a_1.op], b_1.op.control_inputs) 1915 self.assertItemsEqual([], b_none2.op.control_inputs) 1916 1917 def testLiftsOpsFromFunctions(self): 1918 g0 = ops.Graph() 1919 g1 = ops.Graph() 1920 g1._building_function = True # pylint: disable=protected-access 1921 g2 = ops.Graph() 1922 g2._building_function = True # pylint: disable=protected-access 1923 1924 with g0.as_default(): 1925 with g1.as_default(): 1926 with g2.as_default(): 1927 with ops.init_scope(): 1928 _ = constant_op.constant(1.0) 1929 1930 self.assertEqual(len(g2.get_operations()), 0) 1931 self.assertEqual(len(g1.get_operations()), 0) 1932 self.assertEqual(len(g0.get_operations()), 1) 1933 1934 def testComposes(self): 1935 g0 = ops.Graph() 1936 g1 = ops.Graph() 1937 g1._building_function = True # pylint: disable=protected-access 1938 g2 = ops.Graph() 1939 g2._building_function = True # pylint: disable=protected-access 1940 g3 = ops.Graph() 1941 g3._building_function = False # pylint: disable=protected-access 1942 1943 with g0.as_default(): 1944 with g1.as_default(): 1945 with ops.init_scope(): 1946 # This op should be lifted into g0. 1947 _ = constant_op.constant(1.0) 1948 self.assertIs(g0, ops.get_default_graph()) 1949 self.assertEqual(len(g2.get_operations()), 0) 1950 self.assertEqual(len(g1.get_operations()), 0) 1951 self.assertEqual(len(g0.get_operations()), 1) 1952 with g2.as_default(): 1953 with ops.init_scope(): 1954 # This op should be lifted into g0. 1955 _ = constant_op.constant(1.0) 1956 self.assertIs(g0, ops.get_default_graph()) 1957 with g3.as_default(): 1958 with ops.init_scope(): 1959 # This op should be lifted into g3, because g3 is not building a 1960 # function. 1961 _ = constant_op.constant(1.0) 1962 self.assertIs(g3, ops.get_default_graph()) 1963 1964 self.assertEqual(len(g3.get_operations()), 1) 1965 self.assertEqual(len(g2.get_operations()), 0) 1966 self.assertEqual(len(g1.get_operations()), 0) 1967 self.assertEqual(len(g0.get_operations()), 2) 1968 1969 def testEscapesToEagerContext(self): 1970 g = ops.Graph() 1971 g._building_function = True # pylint: disable=protected-access 1972 with context.eager_mode(): 1973 with context.graph_mode(): 1974 with g.as_default(): 1975 with ops.init_scope(): 1976 # Because g is building a function, init_scope should 1977 # escape out to the eager context. 1978 self.assertTrue(context.in_eager_mode()) 1979 # g should be reinstated as the default graph, and the 1980 # graph context should be re-entered. 1981 self.assertIs(g, ops.get_default_graph()) 1982 self.assertTrue(context.in_graph_mode()) 1983 1984 def testAllGraphsBuildingFunctionsRaisesError(self): 1985 g = ops.Graph() 1986 g._building_function = True # pylint: disable=protected-access 1987 with g.as_default(): 1988 with self.assertRaises(AssertionError): 1989 with ops.init_scope(): 1990 pass 1991 1992 def testStaysInEagerWhenOnlyEagerContextActive(self): 1993 with context.eager_mode(): 1994 with ops.init_scope(): 1995 self.assertTrue(context.eager_mode()) 1996 self.assertTrue(context.eager_mode()) 1997 1998 def testEscapesDefunWhenInEagerMode(self): 1999 2000 def function_with_variables(): 2001 with ops.init_scope(): 2002 v = resource_variable_ops.ResourceVariable(3) 2003 return v.assign_add(1) 2004 2005 with context.eager_mode(): 2006 # Each invocation of function_with_variables recreates a variable. 2007 self.assertEqual(4, int(function_with_variables())) 2008 self.assertEqual(4, int(function_with_variables())) 2009 2010 compiled = eager_function.defun(function_with_variables) 2011 # The init_scope in function_with_variables lifts the variable out 2012 # of the graph function constructed by defun; hence, 2013 # compiled now appears to be stateful. 2014 self.assertEqual(4, int(compiled())) 2015 self.assertEqual(5, int(compiled())) 2016 2017 def testEscapesDefunWhenInGraphMode(self): 2018 def function_with_variables(name): 2019 with ops.init_scope(): 2020 _ = variable_scope.get_variable(name, shape=(1,)) 2021 2022 g = ops.Graph() 2023 with g.as_default(): 2024 with self.test_session(): 2025 # First ensure that graphs that are not building functions are 2026 # not escaped. 2027 function_with_variables("foo") 2028 with self.assertRaisesRegexp(ValueError, 2029 r"Variable foo already exists.*"): 2030 # This will fail because reuse is not set to True. 2031 function_with_variables("foo") 2032 2033 compiled = eager_function.defun(function_with_variables) 2034 compiled("bar") 2035 self.assertEqual( 2036 len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2) 2037 2038 # The second call to `compiled` should not create variables: the 2039 # init_scope has lifted the variable creation code out of the defun. 2040 compiled("bar") 2041 self.assertEqual( 2042 len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2) 2043 2044 def testEscapesNestedDefun(self): 2045 2046 def inner_function(): 2047 with ops.init_scope(): 2048 v = resource_variable_ops.ResourceVariable(1) 2049 return v.assign_add(2) 2050 2051 def outer_function(inner=None): 2052 with ops.init_scope(): 2053 v0 = resource_variable_ops.ResourceVariable(0) 2054 return v0.assign_add(1) + inner() 2055 2056 with context.eager_mode(): 2057 # Each invocation of outer_function recreates variables. 2058 self.assertEqual(4, int(outer_function(inner=inner_function))) 2059 self.assertEqual(4, int(outer_function(inner=inner_function))) 2060 2061 compiled_inner = eager_function.defun(inner_function) 2062 compiled_outer = eager_function.defun(outer_function) 2063 # The init_scope lifts variables out of the graph functions 2064 # constructed by defun; hence, compiled_outer should now appear to be 2065 # stateful. 2066 self.assertEqual(4, int(compiled_outer(inner=compiled_inner))) 2067 self.assertEqual(7, int(compiled_outer(inner=compiled_inner))) 2068 2069 def testInstallsDefaultGraphWhenGraphStackIsEmptyInGraphMode(self): 2070 with context.graph_mode(): 2071 # pylint: disable=protected-access 2072 self.assertEqual(len(ops._default_graph_stack.stack), 0) 2073 with ops.init_scope(): 2074 self.assertGreater(len(ops._default_graph_stack.stack), 0) 2075 self.assertEqual(len(ops._default_graph_stack.stack), 0) 2076 # pylint: enable=protected-access 2077 2078 def testPreservesNameScopeInGraphConstruction(self): 2079 with ops.Graph().as_default(): 2080 function_graph = ops.Graph() 2081 with function_graph.as_default(): 2082 with ops.name_scope("inner"), ops.init_scope(): 2083 self.assertEqual(ops.get_name_scope(), "inner") 2084 self.assertEqual(ops.get_name_scope(), "") 2085 2086 def testPreservesNameScopeInEagerExecution(self): 2087 with context.eager_mode(): 2088 def foo(): 2089 with ops.name_scope("inner"), ops.init_scope(): 2090 if context.in_graph_mode(): 2091 self.assertEqual(ops.get_name_scope(), "inner") 2092 else: 2093 # A trailing slash is always appended when eager execution is 2094 # enabled. 2095 self.assertEqual(context.context().scope_name, "inner/") 2096 foo() 2097 self.assertEqual(ops.get_name_scope(), "") 2098 foo_compiled = eager_function.defun(foo) 2099 foo_compiled() 2100 self.assertEqual(ops.get_name_scope(), "") 2101 2102 2103 @test_util.with_c_api 2104 class GraphTest(test_util.TensorFlowTestCase): 2105 2106 def setUp(self): 2107 ops.reset_default_graph() 2108 2109 def _AssertDefault(self, expected): 2110 self.assertIs(expected, ops.get_default_graph()) 2111 2112 def testResetDefaultGraphNesting(self): 2113 g0 = ops.Graph() 2114 with self.assertRaises(AssertionError): 2115 with g0.as_default(): 2116 ops.reset_default_graph() 2117 2118 def testGraphContextManager(self): 2119 g0 = ops.Graph() 2120 with g0.as_default() as g1: 2121 self.assertIs(g0, g1) 2122 2123 def testDefaultGraph(self): 2124 orig = ops.get_default_graph() 2125 self._AssertDefault(orig) 2126 g0 = ops.Graph() 2127 self._AssertDefault(orig) 2128 context_manager_0 = g0.as_default() 2129 self._AssertDefault(orig) 2130 with context_manager_0 as g0: 2131 self._AssertDefault(g0) 2132 with ops.Graph().as_default() as g1: 2133 self._AssertDefault(g1) 2134 self._AssertDefault(g0) 2135 self._AssertDefault(orig) 2136 2137 def testPreventFeeding(self): 2138 g = ops.Graph() 2139 a = constant_op.constant(2.0) 2140 self.assertTrue(g.is_feedable(a)) 2141 g.prevent_feeding(a) 2142 self.assertFalse(g.is_feedable(a)) 2143 2144 def testPreventFetching(self): 2145 g = ops.Graph() 2146 a = constant_op.constant(2.0) 2147 self.assertTrue(g.is_fetchable(a)) 2148 g.prevent_fetching(a.op) 2149 self.assertFalse(g.is_fetchable(a)) 2150 2151 def testAsGraphElementConversions(self): 2152 2153 class ConvertibleObj(object): 2154 2155 def _as_graph_element(self): 2156 return "FloatOutput:0" 2157 2158 class NonConvertibleObj(object): 2159 2160 pass 2161 2162 g = ops.Graph() 2163 a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2164 self.assertEqual(a, g.as_graph_element(ConvertibleObj())) 2165 with self.assertRaises(TypeError): 2166 g.as_graph_element(NonConvertibleObj()) 2167 2168 # Regression test against creating custom __del__ functions in classes 2169 # involved in cyclic references, e.g. Graph and Operation. (Python won't gc 2170 # cycles that require calling a __del__ method, because the __del__ method can 2171 # theoretically increase the object's refcount to "save" it from gc, and any 2172 # already-deleted objects in the cycle would have be to restored.) 2173 def testGarbageCollected(self): 2174 # Create a graph we can delete and a weak reference to monitor if it's gc'd 2175 g = ops.Graph() 2176 g_ref = weakref.ref(g) 2177 # Create some ops 2178 with g.as_default(): 2179 a = constant_op.constant(2.0) 2180 b = constant_op.constant(3.0) 2181 c = math_ops.add(a, b) 2182 # Create a session we can delete 2183 with session.Session(graph=g) as sess: 2184 sess.run(c) 2185 # Delete all references and trigger gc 2186 del g 2187 del a 2188 del b 2189 del c 2190 del sess 2191 gc.collect() 2192 self.assertIsNone(g_ref()) 2193 2194 def testRunnableAfterInvalidShape(self): 2195 with ops.Graph().as_default(): 2196 with self.assertRaises(ValueError): 2197 math_ops.add([1, 2], [1, 2, 3]) 2198 a = constant_op.constant(1) 2199 with session.Session() as sess: 2200 sess.run(a) 2201 2202 def testRunnableAfterInvalidShapeWithKernelLabelMap(self): 2203 g = ops.Graph() 2204 with g.as_default(): 2205 with g._kernel_label_map({"KernelLabelRequired": "overload_1"}): 2206 with self.assertRaises(ValueError): 2207 test_ops.kernel_label_required(1) 2208 a = constant_op.constant(1) 2209 with session.Session() as sess: 2210 sess.run(a) 2211 2212 2213 @test_util.with_c_api 2214 class AttrScopeTest(test_util.TensorFlowTestCase): 2215 2216 def _get_test_attrs(self): 2217 x = control_flow_ops.no_op() 2218 try: 2219 a = compat.as_text(x.get_attr("_A")) 2220 except ValueError: 2221 a = None 2222 try: 2223 b = compat.as_text(x.get_attr("_B")) 2224 except ValueError: 2225 b = None 2226 return (a, b) 2227 2228 def testNoLabel(self): 2229 with self.test_session(): 2230 self.assertAllEqual((None, None), self._get_test_attrs()) 2231 2232 def testLabelMap(self): 2233 with self.test_session() as sess: 2234 a1 = self._get_test_attrs() 2235 with sess.graph._attr_scope({ 2236 "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo")) 2237 }): 2238 a2 = self._get_test_attrs() 2239 with sess.graph._attr_scope({ 2240 "_A": None, 2241 "_B": attr_value_pb2.AttrValue(s=compat.as_bytes("bar")) 2242 }): 2243 a3 = self._get_test_attrs() 2244 with sess.graph._attr_scope({ 2245 "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("baz")) 2246 }): 2247 a4 = self._get_test_attrs() 2248 a5 = self._get_test_attrs() 2249 a6 = self._get_test_attrs() 2250 a7 = self._get_test_attrs() 2251 2252 self.assertAllEqual((None, None), a1) 2253 self.assertAllEqual(("foo", None), a2) 2254 self.assertAllEqual((None, "bar"), a3) 2255 self.assertAllEqual(("baz", "bar"), a4) 2256 self.assertAllEqual((None, "bar"), a5) 2257 self.assertAllEqual(("foo", None), a6) 2258 self.assertAllEqual((None, None), a7) 2259 2260 2261 ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape) 2262 2263 2264 @test_util.with_c_api 2265 class KernelLabelTest(test_util.TensorFlowTestCase): 2266 2267 @test_util.enable_c_api 2268 def testNoLabel(self): 2269 with self.test_session(): 2270 self.assertAllEqual(b"My label is: default", 2271 test_ops.kernel_label().eval()) 2272 2273 def testLabelMap(self): 2274 with self.test_session() as sess: 2275 default_1 = test_ops.kernel_label() 2276 # pylint: disable=protected-access 2277 with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}): 2278 overload_1_1 = test_ops.kernel_label() 2279 with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}): 2280 overload_2 = test_ops.kernel_label() 2281 with sess.graph._kernel_label_map({"KernelLabel": ""}): 2282 default_2 = test_ops.kernel_label() 2283 overload_1_2 = test_ops.kernel_label() 2284 # pylint: enable=protected-access 2285 default_3 = test_ops.kernel_label() 2286 2287 self.assertAllEqual(b"My label is: default", default_1.eval()) 2288 self.assertAllEqual(b"My label is: default", default_2.eval()) 2289 self.assertAllEqual(b"My label is: default", default_3.eval()) 2290 self.assertAllEqual(b"My label is: overload_1", overload_1_1.eval()) 2291 self.assertAllEqual(b"My label is: overload_1", overload_1_2.eval()) 2292 self.assertAllEqual(b"My label is: overload_2", overload_2.eval()) 2293 2294 2295 @test_util.with_c_api 2296 class AsGraphDefTest(test_util.TensorFlowTestCase): 2297 2298 def testGraphDefVersion(self): 2299 """Test that the graphdef version is plumbed through to kernels.""" 2300 with ops.Graph().as_default() as g: 2301 version = g.graph_def_versions.producer 2302 with self.test_session(graph=g): 2303 v = test_ops.graph_def_version().eval() 2304 self.assertEqual(version, v) 2305 2306 def testAddShapes(self): 2307 with ops.Graph().as_default() as g: 2308 t1, t2, t3, t4, t5 = _apply_op(g, "FiveFloatOutputs", [], 2309 [dtypes.float32] * 5) 2310 t1.set_shape(None) 2311 t2.set_shape([]) 2312 t3.set_shape([None]) 2313 t4.set_shape([43, 37]) 2314 t5.set_shape([43, None]) 2315 2316 b = constant_op.constant(1.0) # pylint: disable=unused-variable 2317 2318 gd = g.as_graph_def(add_shapes=True) 2319 self.assertProtoEqualsVersion(""" 2320 node { name: "FiveFloatOutputs" op: "FiveFloatOutputs" 2321 attr { 2322 key: "_output_shapes" 2323 value { 2324 list { 2325 shape { unknown_rank: true } 2326 shape { } 2327 shape { dim { size: -1 } } 2328 shape { dim { size: 43 } dim { size: 37 } } 2329 shape { dim { size: 43 } dim { size: -1 } } 2330 } 2331 } 2332 } 2333 } 2334 node { name: "Const" op: "Const" 2335 attr { 2336 key: "_output_shapes" 2337 value { 2338 list { 2339 shape { } 2340 } 2341 } 2342 } 2343 attr { 2344 key: "dtype" 2345 value { type: DT_FLOAT } 2346 } 2347 attr { 2348 key: "value" 2349 value { 2350 tensor { 2351 dtype: DT_FLOAT 2352 tensor_shape { } 2353 float_val: 1.0 } } } } 2354 """, gd) 2355 2356 2357 @ops.RegisterStatistics("a", "flops") 2358 def _calc_a_forward_flops(unused_graph, unused_node): 2359 return ops.OpStats("flops", 20) 2360 2361 2362 @test_util.with_c_api 2363 class StatisticsTest(test_util.TensorFlowTestCase): 2364 2365 def testRegisteredNode(self): 2366 graph = ops.Graph() 2367 node = ops._NodeDef("a", "an_a") 2368 flops = ops.get_stats_for_node_def(graph, node, "flops") 2369 self.assertEqual(20, flops.value) 2370 missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat") 2371 self.assertEqual(None, missing_stat.value) 2372 2373 def testUnregisteredNode(self): 2374 graph = ops.Graph() 2375 node = ops._NodeDef("b", "a_b") 2376 weight_params = ops.get_stats_for_node_def(graph, node, "weight_params") 2377 self.assertEqual(None, weight_params.value) 2378 2379 def testAccumulateStatistics(self): 2380 flops_total = ops.OpStats("flops") 2381 self.assertEqual(None, flops_total.value) 2382 second_flops = ops.OpStats("flops", 3) 2383 flops_total += second_flops 2384 self.assertEqual(3, flops_total.value) 2385 2386 2387 @test_util.with_c_api 2388 class ColocationGroupTest(test_util.TensorFlowTestCase): 2389 2390 def testBasic(self): 2391 a = constant_op.constant([2.0], name="a") 2392 with ops.colocate_with(a.op): 2393 b = constant_op.constant(3.0) 2394 c = constant_op.constant(4.0) 2395 self.assertEqual([b"loc:@a"], a.op.colocation_groups()) 2396 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 2397 with self.assertRaises(ValueError): 2398 c.op.get_attr("_class") 2399 2400 def testColocationDeviceInteraction(self): 2401 with ops.device("/cpu:0"): 2402 with ops.device("/device:GPU:0"): 2403 a = constant_op.constant([2.0], name="a") 2404 with ops.colocate_with(a.op): 2405 # 'b' is created in the scope of /cpu:0, but it is 2406 # colocated with 'a', which is on '/device:GPU:0'. colocate_with 2407 # overrides devices because it is a stronger constraint. 2408 b = constant_op.constant(3.0) 2409 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 2410 self.assertEqual(a.op.device, b.op.device) 2411 2412 def testColocationCanonicalization(self): 2413 with ops.device("/device:GPU:0"): 2414 _ = constant_op.constant(2.0) 2415 with ops.device(lambda op: "/device:GPU:0"): 2416 b = constant_op.constant(3.0) 2417 with ops.get_default_graph().colocate_with(b): 2418 with ops.device("/device:GPU:0"): 2419 c = constant_op.constant(4.0) 2420 2421 # A's device will be /device:GPU:0 2422 # B's device will be /device:GPU:0 2423 # C's device will be /device:GPU:0 because it 2424 # inherits B's device name, after canonicalizing the names. 2425 self.assertEqual(b.op.device, c.op.device) 2426 2427 def testLocationOverrides(self): 2428 with ops.device("/cpu:0"): 2429 with ops.device("/device:GPU:0"): 2430 a = constant_op.constant([2.0], name="a") 2431 # Note that this colocation is "redundant", since we are 2432 # within the scope of "/device:GPU:0". However, we would like to 2433 # preserve in the GraphDef that these two ops should be 2434 # colocated in a portable way. 2435 with ops.colocate_with(a.op): 2436 b = constant_op.constant(3.0) 2437 c = constant_op.constant(4.0) 2438 d = constant_op.constant(5.0) 2439 2440 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 2441 self.assertEqual("/device:GPU:0", a.op.device) 2442 self.assertEqual(a.op.device, b.op.device) 2443 2444 # Test that device function stack is restored. 2445 self.assertEqual("/device:GPU:0", c.op.device) 2446 self.assertEqual("/device:CPU:0", d.op.device) 2447 2448 def testNestedColocateWith(self): 2449 a = constant_op.constant([2.0], name="a") 2450 with ops.colocate_with(a.op): 2451 b = constant_op.constant(3.0) 2452 with ops.colocate_with(b.op): 2453 c = constant_op.constant(4.0) 2454 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 2455 self.assertEqual([b"loc:@a"], c.op.colocation_groups()) 2456 2457 def testMultiColocationGroups(self): 2458 a = constant_op.constant([2.0], name="a") 2459 b = constant_op.constant(3.0, name="b") 2460 with ops.colocate_with(a.op): 2461 with ops.colocate_with(b.op): 2462 c = constant_op.constant(4.0) 2463 self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups())) 2464 2465 def testColocationIgnoreStack(self): 2466 a = constant_op.constant([2.0], name="a") 2467 b = constant_op.constant(3.0, name="b") 2468 with ops.colocate_with(a.op): 2469 with ops.colocate_with(b.op, ignore_existing=True): 2470 c = constant_op.constant(4.0) 2471 self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups())) 2472 2473 def testColocateWithReset(self): 2474 a = constant_op.constant([2.0], name="a") 2475 with ops.colocate_with(a.op): 2476 b = constant_op.constant(3.0, name="b") 2477 with ops.colocate_with(None, ignore_existing=True): 2478 c = constant_op.constant(4.0, name="c") 2479 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 2480 self.assertEqual([b"loc:@c"], c.op.colocation_groups()) 2481 2482 def testColocateWithInitialNoneThenNested(self): 2483 a = constant_op.constant([2.0], name="a") 2484 with ops.colocate_with(a.op): 2485 with ops.colocate_with(None, ignore_existing=True): 2486 b = constant_op.constant(3.0, name="b") 2487 with ops.colocate_with(b.op): 2488 c = constant_op.constant(4.0, name="c") 2489 self.assertEqual([b"loc:@b"], b.op.colocation_groups()) 2490 self.assertEqual([b"loc:@b"], c.op.colocation_groups()) 2491 2492 def testColocateVariables(self): 2493 a = variables.Variable([2.0], name="a") 2494 with ops.colocate_with(a.op): 2495 b = variables.Variable([3.0], name="b") 2496 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 2497 2498 def testInconsistentDeviceWithinColocate(self): 2499 with ops.device("/device:GPU:0"): 2500 a = constant_op.constant([2.0], name="a") 2501 with ops.colocate_with(a.op): 2502 # This is allowed due to legacy but clearly wrong, since we 2503 # should really be colocating with 'a'. We allow devices to 2504 # override colocate_with, but we log warnings to suggest that 2505 # this is probably unintentional or misguided. 2506 with ops.device("/cpu:0"): 2507 b = constant_op.constant([3.0], name="b") 2508 2509 self.assertEqual("/device:CPU:0", b.device) 2510 2511 2512 @test_util.with_c_api 2513 class DeprecatedTest(test_util.TensorFlowTestCase): 2514 2515 def testSuccess(self): 2516 # TODO(skyewm): make g.graph_def_versions work with the C API enabled 2517 if ops._USE_C_API: return 2518 2519 with ops.Graph().as_default() as g: 2520 g.graph_def_versions.producer = 7 2521 old = test_ops.old() 2522 with self.test_session(graph=g): 2523 old.run() 2524 2525 def _error(self): 2526 return ((r"Op Old is not available in GraphDef version %d\. " 2527 r"It has been removed in version 8\. For reasons\.") % 2528 versions.GRAPH_DEF_VERSION) 2529 2530 def testGraphConstructionFail(self): 2531 with ops.Graph().as_default(): 2532 with self.assertRaisesRegexp(NotImplementedError, self._error()): 2533 test_ops.old() 2534 2535 def testGraphExecutionFail(self): 2536 # TODO(skyewm): make g.graph_def_versions work with the C API enabled 2537 if ops._USE_C_API: return 2538 2539 with ops.Graph().as_default() as g: 2540 g.graph_def_versions.producer = 7 2541 old = test_ops.old() 2542 g.graph_def_versions.producer = versions.GRAPH_DEF_VERSION 2543 with self.test_session(graph=g): 2544 with self.assertRaisesRegexp(errors.UnimplementedError, self._error()): 2545 old.run() 2546 2547 2548 @test_util.with_c_api 2549 class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase): 2550 2551 def testSuccess(self): 2552 op = ops.Operation( 2553 ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) 2554 t = op.outputs[0] 2555 self.assertTrue(ops.is_dense_tensor_like(t)) 2556 2557 v = variables.Variable([17]) 2558 self.assertTrue(ops.is_dense_tensor_like(v)) 2559 2560 class BadClassNoName(object): 2561 pass 2562 2563 class BadClassBadName(object): 2564 2565 def name(self): 2566 pass 2567 2568 class BadClassNoDtype(object): 2569 2570 @property 2571 def name(self): 2572 pass 2573 2574 class BadClassBadDtype(object): 2575 2576 @property 2577 def name(self): 2578 pass 2579 2580 def dtype(self): 2581 pass 2582 2583 def testBadClass(self): 2584 with self.assertRaisesRegexp(TypeError, "`name`"): 2585 ops.register_dense_tensor_like_type( 2586 DenseTensorLikeTypeTest.BadClassNoName) 2587 with self.assertRaisesRegexp(TypeError, "`name`"): 2588 ops.register_dense_tensor_like_type( 2589 DenseTensorLikeTypeTest.BadClassBadName) 2590 with self.assertRaisesRegexp(TypeError, "`dtype`"): 2591 ops.register_dense_tensor_like_type( 2592 DenseTensorLikeTypeTest.BadClassNoDtype) 2593 with self.assertRaisesRegexp(TypeError, "`dtype`"): 2594 ops.register_dense_tensor_like_type( 2595 DenseTensorLikeTypeTest.BadClassBadDtype) 2596 2597 2598 @test_util.with_c_api 2599 class NameScopeTest(test_util.TensorFlowTestCase): 2600 2601 def testStripAndPrependScope(self): 2602 strs = [ 2603 "hidden1/hidden1/weights", # Same prefix. Should strip. 2604 "hidden1///hidden1/weights", # Extra "/". Should strip. 2605 "^hidden1/hidden1/weights", # Same prefix. Should strip. 2606 "loc:@hidden1/hidden1/weights", # Same prefix. Should strip. 2607 "hhidden1/hidden1/weights", # Different prefix. Should keep. 2608 "hidden1" 2609 ] # Not a prefix. Should keep. 2610 expected_striped = [ 2611 "hidden1/weights", "hidden1/weights", "^hidden1/weights", 2612 "loc:@hidden1/weights", "hhidden1/hidden1/weights", "hidden1" 2613 ] 2614 expected_prepended = [ 2615 "hidden2/hidden1/weights", "hidden2/hidden1/weights", 2616 "^hidden2/hidden1/weights", "loc:@hidden2/hidden1/weights", 2617 "hidden2/hhidden1/hidden1/weights", "hidden2/hidden1" 2618 ] 2619 name_scope_to_strip = "hidden1" 2620 name_scope_to_add = "hidden2" 2621 for es, ep, s in zip(expected_striped, expected_prepended, strs): 2622 striped = ops.strip_name_scope(s, name_scope_to_strip) 2623 self.assertEqual(es, striped) 2624 self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add)) 2625 2626 def testGetNameScope(self): 2627 with ops.Graph().as_default() as g: 2628 with ops.name_scope("scope1"): 2629 with ops.name_scope("scope2"): 2630 with ops.name_scope("scope3"): 2631 self.assertEqual("scope1/scope2/scope3", g.get_name_scope()) 2632 self.assertEqual("scope1/scope2", g.get_name_scope()) 2633 self.assertEqual("scope1", g.get_name_scope()) 2634 self.assertEqual("", g.get_name_scope()) 2635 2636 def testTwoGraphs(self): 2637 2638 def f(): 2639 g1 = ops.Graph() 2640 g2 = ops.Graph() 2641 with g1.as_default(): 2642 with g2.as_default(): 2643 with ops.name_scope("_"): 2644 pass 2645 2646 self.assertRaisesRegexp(ValueError, "'_' is not a valid scope name", f) 2647 2648 2649 @test_util.with_c_api 2650 class TracebackTest(test_util.TensorFlowTestCase): 2651 2652 def testTracebackWithStartLines(self): 2653 with self.test_session() as sess: 2654 a = constant_op.constant(2.0) 2655 sess.run( 2656 a, 2657 options=config_pb2.RunOptions( 2658 trace_level=config_pb2.RunOptions.FULL_TRACE)) 2659 self.assertTrue(sess.graph.get_operations()) 2660 2661 # Tests that traceback_with_start_lines is the same as traceback 2662 # but includes one more element at the end. 2663 for op in sess.graph.get_operations(): 2664 self.assertEquals(len(op.traceback), len(op.traceback_with_start_lines)) 2665 for frame, frame_with_start_line in zip( 2666 op.traceback, op.traceback_with_start_lines): 2667 self.assertEquals(5, len(frame_with_start_line)) 2668 self.assertEquals(frame, frame_with_start_line[:-1]) 2669 2670 2671 @test_util.with_c_api 2672 class OutputTypesTest(test_util.TensorFlowTestCase): 2673 """Tests Operation._output_types property. 2674 2675 This test should not exist as _output_types is a private property. 2676 This property is used by util.copy_elements and its tests would normally 2677 cover Operation._output_types. However, we can't yet run these tests in C 2678 API mode because their use _set_device method. This test will be deleted 2679 once we port _set_device and run the copy tests with C API on. 2680 """ 2681 # TODO(iga): Remove this test 2682 2683 def setUp(self): 2684 self.prev_use_c_api = ops._USE_C_API # pylint: disable=protected-access 2685 ops._USE_C_API = True # pylint: disable=protected-access 2686 2687 def tearDown(self): 2688 ops._USE_C_API = self.prev_use_c_api # pylint: disable=protected-access 2689 2690 def testOneOutput(self): 2691 g = ops.Graph() 2692 with g.as_default(): 2693 # Using a constant because creating unregistered ops 2694 # doesn't work with the C API. 2695 op = constant_op.constant(12, dtype=dtypes.uint16).op 2696 # pylint: disable=protected-access 2697 self.assertEqual([types_pb2.DT_UINT16], op._output_types) 2698 # pylint: enable=protected-access 2699 2700 def testTwoDifferentOutputs(self): 2701 g = ops.Graph() 2702 with g.as_default(): 2703 x = constant_op.constant([1, 1, 2, 4, 4, 4, 7, 8, 8], 2704 dtype=dtypes.double) 2705 y, _ = gen_array_ops._unique(x) 2706 self.assertEqual([types_pb2.DT_DOUBLE, types_pb2.DT_INT32], 2707 y.op._output_types) # pylint: disable=protected-access 2708 2709 def testThreeOutputs(self): 2710 g = ops.Graph() 2711 with g.as_default(): 2712 # Using a split operationt because creating unregistered ops 2713 # doesn't work with the C API. 2714 a = constant_op.constant("abc", dtype=dtypes.string, shape=[5, 30]) 2715 split0, _, _ = array_ops.split(a, [4, 15, 11], 1) 2716 # pylint: disable=protected-access 2717 self.assertEqual([types_pb2.DT_STRING] * 3, split0.op._output_types) 2718 # pylint: enable=protected-access 2719 2720 2721 @test_util.with_c_api 2722 class EnableEagerExecutionTest(test_util.TensorFlowTestCase): 2723 2724 def testBadArgumentsToEnableEagerExecution(self): 2725 with self.assertRaisesRegexp(TypeError, "config must be a tf.ConfigProto"): 2726 ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT) 2727 with self.assertRaisesRegexp(ValueError, "device_policy must be one of"): 2728 c = config_pb2.ConfigProto() 2729 ops.enable_eager_execution(c, c) 2730 2731 2732 if __name__ == "__main__": 2733 googletest.main() 2734