1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tf.py.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import operator 22 23 import numpy as np 24 25 from tensorflow.python.eager import function 26 from tensorflow.python.framework import constant_op 27 from tensorflow.python.framework import dtypes 28 from tensorflow.python.framework import errors_impl 29 from tensorflow.python.framework import ops 30 from tensorflow.python.ops import array_ops 31 from tensorflow.python.ops import control_flow_ops 32 from tensorflow.python.ops import gen_state_ops 33 from tensorflow.python.ops import math_ops 34 from tensorflow.python.ops import random_ops 35 from tensorflow.python.ops import resource_variable_ops 36 from tensorflow.python.ops import variables 37 from tensorflow.python.platform import test 38 from tensorflow.python.training import gradient_descent 39 from tensorflow.python.util import compat 40 41 42 class VariablesTestCase(test.TestCase): 43 44 def testInitialization(self): 45 with self.test_session(): 46 var0 = variables.Variable(0.0) 47 self.assertEqual("Variable:0", var0.name) 48 self.assertEqual("Variable", var0._shared_name) 49 self.assertEqual([], var0.get_shape()) 50 self.assertEqual([], var0.get_shape()) 51 self.assertEqual([], var0.shape) 52 53 var1 = variables.Variable(1.1) 54 self.assertEqual("Variable_1:0", var1.name) 55 self.assertEqual("Variable_1", var1._shared_name) 56 self.assertEqual([], var1.get_shape()) 57 self.assertEqual([], var1.get_shape()) 58 self.assertEqual([], var1.shape) 59 60 with self.assertRaisesOpError("Attempting to use uninitialized value"): 61 var0.eval() 62 63 with self.assertRaisesOpError("Attempting to use uninitialized value"): 64 var1.eval() 65 66 variables.global_variables_initializer().run() 67 68 self.assertAllClose(0.0, var0.eval()) 69 self.assertAllClose(1.1, var1.eval()) 70 71 def testInitializationOrder(self): 72 with self.test_session(): 73 rnd = variables.Variable(random_ops.random_uniform([3, 6]), name="rnd") 74 self.assertEqual("rnd:0", rnd.name) 75 self.assertEqual([3, 6], rnd.get_shape()) 76 self.assertEqual([3, 6], rnd.get_shape()) 77 self.assertEqual([3, 6], rnd.shape) 78 79 dep = variables.Variable(rnd.initialized_value(), name="dep") 80 self.assertEqual("dep:0", dep.name) 81 self.assertEqual([3, 6], dep.get_shape()) 82 self.assertEqual([3, 6], dep.get_shape()) 83 self.assertEqual([3, 6], dep.shape) 84 85 # Currently have to set the shape manually for Add. 86 added_val = rnd.initialized_value() + dep.initialized_value() + 2.0 87 added_val.set_shape(rnd.get_shape()) 88 89 depdep = variables.Variable(added_val, name="depdep") 90 self.assertEqual("depdep:0", depdep.name) 91 self.assertEqual([3, 6], depdep.get_shape()) 92 self.assertEqual([3, 6], depdep.get_shape()) 93 self.assertEqual([3, 6], depdep.shape) 94 95 variables.global_variables_initializer().run() 96 97 self.assertAllClose(rnd.eval(), dep.eval()) 98 self.assertAllClose(rnd.eval() + dep.eval() + 2.0, depdep.eval()) 99 100 def testIterable(self): 101 with self.assertRaisesRegexp(TypeError, "not iterable"): 102 for _ in variables.Variable(0.0): 103 pass 104 with self.assertRaisesRegexp(TypeError, "not iterable"): 105 for _ in variables.Variable([0.0, 1.0]): 106 pass 107 108 def testAssignments(self): 109 with self.test_session(): 110 var = variables.Variable(0.0) 111 plus_one = var.assign_add(1.0) 112 minus_one = var.assign_sub(2.0) 113 four = var.assign(4.0) 114 variables.global_variables_initializer().run() 115 self.assertAllClose(0.0, var.eval()) 116 117 self.assertAllClose(1.0, plus_one.eval()) 118 self.assertAllClose(1.0, var.eval()) 119 120 self.assertAllClose(-1.0, minus_one.eval()) 121 self.assertAllClose(-1.0, var.eval()) 122 123 self.assertAllClose(4.0, four.eval()) 124 self.assertAllClose(4.0, var.eval()) 125 126 def testResourceAssignments(self): 127 with self.test_session(use_gpu=True): 128 var = resource_variable_ops.ResourceVariable(0.0) 129 plus_one = var.assign_add(1.0) 130 minus_one = var.assign_sub(2.0) 131 four = var.assign(4.0) 132 variables.global_variables_initializer().run() 133 self.assertAllClose(0.0, var.eval()) 134 135 plus_one.eval() 136 self.assertAllClose(1.0, var.eval()) 137 138 minus_one.eval() 139 self.assertAllClose(-1.0, var.eval()) 140 141 four.eval() 142 self.assertAllClose(4.0, var.eval()) 143 144 def testZeroSizeStringAssign(self): 145 with self.test_session() as sess: 146 array = variables.Variable( 147 initial_value=array_ops.zeros((0,), dtype=dtypes.string), 148 name="foo", 149 trainable=False, 150 collections=[ops.GraphKeys.LOCAL_VARIABLES]) 151 sess.run(variables.local_variables_initializer()) 152 old_value = array.value() 153 copy_op = array.assign(old_value) 154 self.assertEqual([], list(sess.run(copy_op))) 155 156 def _countUpToTest(self, dtype): 157 with self.test_session(): 158 zero = constant_op.constant(0, dtype=dtype) 159 var = variables.Variable(zero) 160 count_up_to = var.count_up_to(3) 161 162 variables.global_variables_initializer().run() 163 self.assertEqual(0, var.eval()) 164 165 self.assertEqual(0, count_up_to.eval()) 166 self.assertEqual(1, var.eval()) 167 168 self.assertEqual(1, count_up_to.eval()) 169 self.assertEqual(2, var.eval()) 170 171 self.assertEqual(2, count_up_to.eval()) 172 self.assertEqual(3, var.eval()) 173 174 with self.assertRaisesOpError("Reached limit of 3"): 175 count_up_to.eval() 176 self.assertEqual(3, var.eval()) 177 178 with self.assertRaisesOpError("Reached limit of 3"): 179 count_up_to.eval() 180 self.assertEqual(3, var.eval()) 181 182 def testCountUpToInt32(self): 183 self._countUpToTest(dtypes.int32) 184 185 def testCountUpToInt64(self): 186 self._countUpToTest(dtypes.int64) 187 188 def testControlDepsNone(self): 189 with self.test_session(): 190 c = constant_op.constant(1.0) 191 with ops.control_dependencies([c]): 192 # d get the control dep. 193 d = constant_op.constant(2.0) 194 # variables do not. 195 var_x = variables.Variable(2.0) 196 self.assertEqual([c.op], d.op.control_inputs) 197 self.assertEqual([], var_x.initializer.control_inputs) 198 self.assertEqual([], var_x.value().op.control_inputs) 199 self.assertEqual([], var_x._ref().op.control_inputs) # pylint: disable=protected-access 200 201 def testControlFlow(self): 202 with self.test_session() as sess: 203 v0 = variables.Variable(0, name="v0") 204 var_dict = {} 205 206 # Call get_variable in each of the cond clauses. 207 def var_in_then_clause(): 208 v1 = variables.Variable(1, name="v1") 209 var_dict["v1"] = v1 210 return v1 + v0 211 212 def var_in_else_clause(): 213 v2 = variables.Variable(2, name="v2") 214 var_dict["v2"] = v2 215 return v2 + v0 216 217 add = control_flow_ops.cond( 218 math_ops.less(v0, 10), var_in_then_clause, var_in_else_clause) 219 v1 = var_dict["v1"] 220 v2 = var_dict["v2"] 221 # We should be able to initialize and run v1 and v2 without initializing 222 # v0, even if the variable was created with a control dep on v0. 223 sess.run(v1.initializer) 224 self.assertEqual([1], sess.run(v1)) 225 sess.run(v2.initializer) 226 self.assertEqual([2], sess.run(v2)) 227 # v0 should still be uninitialized. 228 with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"): 229 sess.run(v0) 230 # We should not be able to run 'add' yet. 231 with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"): 232 sess.run(add) 233 # If we initialize v0 we should be able to run 'add'. 234 sess.run(v0.initializer) 235 sess.run(add) 236 237 def testControlFlowInitialization(self): 238 """Expects an error if an initializer is in a control-flow scope.""" 239 def cond(i, _): 240 return i < 10 241 242 def body(i, _): 243 zero = array_ops.zeros([], dtype=dtypes.int32) 244 v = variables.Variable(initial_value=zero) 245 return (i + 1, v.read_value()) 246 247 with self.assertRaisesRegexp(ValueError, "inside a control-flow"): 248 control_flow_ops.while_loop(cond, body, [0, 0]) 249 250 def testUseVariableAsTensor(self): 251 with self.test_session(): 252 var_x = variables.Variable(2.0) 253 var_y = variables.Variable(3.0) 254 variables.global_variables_initializer().run() 255 self.assertAllClose(2.0, var_x.eval()) 256 self.assertAllClose(3.0, var_y.eval()) 257 self.assertAllClose(5.0, math_ops.add(var_x, var_y).eval()) 258 259 def testZeroSizeVarSameAsConst(self): 260 with self.test_session(): 261 zero_size_var = variables.Variable(array_ops.zeros([0, 2])) 262 zero_size_const = array_ops.ones([2, 0]) 263 variable_mul = math_ops.matmul(zero_size_const, zero_size_var) 264 const_mul = math_ops.matmul( 265 zero_size_const, zero_size_const, transpose_b=True) 266 variables.global_variables_initializer().run() 267 variable_output = variable_mul.eval() 268 self.assertAllClose(const_mul.eval(), variable_output) 269 self.assertAllClose([[0., 0.], [0., 0.]], variable_output) 270 271 def testCachingDevice(self): 272 with self.test_session(): 273 var = variables.Variable(2.0) 274 self.assertEqual(var.device, var.value().device) 275 self.assertEqual(var.device, var.initialized_value().device) 276 277 var_cached = variables.Variable(2.0, caching_device="/job:foo") 278 self.assertFalse(var_cached.device.startswith("/job:foo")) 279 self.assertTrue(var_cached.value().device.startswith("/job:foo")) 280 281 def testCollections(self): 282 with self.test_session(): 283 var_x = variables.Variable(2.0) 284 var_y = variables.Variable(2.0, trainable=False) 285 var_z = variables.Variable(2.0, trainable=True) 286 var_t = variables.Variable( 287 2.0, 288 trainable=True, 289 collections=[ 290 ops.GraphKeys.TRAINABLE_VARIABLES, ops.GraphKeys.GLOBAL_VARIABLES 291 ]) 292 self.assertEqual([var_x, var_y, var_z, var_t], 293 variables.global_variables()) 294 self.assertEqual([var_x, var_z, var_t], variables.trainable_variables()) 295 296 def testCollectionsWithScope(self): 297 with self.test_session(): 298 with ops.name_scope("scope_1"): 299 var_x = variables.Variable(2.0) 300 with ops.name_scope("scope_2"): 301 var_y = variables.Variable(2.0) 302 303 self.assertEqual([var_x, var_y], variables.global_variables()) 304 self.assertEqual([var_x], variables.global_variables("scope_1")) 305 self.assertEqual([var_y], variables.global_variables("scope_2")) 306 307 self.assertEqual([var_x, var_y], variables.trainable_variables()) 308 self.assertEqual([var_x], variables.trainable_variables("scope_1")) 309 self.assertEqual([var_y], variables.trainable_variables("scope_2")) 310 311 def testOperators(self): 312 with self.test_session(): 313 var_f = variables.Variable([2.0]) 314 add = var_f + 0.0 315 radd = 1.0 + var_f 316 sub = var_f - 1.0 317 rsub = 1.0 - var_f 318 mul = var_f * 10.0 319 rmul = 10.0 * var_f 320 div = var_f / 10.0 321 rdiv = 10.0 / var_f 322 lt = var_f < 3.0 323 rlt = 3.0 < var_f 324 le = var_f <= 2.0 325 rle = 2.0 <= var_f 326 gt = var_f > 3.0 327 rgt = 3.0 > var_f 328 ge = var_f >= 2.0 329 rge = 2.0 >= var_f 330 neg = -var_f 331 abs_v = abs(var_f) 332 333 var_i = variables.Variable([20]) 334 mod = var_i % 7 335 rmod = 103 % var_i 336 337 var_b = variables.Variable([True, False]) 338 and_v = operator.and_(var_b, [True, True]) 339 or_v = operator.or_(var_b, [False, True]) 340 xor_v = operator.xor(var_b, [False, False]) 341 invert_v = ~var_b 342 343 rnd = np.random.rand(4, 4).astype("f") 344 var_t = variables.Variable(rnd) 345 slice_v = var_t[2, 0:0] 346 347 var_m = variables.Variable([[2.0, 3.0]]) 348 matmul = var_m.__matmul__([[10.0], [20.0]]) 349 rmatmul = var_m.__rmatmul__([[10.0], [20.0]]) 350 351 variables.global_variables_initializer().run() 352 self.assertAllClose([2.0], add.eval()) 353 self.assertAllClose([3.0], radd.eval()) 354 self.assertAllClose([1.0], sub.eval()) 355 self.assertAllClose([-1.0], rsub.eval()) 356 self.assertAllClose([20.0], mul.eval()) 357 self.assertAllClose([20.0], rmul.eval()) 358 self.assertAllClose([0.2], div.eval()) 359 self.assertAllClose([5.0], rdiv.eval()) 360 self.assertAllClose([-2.0], neg.eval()) 361 self.assertAllClose([2.0], abs_v.eval()) 362 self.assertAllClose([True], lt.eval()) 363 self.assertAllClose([False], rlt.eval()) 364 self.assertAllClose([True], le.eval()) 365 self.assertAllClose([True], rle.eval()) 366 self.assertAllClose([False], gt.eval()) 367 self.assertAllClose([True], rgt.eval()) 368 self.assertAllClose([True], ge.eval()) 369 self.assertAllClose([True], rge.eval()) 370 371 self.assertAllClose([6], mod.eval()) 372 self.assertAllClose([3], rmod.eval()) 373 374 self.assertAllClose([True, False], and_v.eval()) 375 self.assertAllClose([True, True], or_v.eval()) 376 self.assertAllClose([True, False], xor_v.eval()) 377 self.assertAllClose([False, True], invert_v.eval()) 378 379 self.assertAllClose(rnd[2, 0:0], slice_v.eval()) 380 381 self.assertAllClose([[80.0]], matmul.eval()) 382 self.assertAllClose([[20.0, 30.0], [40.0, 60.0]], rmatmul.eval()) 383 384 def testSession(self): 385 with self.test_session() as sess: 386 var = variables.Variable([1, 12]) 387 variables.global_variables_initializer().run() 388 self.assertAllClose([1, 12], sess.run(var)) 389 390 def testDevicePlacement(self): 391 with self.test_session() as sess: 392 with ops.device("/cpu:0"): 393 var = variables.Variable([1, 12]) 394 init_value = var.initialized_value() 395 init_op = variables.global_variables_initializer() 396 self.assertEqual(var.op.device, init_value.device) 397 self.assertEqual(var.op.device, init_op.device) 398 sess.run(init_op) 399 400 def testColocation(self): 401 with ops.device("/job:ps"): 402 var = variables.Variable(0, name="v") 403 with ops.device("/job:worker/task:7"): 404 assign_op = var.assign(1) 405 self.assertDeviceEqual("/job:ps", assign_op.device) 406 self.assertEqual([b"loc:@v"], assign_op.op.colocation_groups()) 407 408 def testInitializerFunction(self): 409 value = [[-42], [133.7]] 410 shape = [2, 1] 411 with self.test_session(): 412 initializer = lambda: constant_op.constant(value) 413 414 v1 = variables.Variable(initializer, dtype=dtypes.float32) 415 self.assertEqual(shape, v1.get_shape()) 416 self.assertEqual(shape, v1.shape) 417 self.assertAllClose(value, v1.initial_value.eval()) 418 with self.assertRaises(errors_impl.FailedPreconditionError): 419 v1.eval() 420 421 v2 = variables.Variable( 422 math_ops.negative(v1.initialized_value()), dtype=dtypes.float32) 423 self.assertEqual(v1.get_shape(), v2.get_shape()) 424 self.assertEqual(v1.shape, v2.shape) 425 self.assertAllClose(np.negative(value), v2.initial_value.eval()) 426 427 with self.assertRaises(errors_impl.FailedPreconditionError): 428 v2.eval() 429 variables.global_variables_initializer().run() 430 self.assertAllClose(np.negative(value), v2.eval()) 431 432 def testConstraintArg(self): 433 constraint = lambda x: x 434 v = variables.Variable( 435 lambda: constant_op.constant(1.), 436 constraint=constraint) 437 self.assertEqual(v.constraint, constraint) 438 439 constraint = 0 440 with self.assertRaises(ValueError): 441 v = variables.Variable( 442 lambda: constant_op.constant(1.), 443 constraint=constraint) 444 445 def testNoRefDataRace(self): 446 with self.test_session(): 447 a = variables.Variable([1, 2, 3], dtype=dtypes.float32) 448 b = variables.Variable(a.initialized_value() + 2) 449 c = variables.Variable(b.initialized_value() + 2) 450 variables.global_variables_initializer().run() 451 self.assertAllEqual(a.eval(), [1, 2, 3]) 452 self.assertAllEqual(b.eval(), [3, 4, 5]) 453 self.assertAllEqual(c.eval(), [5, 6, 7]) 454 455 def testInitializerFunctionDevicePlacement(self): 456 with self.test_session(): 457 initializer = lambda: constant_op.constant(42.0) 458 with ops.device("/cpu:100"): 459 v1 = variables.Variable(initializer, dtype=dtypes.float32, name="v1") 460 expected_device = "/device:CPU:100" 461 expected_group_v1 = [b"loc:@v1"] 462 self.assertEqual(expected_device, v1.op.device) 463 self.assertEqual(expected_group_v1, v1.op.colocation_groups()) 464 for i in v1.initializer.inputs: 465 self.assertEqual(expected_group_v1, i.op.colocation_groups()) 466 467 v2 = variables.Variable(initializer, dtype=dtypes.float32, name="v2") 468 expected_group_v2 = [b"loc:@v2"] 469 self.assertEqual(expected_group_v2, v2.op.colocation_groups()) 470 for i in v2.initializer.inputs: 471 self.assertEqual(expected_group_v2, i.op.colocation_groups()) 472 473 def testVariableDefInitializedInstances(self): 474 with ops.Graph().as_default(), self.test_session() as sess: 475 v_def = variables.Variable( 476 initial_value=constant_op.constant(3.0)).to_proto() 477 478 with ops.Graph().as_default(), self.test_session() as sess: 479 # v describes a VariableDef-based variable without an initial value. 480 v = variables.Variable(variable_def=v_def) 481 self.assertEqual(3.0, sess.run(v.initialized_value())) 482 483 # initialized_value should not rerun the initializer_op if the variable 484 # has already been initialized elsewhere. 485 sess.run(v.assign(1.0)) 486 self.assertEqual(1.0, v.initialized_value().eval()) 487 488 v_def.ClearField("initial_value_name") 489 with ops.Graph().as_default(), self.test_session() as sess: 490 # Restoring a legacy VariableDef proto that does not have 491 # initial_value_name set should still work. 492 v = variables.Variable(variable_def=v_def) 493 # We should also be able to re-export the variable to a new meta graph. 494 self.assertProtoEquals(v_def, v.to_proto()) 495 # But attempts to use initialized_value will result in errors. 496 with self.assertRaises(ValueError): 497 sess.run(v.initialized_value()) 498 499 def testLoad(self): 500 with self.test_session(): 501 var = variables.Variable(np.zeros((5, 5), np.float32)) 502 variables.global_variables_initializer().run() 503 var.load(np.ones((5, 5), np.float32)) 504 505 self.assertAllClose(np.ones((5, 5), np.float32), var.eval()) 506 507 def testRepr(self): 508 var = variables.Variable(np.zeros((5, 5), np.float32), name="noop") 509 self.assertEqual( 510 "<tf.Variable 'noop:0' shape=(5, 5) dtype=float32_ref>", 511 repr(var)) 512 513 def testVariableNamesPreserveNameScopesWithDefun(self): 514 @function.defun 515 def create_variable(): 516 with ops.name_scope("foo"): 517 v = variables.Variable(0.0, name="bar") 518 self.assertEqual(v.name, "foo/bar:0") 519 with ops.get_default_graph().as_default(): 520 create_variable() 521 522 523 class IsInitializedTest(test.TestCase): 524 525 def testNoVars(self): 526 with ops.Graph().as_default(), self.test_session() as sess: 527 uninited = variables.report_uninitialized_variables() 528 self.assertEqual(0, sess.run(uninited).size) 529 530 def testAssertVariablesInitialized(self): 531 with ops.Graph().as_default(), self.test_session() as sess: 532 v = variables.Variable([1, 2], name="v") 533 w = variables.Variable([3, 4], name="w") 534 _ = v, w 535 uninited = variables.report_uninitialized_variables() 536 self.assertAllEqual(np.array([b"v", b"w"]), sess.run(uninited)) 537 variables.global_variables_initializer().run() 538 self.assertEqual(0, sess.run(uninited).size) 539 540 def testVariableList(self): 541 with ops.Graph().as_default(), self.test_session() as sess: 542 v = variables.Variable([1, 2], name="v") 543 w = variables.Variable([3, 4], name="w") 544 uninited = variables.report_uninitialized_variables() 545 self.assertAllEqual(np.array([b"v", b"w"]), sess.run(uninited)) 546 sess.run(w.initializer) 547 self.assertAllEqual(np.array([b"v"]), sess.run(uninited)) 548 v.initializer.run() 549 self.assertEqual(0, sess.run(uninited).size) 550 551 def testZeroSizeVarInitialized(self): 552 with ops.Graph().as_default(), self.test_session() as sess: 553 v = variables.Variable(array_ops.zeros([0, 2]), name="v") 554 uninited = variables.report_uninitialized_variables() 555 v.initializer.run() # not strictly necessary 556 self.assertEqual(0, sess.run(uninited).size) 557 558 def testTrainingWithZeroSizeVar(self): 559 with ops.Graph().as_default(), self.test_session() as sess: 560 a = variables.Variable(array_ops.zeros([0, 2])) 561 b = variables.Variable(array_ops.ones([2, 2])) 562 objective = math_ops.reduce_sum(b + math_ops.matmul( 563 a, a, transpose_a=True)) 564 variables.global_variables_initializer().run() 565 do_opt = gradient_descent.GradientDescentOptimizer(0.1).minimize( 566 objective) 567 sess.run([do_opt]) 568 self.assertAllClose([[0.9, 0.9], [0.9, 0.9]], b.eval()) 569 570 571 class ObsoleteIsInitializedTest(test.TestCase): 572 573 def testNoVars(self): 574 with ops.Graph().as_default(): 575 self.assertEqual(None, variables.assert_variables_initialized()) 576 577 def testVariables(self): 578 with ops.Graph().as_default(), self.test_session() as sess: 579 v = variables.Variable([1, 2]) 580 w = variables.Variable([3, 4]) 581 _ = v, w 582 inited = variables.assert_variables_initialized() 583 with self.assertRaisesOpError("Attempting to use uninitialized value"): 584 sess.run(inited) 585 variables.global_variables_initializer().run() 586 sess.run(inited) 587 588 def testVariableList(self): 589 with ops.Graph().as_default(), self.test_session() as sess: 590 v = variables.Variable([1, 2]) 591 w = variables.Variable([3, 4]) 592 inited = variables.assert_variables_initialized([v]) 593 with self.assertRaisesOpError("Attempting to use uninitialized value"): 594 inited.op.run() 595 sess.run(w.initializer) 596 with self.assertRaisesOpError("Attempting to use uninitialized value"): 597 inited.op.run() 598 v.initializer.run() 599 inited.op.run() 600 601 602 class PartitionedVariableTest(test.TestCase): 603 604 def testPartitionedVariable(self): 605 with ops.Graph().as_default(): 606 v0 = variables.Variable([0]) 607 v1 = variables.Variable([1]) 608 v0._set_save_slice_info( 609 variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1])) 610 v1._set_save_slice_info( 611 variables.Variable.SaveSliceInfo(v0.name, [2], [1], [1])) 612 partitions = [2] 613 614 # Pass variable_list as [v1, v0] to ensure they are properly 615 # re-sorted to [v0, v1] based on their slice info offsets. 616 partitioned_variable = variables.PartitionedVariable( 617 name="two_vars", 618 shape=[2], 619 dtype=v0.dtype, 620 variable_list=[v1, v0], 621 partitions=partitions) 622 623 concatenated = ops.convert_to_tensor(partitioned_variable) 624 num_partitions = len(partitioned_variable) 625 iterated_partitions = list(partitioned_variable) 626 self.assertEqual(2, num_partitions) 627 self.assertEqual([v0, v1], iterated_partitions) 628 self.assertEqual([2], concatenated.get_shape()) 629 self.assertEqual([2], concatenated.shape) 630 631 def testPartitionedVariableFailures(self): 632 with ops.Graph().as_default(): 633 with self.assertRaisesRegexp(ValueError, "empty"): 634 variables.PartitionedVariable( 635 name="fail", 636 shape=2, 637 dtype=dtypes.int32, 638 variable_list=[], 639 partitions=[]) 640 641 with self.assertRaisesRegexp(ValueError, "must have a save_slice_info"): 642 v0 = variables.Variable([0]) 643 partitions = [1] 644 variables.PartitionedVariable( 645 name="two_vars", 646 shape=[1], 647 dtype=v0.dtype, 648 variable_list=[v0], 649 partitions=partitions) 650 651 with self.assertRaisesRegexp(ValueError, "full shapes must match"): 652 v0 = variables.Variable([0]) 653 v1 = variables.Variable([1]) 654 v0._set_save_slice_info( 655 variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1])) 656 v1._set_save_slice_info( 657 variables.Variable.SaveSliceInfo(v0.name, [2], [1], [1])) 658 partitions = [2] 659 660 variables.PartitionedVariable( 661 name="two_vars", 662 shape=[3], 663 dtype=v0.dtype, 664 variable_list=[v1, v0], 665 partitions=partitions) 666 667 with self.assertRaisesRegexp(ValueError, "must be positive"): 668 v0 = variables.Variable([0]) 669 v0._set_save_slice_info( 670 variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1])) 671 partitions = [0] 672 673 variables.PartitionedVariable( 674 name="two_vars", 675 shape=[2], 676 dtype=v0.dtype, 677 variable_list=[v0], 678 partitions=partitions) 679 680 681 class VariableContainerTest(test.TestCase): 682 683 def testContainer(self): 684 with ops.Graph().as_default(): 685 v0 = variables.Variable([0]) 686 with ops.container("l1"): 687 v1 = variables.Variable([1]) 688 with ops.container("l2"): 689 v2 = variables.Variable([2]) 690 special_v = gen_state_ops._variable( 691 shape=[1], 692 dtype=dtypes.float32, 693 name="VariableInL3", 694 container="l3", 695 shared_name="") 696 v3 = variables.Variable([3]) 697 v4 = variables.Variable([4]) 698 self.assertEqual(compat.as_bytes(""), v0.op.get_attr("container")) 699 self.assertEqual(compat.as_bytes("l1"), v1.op.get_attr("container")) 700 self.assertEqual(compat.as_bytes("l2"), v2.op.get_attr("container")) 701 self.assertEqual(compat.as_bytes("l3"), special_v.op.get_attr("container")) 702 self.assertEqual(compat.as_bytes("l1"), v3.op.get_attr("container")) 703 self.assertEqual(compat.as_bytes(""), v4.op.get_attr("container")) 704 705 706 if __name__ == "__main__": 707 test.main() 708