1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for partitioned_variables.py.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 from six.moves import xrange # pylint: disable=redefined-builtin 23 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.framework import ops 27 from tensorflow.python.ops import array_ops 28 from tensorflow.python.ops import init_ops 29 from tensorflow.python.ops import partitioned_variables 30 from tensorflow.python.ops import random_ops 31 from tensorflow.python.ops import variable_scope 32 from tensorflow.python.ops import variables 33 from tensorflow.python.platform import test 34 35 36 class PartitionerCreatorsTest(test.TestCase): 37 38 def testFixedSizePartitioner(self): 39 with self.test_session(): 40 partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0) 41 with variable_scope.variable_scope("root", partitioner=partitioner): 42 v0 = variable_scope.get_variable( 43 "v0", dtype=dtypes.float32, shape=(10, 10)) 44 v0_list = v0._get_variable_list() 45 v0_part = v0._get_partitions() 46 self.assertEqual(len(v0_list), 5) 47 self.assertAllEqual(v0_part, (5, 1)) 48 49 def testFixedSizePartitionerInt64(self): 50 with self.test_session(): 51 partitioner = partitioned_variables.fixed_size_partitioner(4, axis=0) 52 with variable_scope.variable_scope("root", partitioner=partitioner): 53 v0 = variable_scope.get_variable("v0", dtype=dtypes.int64, shape=[20]) 54 v0_list = v0._get_variable_list() 55 self.assertEqual(len(v0_list), 4) 56 57 def testResourceFixedSizePartitioner(self): 58 with self.test_session(): 59 partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0) 60 with variable_scope.variable_scope( 61 "root", partitioner=partitioner, use_resource=True): 62 v0 = variable_scope.get_variable( 63 "v0", dtype=dtypes.float32, shape=(10, 10)) 64 v0_list = v0._get_variable_list() 65 v0_part = v0._get_partitions() 66 self.assertEqual(len(v0_list), 5) 67 self.assertAllEqual(v0_part, (5, 1)) 68 69 def _testVariableAxisSizePartitioner(self, 70 name, 71 axis, 72 max_shard_bytes, 73 expected_axis_shards, 74 expected_partitions, 75 max_shards=None): 76 partitioner = partitioned_variables.variable_axis_size_partitioner( 77 axis=axis, max_shard_bytes=max_shard_bytes, max_shards=max_shards) 78 79 with variable_scope.variable_scope("root", partitioner=partitioner): 80 v0 = variable_scope.get_variable( 81 name, dtype=dtypes.float32, shape=(4, 8, 16, 32)) 82 v0_list = v0._get_variable_list() 83 v0_part = v0._get_partitions() 84 self.assertEqual(len(v0_list), expected_axis_shards) 85 self.assertAllEqual(v0_part, expected_partitions) 86 87 def testVariableAxisSizePartitioner(self): 88 with self.test_session(): 89 # Create a partitioned variable of shape (4, 8, 16, 32) type float32 90 # Bytes per slice along the given axes: 91 92 # 8 * 16 * 32 * sizeof(float32) = 16384 / slice on axis 0 93 # 4 * 16 * 32 * sizeof(float32) = 8192 / slice on axis 1 94 # 4 * 8 * 32 * sizeof(float32) = 4096 / slice on axis 2 95 # 4 * 8 * 16 * sizeof(float32) = 2048 / slice on axis 3 96 97 # Now partition it in different ways... 98 99 # No need to slice: bytes_per_slice * dim0 = 65536 < max_shard_bytes 100 self._testVariableAxisSizePartitioner( 101 "v0", 102 axis=0, 103 max_shard_bytes=131072, 104 expected_axis_shards=1, 105 expected_partitions=(1, 1, 1, 1)) 106 107 # Slice exactly once: bytes_per_slice * dim1 = 65536 = max_shard_bytes 108 self._testVariableAxisSizePartitioner( 109 "v1", 110 axis=1, 111 max_shard_bytes=65536, 112 expected_axis_shards=1, 113 expected_partitions=(1, 1, 1, 1)) 114 115 # Slice into 2 parts: 116 # bytes_per_slice = 4096 117 # slices_per_shard = 32768 / 4096 = 8 118 # axis_shards = 16 / 8 = 2 119 self._testVariableAxisSizePartitioner( 120 "v2", 121 axis=2, 122 max_shard_bytes=32768, 123 expected_axis_shards=2, 124 expected_partitions=(1, 1, 2, 1)) 125 126 # This partitioner makes sure we maximize the number of shards along 127 # axis 3. Slice it into 32 parts: 128 # bytes_per_slice = 2048 129 # slices_per_shard = 2048 / 2048 = 1 130 # axis_shards = 32 / 1 = 32 131 self._testVariableAxisSizePartitioner( 132 "v3a", 133 axis=3, 134 max_shard_bytes=2048, 135 expected_axis_shards=32, 136 expected_partitions=(1, 1, 1, 32)) 137 138 # This partitioner makes sure we do not go past the bound of allowable 139 # number of shards along axis 3. 140 # Slice into 32 parts: 141 # bytes_per_slice = 2048 142 # slices_per_shard = max(1, 1024 / 2048) = 1 143 # axis_shards = 32 / 1 = 32 144 # Slice into max of 32 parts because: max_shard_bytes < bytes_per_slice 145 self._testVariableAxisSizePartitioner( 146 "v3b", 147 axis=3, 148 max_shard_bytes=1024, 149 expected_axis_shards=32, 150 expected_partitions=(1, 1, 1, 32)) 151 152 # Specify max_shards so that it won't affect sharding. 153 self._testVariableAxisSizePartitioner( 154 "v3c", 155 axis=3, 156 max_shard_bytes=1024, 157 expected_axis_shards=32, 158 expected_partitions=(1, 1, 1, 32), 159 max_shards=33) 160 161 # Specify max_shards so that it will affect sharding. 162 self._testVariableAxisSizePartitioner( 163 "v3d", 164 axis=3, 165 max_shard_bytes=1024, 166 expected_axis_shards=2, 167 expected_partitions=(1, 1, 1, 2), 168 max_shards=2) 169 170 # Use the partitioner with strings 171 partitioner_axis3_str = partitioned_variables.variable_axis_size_partitioner( # pylint: disable=line-too-long 172 axis=3, 173 max_shard_bytes=32768, 174 bytes_per_string_element=8) 175 176 with variable_scope.variable_scope( 177 "root", partitioner=partitioner_axis3_str): 178 v3str = variable_scope.get_variable( 179 "v3str", 180 initializer=np.array([""] * 4 * 8 * 16 * 32).reshape(4, 8, 16, 32), 181 dtype=dtypes.string, 182 shape=(4, 8, 16, 32)) 183 v3str_list = v3str._get_variable_list() 184 v3str_part = v3str._get_partitions() 185 186 # Now the estimated bytes_per_slice = 4*8*16*bytes_per_string_element 187 # which is equal to 4096. Setting a max_shard_bytes of 32768 188 # and we should get a split of 4. 189 # Slice into 4 parts: 190 # bytes_per_slice = 4096 191 # slices_per_shard = 32768 / 4096 = 8 192 # axis_shards = 32 / 8 = 4 193 self.assertEqual(len(v3str_list), 4) 194 self.assertAllEqual(v3str_part, (1, 1, 1, 4)) 195 196 def _testMinMaxVariablePartitioner(self, max_partitions, axis, min_slice_size, 197 var_name, var_shape, expected_axis_shards, 198 expected_partitions): 199 partitioner = partitioned_variables.min_max_variable_partitioner( 200 max_partitions=max_partitions, axis=axis, min_slice_size=min_slice_size) 201 with variable_scope.variable_scope("root", partitioner=partitioner): 202 v0 = variable_scope.get_variable( 203 var_name, dtype=dtypes.float32, shape=var_shape) 204 v0_list = v0._get_variable_list() 205 v0_part = v0._get_partitions() 206 self.assertEqual(len(v0_list), expected_axis_shards) 207 self.assertAllEqual(v0_part, expected_partitions) 208 209 def testMinMaxVariablePartitioner(self): 210 with self.test_session(): 211 # Partitioning a variable of shape=[2048] with a minimum of 2K per slice. 212 self._testMinMaxVariablePartitioner( 213 max_partitions=100, 214 axis=0, 215 min_slice_size=2 << 10, 216 var_name="v0_0", 217 var_shape=[2048], 218 expected_axis_shards=4, 219 expected_partitions=[4]) 220 221 # Partitioning a variable of shape=[2048, 1024] with a minimum of 256K per 222 # slice. 223 self._testMinMaxVariablePartitioner( 224 max_partitions=100, 225 axis=0, 226 min_slice_size=256 << 10, 227 var_name="v0", 228 var_shape=[2048, 1024], 229 expected_axis_shards=32, 230 expected_partitions=[32, 1]) 231 232 # max_partitions restricts partitioning of the variable. 233 self._testMinMaxVariablePartitioner( 234 max_partitions=16, 235 axis=0, 236 min_slice_size=256 << 10, 237 var_name="v1_max", 238 var_shape=[2048, 1024], 239 expected_axis_shards=16, 240 expected_partitions=[16, 1]) 241 self._testMinMaxVariablePartitioner( 242 max_partitions=1, 243 axis=0, 244 min_slice_size=256 << 10, 245 var_name="v2_max", 246 var_shape=[2048, 1024], 247 expected_axis_shards=1, 248 expected_partitions=[1, 1]) 249 250 # Reducing/Increasing min_slice_size proportionately increases/reduces the 251 # number of partitions. 252 self._testMinMaxVariablePartitioner( 253 max_partitions=100, 254 axis=0, 255 min_slice_size=128 << 10, 256 var_name="v3_slice", 257 var_shape=[2048, 1024], 258 expected_axis_shards=64, 259 expected_partitions=[64, 1]) 260 self._testMinMaxVariablePartitioner( 261 max_partitions=100, 262 axis=0, 263 min_slice_size=512 << 10, 264 var_name="v4_slice", 265 var_shape=[2048, 1024], 266 expected_axis_shards=16, 267 expected_partitions=[16, 1]) 268 269 # Partitioning the variable along a different axis. 270 self._testMinMaxVariablePartitioner( 271 max_partitions=100, 272 axis=1, 273 min_slice_size=256 << 10, 274 var_name="v5_axis", 275 var_shape=[64, 1024, 1, 3], 276 expected_axis_shards=3, 277 expected_partitions=[1, 3, 1, 1]) 278 self._testMinMaxVariablePartitioner( 279 max_partitions=100, 280 axis=3, 281 min_slice_size=256 << 10, 282 var_name="v6_axis", 283 var_shape=[64, 1024, 1, 3], 284 expected_axis_shards=3, 285 expected_partitions=[1, 1, 1, 3]) 286 287 # Can not partition the variable more than what its shape allows. 288 self._testMinMaxVariablePartitioner( 289 max_partitions=100, 290 axis=0, 291 min_slice_size=256 << 10, 292 var_name="v7_shape", 293 var_shape=[16, 128, 1024], 294 expected_axis_shards=16, 295 expected_partitions=[16, 1, 1]) 296 self._testMinMaxVariablePartitioner( 297 max_partitions=100, 298 axis=0, 299 min_slice_size=256 << 10, 300 var_name="v8_shape", 301 var_shape=[4, 512, 1024], 302 expected_axis_shards=4, 303 expected_partitions=[4, 1, 1]) 304 305 306 def _IotaInitializer(shape, dtype=dtypes.float32, partition_info=None): 307 assert dtype == dtypes.float32 308 if len(shape) == 1: 309 return range(shape[0]) 310 else: 311 val = _IotaInitializer(shape[1:], dtype) 312 return [[(10**i) * v for v in val] for i in range(shape[0])] 313 314 315 class PartitionedVariablesTestCase(test.TestCase): 316 317 def _TestSaveSpec(self, slices, expected_specs): 318 self.assertEqual(len(expected_specs), len(slices)) 319 for i in xrange(len(expected_specs)): 320 self.assertEquals(expected_specs[i], slices[i]._save_slice_info.spec) 321 322 def testVecConstantInit(self): 323 with self.test_session(): 324 rnd_par = constant_op.constant([1, 2, 3, 4]) 325 vs = partitioned_variables.create_partitioned_variables([4], [4], rnd_par) 326 variables.global_variables_initializer().run() 327 val = array_ops.concat(vs, 0).eval() 328 rnd = rnd_par.eval() 329 self.assertAllClose(rnd, val) 330 self.assertEqual([dtypes.int32] * 4, [v.dtype.base_dtype for v in vs]) 331 self._TestSaveSpec(vs, ["4 0,1", "4 1,1", "4 2,1", "4 3,1"]) 332 333 def testConstantInit(self): 334 with self.test_session(): 335 rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]]) 336 vs = partitioned_variables.create_partitioned_variables([2, 4], [1, 2], 337 rnd_par) 338 variables.global_variables_initializer().run() 339 val = array_ops.concat(vs, 1).eval() 340 rnd = rnd_par.eval() 341 self.assertAllClose(rnd, val) 342 self.assertEqual([dtypes.int32] * 2, [v.dtype.base_dtype for v in vs]) 343 self._TestSaveSpec(vs, ["2 4 0,2:0,2", "2 4 0,2:2,2"]) 344 345 def _testNameHelper(self, use_resource=False): 346 with self.test_session(): 347 rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]]) 348 with variable_scope.variable_scope("hi", use_resource=use_resource): 349 vs1 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2], 350 rnd_par) 351 vs2 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2], 352 rnd_par) 353 variables.global_variables_initializer().run() 354 var1_name = vs1[0]._save_slice_info.full_name 355 var2_name = vs2[0]._save_slice_info.full_name 356 self.assertEqual("hi/PartitionedVariable", var1_name) 357 self.assertEqual("hi/PartitionedVariable_1", var2_name) 358 self.assertEqual(var1_name + "/part_0:0", vs1[0].name) 359 self.assertEqual(var1_name + "/part_1:0", vs1[1].name) 360 self.assertEqual(var2_name + "/part_0:0", vs2[0].name) 361 self.assertEqual(var2_name + "/part_1:0", vs2[1].name) 362 # Test same variable. 363 with self.test_session(): 364 rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]]) 365 with variable_scope.variable_scope( 366 "hola", use_resource=use_resource) as vs: 367 vs1 = partitioned_variables.create_partitioned_variables( 368 [2, 4], [1, 2], rnd_par, dtype=dtypes.int32) 369 with variable_scope.variable_scope( 370 vs, reuse=True, use_resource=use_resource): 371 vs2 = partitioned_variables.create_partitioned_variables( 372 [2, 4], [1, 2], rnd_par, dtype=dtypes.int32) 373 variables.global_variables_initializer().run() 374 var1_name = vs1[0]._save_slice_info.full_name 375 var2_name = vs2[0]._save_slice_info.full_name 376 self.assertEqual("hola/PartitionedVariable", var1_name) 377 self.assertEqual("hola/PartitionedVariable", var2_name) 378 self.assertEqual(var1_name + "/part_0:0", vs1[0].name) 379 self.assertEqual(var1_name + "/part_1:0", vs1[1].name) 380 self.assertEqual(var2_name + "/part_0:0", vs2[0].name) 381 self.assertEqual(var2_name + "/part_1:0", vs2[1].name) 382 # Test name_scope 383 with self.test_session(): 384 rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]]) 385 with ops.name_scope("ola"): 386 vs1 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2], 387 rnd_par) 388 vs2 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2], 389 rnd_par) 390 variables.global_variables_initializer().run() 391 var1_name = vs1[0]._save_slice_info.full_name 392 var2_name = vs2[0]._save_slice_info.full_name 393 # Currently, the name scope 'ola' has no effect. 394 self.assertEqual("PartitionedVariable", var1_name) 395 self.assertEqual("PartitionedVariable_1", var2_name) 396 self.assertEqual(var1_name + "/part_0:0", vs1[0].name) 397 self.assertEqual(var1_name + "/part_1:0", vs1[1].name) 398 self.assertEqual(var2_name + "/part_0:0", vs2[0].name) 399 self.assertEqual(var2_name + "/part_1:0", vs2[1].name) 400 401 def testName(self): 402 self._testNameHelper(use_resource=False) 403 404 def testResourceName(self): 405 self._testNameHelper(use_resource=True) 406 407 def testRandomInitValue(self): 408 with self.test_session(): 409 rnd = variables.Variable(random_ops.random_uniform([200, 40])) 410 vs = partitioned_variables.create_partitioned_variables( 411 rnd.get_shape(), [1, 10], rnd.initialized_value()) 412 variables.global_variables_initializer().run() 413 val = array_ops.concat(vs, 1).eval() 414 rnd = rnd.eval() 415 self.assertAllClose(rnd, val) 416 self.assertEqual([dtypes.float32] * 10, [v.dtype.base_dtype for v in vs]) 417 self._TestSaveSpec(vs, [ 418 "200 40 0,200:0,4", "200 40 0,200:4,4", "200 40 0,200:8,4", 419 "200 40 0,200:12,4", "200 40 0,200:16,4", "200 40 0,200:20,4", 420 "200 40 0,200:24,4", "200 40 0,200:28,4", "200 40 0,200:32,4", 421 "200 40 0,200:36,4" 422 ]) 423 424 def testRandomInitUnevenPartitions(self): 425 with self.test_session(): 426 rnd = variables.Variable( 427 random_ops.random_uniform([20, 43], dtype=dtypes.float64)) 428 var_lists = [ 429 partitioned_variables.create_partitioned_variables( 430 rnd.get_shape(), [1, i], rnd.initialized_value()) 431 for i in xrange(1, 10) 432 ] 433 variables.global_variables_initializer().run() 434 rnd_val = rnd.eval() 435 # Only check the slice save specs for the first 5 tf. 436 save_specs = [ 437 # One slice 438 ["20 43 0,20:0,43"], 439 # Two slices 440 ["20 43 0,20:0,22", "20 43 0,20:22,21"], 441 # Three slices 442 ["20 43 0,20:0,15", "20 43 0,20:15,14", "20 43 0,20:29,14"], 443 # Four slices 444 [ 445 "20 43 0,20:0,11", "20 43 0,20:11,11", "20 43 0,20:22,11", 446 "20 43 0,20:33,10" 447 ], 448 # Five slices 449 [ 450 "20 43 0,20:0,9", "20 43 0,20:9,9", "20 43 0,20:18,9", 451 "20 43 0,20:27,8", "20 43 0,20:35,8" 452 ] 453 ] 454 for i, vs in enumerate(var_lists): 455 var_val = array_ops.concat(vs, 1).eval() 456 self.assertAllClose(rnd_val, var_val) 457 self.assertEqual([dtypes.float64] * len(vs), 458 [v.dtype.base_dtype for v in vs]) 459 if i < len(save_specs): 460 self._TestSaveSpec(vs, save_specs[i]) 461 462 def testDegenerate(self): 463 with self.test_session(): 464 rnd = variables.Variable(random_ops.random_uniform([10, 43])) 465 vs = partitioned_variables.create_partitioned_variables( 466 rnd.get_shape(), [1, 1], rnd.initialized_value()) 467 variables.global_variables_initializer().run() 468 val = array_ops.concat(vs, 0).eval() 469 rnd = rnd.eval() 470 self.assertAllClose(rnd, val) 471 self._TestSaveSpec(vs, ["10 43 0,10:0,43"]) 472 473 def testSliceSizeOne(self): 474 with self.test_session(): 475 rnd = variables.Variable(random_ops.random_uniform([10, 43])) 476 vs = partitioned_variables.create_partitioned_variables( 477 rnd.get_shape(), [10, 1], rnd.initialized_value()) 478 variables.global_variables_initializer().run() 479 val = array_ops.concat(vs, 0).eval() 480 rnd = rnd.eval() 481 self.assertAllClose(rnd, val) 482 self._TestSaveSpec(vs, [ 483 "10 43 0,1:0,43", "10 43 1,1:0,43", "10 43 2,1:0,43", 484 "10 43 3,1:0,43", "10 43 4,1:0,43", "10 43 5,1:0,43", 485 "10 43 6,1:0,43", "10 43 7,1:0,43", "10 43 8,1:0,43", "10 43 9,1:0,43" 486 ]) 487 488 def testIotaInitializer(self): 489 self.assertAllClose([0., 1., 2., 3.], _IotaInitializer([4])) 490 self.assertAllClose([[0., 1.], [0., 10.], [0., 100.], [0., 1000.]], 491 _IotaInitializer([4, 2])) 492 with self.test_session(): 493 vs = partitioned_variables.create_partitioned_variables([13, 5], [3, 1], 494 _IotaInitializer) 495 variables.global_variables_initializer().run() 496 slice0 = _IotaInitializer([5, 5]) 497 slice1 = _IotaInitializer([4, 5]) 498 slice2 = _IotaInitializer([4, 5]) 499 val = array_ops.concat(vs, 0).eval() 500 self.assertAllClose(slice0 + slice1 + slice2, val) 501 self._TestSaveSpec(vs, ["13 5 0,5:0,5", "13 5 5,4:0,5", "13 5 9,4:0,5"]) 502 503 def testRandomInitializer(self): 504 # Sanity check that the slices uses a different seed when using a random 505 # initializer function. 506 with self.test_session(): 507 var0, var1 = partitioned_variables.create_partitioned_variables( 508 [20, 12], [1, 2], init_ops.random_uniform_initializer()) 509 variables.global_variables_initializer().run() 510 val0, val1 = var0.eval().flatten(), var1.eval().flatten() 511 self.assertTrue(np.linalg.norm(val0 - val1) > 1e-6) 512 # Negative test that proves that slices have the same values if 513 # the random initializer uses a seed. 514 with self.test_session(): 515 var0, var1 = partitioned_variables.create_partitioned_variables( 516 [20, 12], [1, 2], init_ops.random_uniform_initializer(seed=201)) 517 variables.global_variables_initializer().run() 518 val0, val1 = var0.eval().flatten(), var1.eval().flatten() 519 self.assertAllClose(val0, val1) 520 521 def testSomeErrors(self): 522 with self.test_session(): 523 rnd = variables.Variable(random_ops.random_uniform([10, 43])) 524 with self.assertRaises(ValueError): 525 partitioned_variables.create_partitioned_variables( 526 [10], [1, 1], rnd.initialized_value()) 527 with self.assertRaises(ValueError): 528 partitioned_variables.create_partitioned_variables( 529 [10, 20], [1], rnd.initialized_value()) 530 with self.assertRaises(ValueError): 531 partitioned_variables.create_partitioned_variables( 532 [10, 43], [1], rnd.initialized_value()) 533 with self.assertRaises(ValueError): 534 partitioned_variables.create_partitioned_variables( 535 [10, 43], [1, 2, 3], rnd.initialized_value()) 536 with self.assertRaises(ValueError): 537 partitioned_variables.create_partitioned_variables( 538 [10, 43], [11, 1], rnd.initialized_value()) 539 with self.assertRaises(ValueError): 540 partitioned_variables.create_partitioned_variables( 541 [10, 43], [20, 1], rnd.initialized_value()) 542 with self.assertRaises(ValueError): 543 partitioned_variables.create_partitioned_variables( 544 [10, 43], [1, 50], rnd.initialized_value()) 545 546 def testControlDepsNone(self): 547 with self.test_session() as session: 548 c = constant_op.constant(1.0) 549 with ops.control_dependencies([c]): 550 # d get the control dependency. 551 d = constant_op.constant(2.0) 552 # Partitioned variables do not. 553 var_x = variable_scope.get_variable( 554 "x", 555 shape=[2], 556 initializer=init_ops.ones_initializer(), 557 partitioner=partitioned_variables.variable_axis_size_partitioner(4)) 558 559 ops_before_read = session.graph.get_operations() 560 var_x.as_tensor() # Caches the ops for subsequent reads. 561 reading_ops = [ 562 op for op in session.graph.get_operations() 563 if op not in ops_before_read 564 ] 565 566 self.assertEqual([c.op], d.op.control_inputs) 567 # Tests that no control dependencies are added to reading a partitioned 568 # variable which is similar to reading a variable. 569 for op in reading_ops: 570 self.assertEqual([], op.control_inputs) 571 572 def testConcat(self): 573 with self.test_session() as session: 574 var_x = variable_scope.get_variable( 575 "x", 576 initializer=constant_op.constant([1., 2.]), 577 partitioner=partitioned_variables.variable_axis_size_partitioner(4)) 578 579 c = constant_op.constant(1.0) 580 with ops.control_dependencies([c]): 581 ops_before_concat = session.graph.get_operations() 582 value = var_x._concat() # pylint: disable=protected-access 583 concat_ops = [ 584 op for op in session.graph.get_operations() 585 if op not in ops_before_concat 586 ] 587 588 concat_control_inputs = [ 589 ci for op in concat_ops for ci in op.control_inputs 590 ] 591 self.assertTrue( 592 c.op in concat_control_inputs, 593 "var_x._concat() should get control dependencies from its scope.") 594 variables.global_variables_initializer().run() 595 self.assertAllClose(value.eval(), var_x.as_tensor().eval()) 596 597 598 if __name__ == "__main__": 599 test.main() 600