Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for partitioned_variables.py."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 from six.moves import xrange  # pylint: disable=redefined-builtin
     23 
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import init_ops
     29 from tensorflow.python.ops import partitioned_variables
     30 from tensorflow.python.ops import random_ops
     31 from tensorflow.python.ops import variable_scope
     32 from tensorflow.python.ops import variables
     33 from tensorflow.python.platform import test
     34 
     35 
     36 class PartitionerCreatorsTest(test.TestCase):
     37 
     38   def testFixedSizePartitioner(self):
     39     with self.test_session():
     40       partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0)
     41       with variable_scope.variable_scope("root", partitioner=partitioner):
     42         v0 = variable_scope.get_variable(
     43             "v0", dtype=dtypes.float32, shape=(10, 10))
     44         v0_list = v0._get_variable_list()
     45         v0_part = v0._get_partitions()
     46         self.assertEqual(len(v0_list), 5)
     47         self.assertAllEqual(v0_part, (5, 1))
     48 
     49   def testFixedSizePartitionerInt64(self):
     50     with self.test_session():
     51       partitioner = partitioned_variables.fixed_size_partitioner(4, axis=0)
     52       with variable_scope.variable_scope("root", partitioner=partitioner):
     53         v0 = variable_scope.get_variable("v0", dtype=dtypes.int64, shape=[20])
     54         v0_list = v0._get_variable_list()
     55         self.assertEqual(len(v0_list), 4)
     56 
     57   def testResourceFixedSizePartitioner(self):
     58     with self.test_session():
     59       partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0)
     60       with variable_scope.variable_scope(
     61           "root", partitioner=partitioner, use_resource=True):
     62         v0 = variable_scope.get_variable(
     63             "v0", dtype=dtypes.float32, shape=(10, 10))
     64         v0_list = v0._get_variable_list()
     65         v0_part = v0._get_partitions()
     66         self.assertEqual(len(v0_list), 5)
     67         self.assertAllEqual(v0_part, (5, 1))
     68 
     69   def _testVariableAxisSizePartitioner(self,
     70                                        name,
     71                                        axis,
     72                                        max_shard_bytes,
     73                                        expected_axis_shards,
     74                                        expected_partitions,
     75                                        max_shards=None):
     76     partitioner = partitioned_variables.variable_axis_size_partitioner(
     77         axis=axis, max_shard_bytes=max_shard_bytes, max_shards=max_shards)
     78 
     79     with variable_scope.variable_scope("root", partitioner=partitioner):
     80       v0 = variable_scope.get_variable(
     81           name, dtype=dtypes.float32, shape=(4, 8, 16, 32))
     82       v0_list = v0._get_variable_list()
     83       v0_part = v0._get_partitions()
     84       self.assertEqual(len(v0_list), expected_axis_shards)
     85       self.assertAllEqual(v0_part, expected_partitions)
     86 
     87   def testVariableAxisSizePartitioner(self):
     88     with self.test_session():
     89       # Create a partitioned variable of shape (4, 8, 16, 32) type float32
     90       # Bytes per slice along the given axes:
     91 
     92       # 8 * 16 * 32 * sizeof(float32) = 16384 / slice on axis 0
     93       # 4 * 16 * 32 * sizeof(float32) = 8192 / slice on axis 1
     94       # 4 * 8 * 32 * sizeof(float32) = 4096 / slice on axis 2
     95       # 4 * 8 * 16 * sizeof(float32) = 2048 / slice on axis 3
     96 
     97       # Now partition it in different ways...
     98 
     99       # No need to slice: bytes_per_slice * dim0 = 65536 < max_shard_bytes
    100       self._testVariableAxisSizePartitioner(
    101           "v0",
    102           axis=0,
    103           max_shard_bytes=131072,
    104           expected_axis_shards=1,
    105           expected_partitions=(1, 1, 1, 1))
    106 
    107       # Slice exactly once: bytes_per_slice * dim1 = 65536 = max_shard_bytes
    108       self._testVariableAxisSizePartitioner(
    109           "v1",
    110           axis=1,
    111           max_shard_bytes=65536,
    112           expected_axis_shards=1,
    113           expected_partitions=(1, 1, 1, 1))
    114 
    115       # Slice into 2 parts:
    116       # bytes_per_slice = 4096
    117       # slices_per_shard = 32768 / 4096 = 8
    118       # axis_shards = 16 / 8 = 2
    119       self._testVariableAxisSizePartitioner(
    120           "v2",
    121           axis=2,
    122           max_shard_bytes=32768,
    123           expected_axis_shards=2,
    124           expected_partitions=(1, 1, 2, 1))
    125 
    126       # This partitioner makes sure we maximize the number of shards along
    127       # axis 3. Slice it into 32 parts:
    128       # bytes_per_slice = 2048
    129       # slices_per_shard = 2048 / 2048 = 1
    130       # axis_shards = 32 / 1 = 32
    131       self._testVariableAxisSizePartitioner(
    132           "v3a",
    133           axis=3,
    134           max_shard_bytes=2048,
    135           expected_axis_shards=32,
    136           expected_partitions=(1, 1, 1, 32))
    137 
    138       # This partitioner makes sure we do not go past the bound of allowable
    139       # number of shards along axis 3.
    140       # Slice into 32 parts:
    141       # bytes_per_slice = 2048
    142       # slices_per_shard = max(1, 1024 / 2048) = 1
    143       # axis_shards = 32 / 1 = 32
    144       # Slice into max of 32 parts because: max_shard_bytes < bytes_per_slice
    145       self._testVariableAxisSizePartitioner(
    146           "v3b",
    147           axis=3,
    148           max_shard_bytes=1024,
    149           expected_axis_shards=32,
    150           expected_partitions=(1, 1, 1, 32))
    151 
    152       # Specify max_shards so that it won't affect sharding.
    153       self._testVariableAxisSizePartitioner(
    154           "v3c",
    155           axis=3,
    156           max_shard_bytes=1024,
    157           expected_axis_shards=32,
    158           expected_partitions=(1, 1, 1, 32),
    159           max_shards=33)
    160 
    161       # Specify max_shards so that it will affect sharding.
    162       self._testVariableAxisSizePartitioner(
    163           "v3d",
    164           axis=3,
    165           max_shard_bytes=1024,
    166           expected_axis_shards=2,
    167           expected_partitions=(1, 1, 1, 2),
    168           max_shards=2)
    169 
    170       # Use the partitioner with strings
    171       partitioner_axis3_str = partitioned_variables.variable_axis_size_partitioner(  # pylint: disable=line-too-long
    172           axis=3,
    173           max_shard_bytes=32768,
    174           bytes_per_string_element=8)
    175 
    176       with variable_scope.variable_scope(
    177           "root", partitioner=partitioner_axis3_str):
    178         v3str = variable_scope.get_variable(
    179             "v3str",
    180             initializer=np.array([""] * 4 * 8 * 16 * 32).reshape(4, 8, 16, 32),
    181             dtype=dtypes.string,
    182             shape=(4, 8, 16, 32))
    183         v3str_list = v3str._get_variable_list()
    184         v3str_part = v3str._get_partitions()
    185 
    186         # Now the estimated bytes_per_slice = 4*8*16*bytes_per_string_element
    187         # which is equal to 4096.  Setting a max_shard_bytes of 32768
    188         # and we should get a split of 4.
    189         # Slice into 4 parts:
    190         # bytes_per_slice = 4096
    191         # slices_per_shard = 32768 / 4096 = 8
    192         # axis_shards = 32 / 8 = 4
    193         self.assertEqual(len(v3str_list), 4)
    194         self.assertAllEqual(v3str_part, (1, 1, 1, 4))
    195 
    196   def _testMinMaxVariablePartitioner(self, max_partitions, axis, min_slice_size,
    197                                      var_name, var_shape, expected_axis_shards,
    198                                      expected_partitions):
    199     partitioner = partitioned_variables.min_max_variable_partitioner(
    200         max_partitions=max_partitions, axis=axis, min_slice_size=min_slice_size)
    201     with variable_scope.variable_scope("root", partitioner=partitioner):
    202       v0 = variable_scope.get_variable(
    203           var_name, dtype=dtypes.float32, shape=var_shape)
    204       v0_list = v0._get_variable_list()
    205       v0_part = v0._get_partitions()
    206       self.assertEqual(len(v0_list), expected_axis_shards)
    207       self.assertAllEqual(v0_part, expected_partitions)
    208 
    209   def testMinMaxVariablePartitioner(self):
    210     with self.test_session():
    211       # Partitioning a variable of shape=[2048] with a minimum of 2K per slice.
    212       self._testMinMaxVariablePartitioner(
    213           max_partitions=100,
    214           axis=0,
    215           min_slice_size=2 << 10,
    216           var_name="v0_0",
    217           var_shape=[2048],
    218           expected_axis_shards=4,
    219           expected_partitions=[4])
    220 
    221       # Partitioning a variable of shape=[2048, 1024] with a minimum of 256K per
    222       # slice.
    223       self._testMinMaxVariablePartitioner(
    224           max_partitions=100,
    225           axis=0,
    226           min_slice_size=256 << 10,
    227           var_name="v0",
    228           var_shape=[2048, 1024],
    229           expected_axis_shards=32,
    230           expected_partitions=[32, 1])
    231 
    232       # max_partitions restricts partitioning of the variable.
    233       self._testMinMaxVariablePartitioner(
    234           max_partitions=16,
    235           axis=0,
    236           min_slice_size=256 << 10,
    237           var_name="v1_max",
    238           var_shape=[2048, 1024],
    239           expected_axis_shards=16,
    240           expected_partitions=[16, 1])
    241       self._testMinMaxVariablePartitioner(
    242           max_partitions=1,
    243           axis=0,
    244           min_slice_size=256 << 10,
    245           var_name="v2_max",
    246           var_shape=[2048, 1024],
    247           expected_axis_shards=1,
    248           expected_partitions=[1, 1])
    249 
    250       # Reducing/Increasing min_slice_size proportionately increases/reduces the
    251       # number of partitions.
    252       self._testMinMaxVariablePartitioner(
    253           max_partitions=100,
    254           axis=0,
    255           min_slice_size=128 << 10,
    256           var_name="v3_slice",
    257           var_shape=[2048, 1024],
    258           expected_axis_shards=64,
    259           expected_partitions=[64, 1])
    260       self._testMinMaxVariablePartitioner(
    261           max_partitions=100,
    262           axis=0,
    263           min_slice_size=512 << 10,
    264           var_name="v4_slice",
    265           var_shape=[2048, 1024],
    266           expected_axis_shards=16,
    267           expected_partitions=[16, 1])
    268 
    269       # Partitioning the variable along a different axis.
    270       self._testMinMaxVariablePartitioner(
    271           max_partitions=100,
    272           axis=1,
    273           min_slice_size=256 << 10,
    274           var_name="v5_axis",
    275           var_shape=[64, 1024, 1, 3],
    276           expected_axis_shards=3,
    277           expected_partitions=[1, 3, 1, 1])
    278       self._testMinMaxVariablePartitioner(
    279           max_partitions=100,
    280           axis=3,
    281           min_slice_size=256 << 10,
    282           var_name="v6_axis",
    283           var_shape=[64, 1024, 1, 3],
    284           expected_axis_shards=3,
    285           expected_partitions=[1, 1, 1, 3])
    286 
    287       # Can not partition the variable more than what its shape allows.
    288       self._testMinMaxVariablePartitioner(
    289           max_partitions=100,
    290           axis=0,
    291           min_slice_size=256 << 10,
    292           var_name="v7_shape",
    293           var_shape=[16, 128, 1024],
    294           expected_axis_shards=16,
    295           expected_partitions=[16, 1, 1])
    296       self._testMinMaxVariablePartitioner(
    297           max_partitions=100,
    298           axis=0,
    299           min_slice_size=256 << 10,
    300           var_name="v8_shape",
    301           var_shape=[4, 512, 1024],
    302           expected_axis_shards=4,
    303           expected_partitions=[4, 1, 1])
    304 
    305 
    306 def _IotaInitializer(shape, dtype=dtypes.float32, partition_info=None):
    307   assert dtype == dtypes.float32
    308   if len(shape) == 1:
    309     return range(shape[0])
    310   else:
    311     val = _IotaInitializer(shape[1:], dtype)
    312     return [[(10**i) * v for v in val] for i in range(shape[0])]
    313 
    314 
    315 class PartitionedVariablesTestCase(test.TestCase):
    316 
    317   def _TestSaveSpec(self, slices, expected_specs):
    318     self.assertEqual(len(expected_specs), len(slices))
    319     for i in xrange(len(expected_specs)):
    320       self.assertEquals(expected_specs[i], slices[i]._save_slice_info.spec)
    321 
    322   def testVecConstantInit(self):
    323     with self.test_session():
    324       rnd_par = constant_op.constant([1, 2, 3, 4])
    325       vs = partitioned_variables.create_partitioned_variables([4], [4], rnd_par)
    326       variables.global_variables_initializer().run()
    327       val = array_ops.concat(vs, 0).eval()
    328       rnd = rnd_par.eval()
    329       self.assertAllClose(rnd, val)
    330       self.assertEqual([dtypes.int32] * 4, [v.dtype.base_dtype for v in vs])
    331       self._TestSaveSpec(vs, ["4 0,1", "4 1,1", "4 2,1", "4 3,1"])
    332 
    333   def testConstantInit(self):
    334     with self.test_session():
    335       rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
    336       vs = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
    337                                                               rnd_par)
    338       variables.global_variables_initializer().run()
    339       val = array_ops.concat(vs, 1).eval()
    340       rnd = rnd_par.eval()
    341       self.assertAllClose(rnd, val)
    342       self.assertEqual([dtypes.int32] * 2, [v.dtype.base_dtype for v in vs])
    343       self._TestSaveSpec(vs, ["2 4 0,2:0,2", "2 4 0,2:2,2"])
    344 
    345   def _testNameHelper(self, use_resource=False):
    346     with self.test_session():
    347       rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
    348       with variable_scope.variable_scope("hi", use_resource=use_resource):
    349         vs1 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
    350                                                                  rnd_par)
    351         vs2 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
    352                                                                  rnd_par)
    353       variables.global_variables_initializer().run()
    354       var1_name = vs1[0]._save_slice_info.full_name
    355       var2_name = vs2[0]._save_slice_info.full_name
    356       self.assertEqual("hi/PartitionedVariable", var1_name)
    357       self.assertEqual("hi/PartitionedVariable_1", var2_name)
    358       self.assertEqual(var1_name + "/part_0:0", vs1[0].name)
    359       self.assertEqual(var1_name + "/part_1:0", vs1[1].name)
    360       self.assertEqual(var2_name + "/part_0:0", vs2[0].name)
    361       self.assertEqual(var2_name + "/part_1:0", vs2[1].name)
    362     # Test same variable.
    363     with self.test_session():
    364       rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
    365       with variable_scope.variable_scope(
    366           "hola", use_resource=use_resource) as vs:
    367         vs1 = partitioned_variables.create_partitioned_variables(
    368             [2, 4], [1, 2], rnd_par, dtype=dtypes.int32)
    369       with variable_scope.variable_scope(
    370           vs, reuse=True, use_resource=use_resource):
    371         vs2 = partitioned_variables.create_partitioned_variables(
    372             [2, 4], [1, 2], rnd_par, dtype=dtypes.int32)
    373       variables.global_variables_initializer().run()
    374       var1_name = vs1[0]._save_slice_info.full_name
    375       var2_name = vs2[0]._save_slice_info.full_name
    376       self.assertEqual("hola/PartitionedVariable", var1_name)
    377       self.assertEqual("hola/PartitionedVariable", var2_name)
    378       self.assertEqual(var1_name + "/part_0:0", vs1[0].name)
    379       self.assertEqual(var1_name + "/part_1:0", vs1[1].name)
    380       self.assertEqual(var2_name + "/part_0:0", vs2[0].name)
    381       self.assertEqual(var2_name + "/part_1:0", vs2[1].name)
    382     # Test name_scope
    383     with self.test_session():
    384       rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
    385       with ops.name_scope("ola"):
    386         vs1 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
    387                                                                  rnd_par)
    388         vs2 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
    389                                                                  rnd_par)
    390       variables.global_variables_initializer().run()
    391       var1_name = vs1[0]._save_slice_info.full_name
    392       var2_name = vs2[0]._save_slice_info.full_name
    393       # Currently, the name scope 'ola' has no effect.
    394       self.assertEqual("PartitionedVariable", var1_name)
    395       self.assertEqual("PartitionedVariable_1", var2_name)
    396       self.assertEqual(var1_name + "/part_0:0", vs1[0].name)
    397       self.assertEqual(var1_name + "/part_1:0", vs1[1].name)
    398       self.assertEqual(var2_name + "/part_0:0", vs2[0].name)
    399       self.assertEqual(var2_name + "/part_1:0", vs2[1].name)
    400 
    401   def testName(self):
    402     self._testNameHelper(use_resource=False)
    403 
    404   def testResourceName(self):
    405     self._testNameHelper(use_resource=True)
    406 
    407   def testRandomInitValue(self):
    408     with self.test_session():
    409       rnd = variables.Variable(random_ops.random_uniform([200, 40]))
    410       vs = partitioned_variables.create_partitioned_variables(
    411           rnd.get_shape(), [1, 10], rnd.initialized_value())
    412       variables.global_variables_initializer().run()
    413       val = array_ops.concat(vs, 1).eval()
    414       rnd = rnd.eval()
    415       self.assertAllClose(rnd, val)
    416       self.assertEqual([dtypes.float32] * 10, [v.dtype.base_dtype for v in vs])
    417       self._TestSaveSpec(vs, [
    418           "200 40 0,200:0,4", "200 40 0,200:4,4", "200 40 0,200:8,4",
    419           "200 40 0,200:12,4", "200 40 0,200:16,4", "200 40 0,200:20,4",
    420           "200 40 0,200:24,4", "200 40 0,200:28,4", "200 40 0,200:32,4",
    421           "200 40 0,200:36,4"
    422       ])
    423 
    424   def testRandomInitUnevenPartitions(self):
    425     with self.test_session():
    426       rnd = variables.Variable(
    427           random_ops.random_uniform([20, 43], dtype=dtypes.float64))
    428       var_lists = [
    429           partitioned_variables.create_partitioned_variables(
    430               rnd.get_shape(), [1, i], rnd.initialized_value())
    431           for i in xrange(1, 10)
    432       ]
    433       variables.global_variables_initializer().run()
    434       rnd_val = rnd.eval()
    435       # Only check the slice save specs for the first 5 tf.
    436       save_specs = [
    437           # One slice
    438           ["20 43 0,20:0,43"],
    439           # Two slices
    440           ["20 43 0,20:0,22", "20 43 0,20:22,21"],
    441           # Three slices
    442           ["20 43 0,20:0,15", "20 43 0,20:15,14", "20 43 0,20:29,14"],
    443           # Four slices
    444           [
    445               "20 43 0,20:0,11", "20 43 0,20:11,11", "20 43 0,20:22,11",
    446               "20 43 0,20:33,10"
    447           ],
    448           # Five slices
    449           [
    450               "20 43 0,20:0,9", "20 43 0,20:9,9", "20 43 0,20:18,9",
    451               "20 43 0,20:27,8", "20 43 0,20:35,8"
    452           ]
    453       ]
    454       for i, vs in enumerate(var_lists):
    455         var_val = array_ops.concat(vs, 1).eval()
    456         self.assertAllClose(rnd_val, var_val)
    457         self.assertEqual([dtypes.float64] * len(vs),
    458                          [v.dtype.base_dtype for v in vs])
    459         if i < len(save_specs):
    460           self._TestSaveSpec(vs, save_specs[i])
    461 
    462   def testDegenerate(self):
    463     with self.test_session():
    464       rnd = variables.Variable(random_ops.random_uniform([10, 43]))
    465       vs = partitioned_variables.create_partitioned_variables(
    466           rnd.get_shape(), [1, 1], rnd.initialized_value())
    467       variables.global_variables_initializer().run()
    468       val = array_ops.concat(vs, 0).eval()
    469       rnd = rnd.eval()
    470       self.assertAllClose(rnd, val)
    471       self._TestSaveSpec(vs, ["10 43 0,10:0,43"])
    472 
    473   def testSliceSizeOne(self):
    474     with self.test_session():
    475       rnd = variables.Variable(random_ops.random_uniform([10, 43]))
    476       vs = partitioned_variables.create_partitioned_variables(
    477           rnd.get_shape(), [10, 1], rnd.initialized_value())
    478       variables.global_variables_initializer().run()
    479       val = array_ops.concat(vs, 0).eval()
    480       rnd = rnd.eval()
    481       self.assertAllClose(rnd, val)
    482       self._TestSaveSpec(vs, [
    483           "10 43 0,1:0,43", "10 43 1,1:0,43", "10 43 2,1:0,43",
    484           "10 43 3,1:0,43", "10 43 4,1:0,43", "10 43 5,1:0,43",
    485           "10 43 6,1:0,43", "10 43 7,1:0,43", "10 43 8,1:0,43", "10 43 9,1:0,43"
    486       ])
    487 
    488   def testIotaInitializer(self):
    489     self.assertAllClose([0., 1., 2., 3.], _IotaInitializer([4]))
    490     self.assertAllClose([[0., 1.], [0., 10.], [0., 100.], [0., 1000.]],
    491                         _IotaInitializer([4, 2]))
    492     with self.test_session():
    493       vs = partitioned_variables.create_partitioned_variables([13, 5], [3, 1],
    494                                                               _IotaInitializer)
    495       variables.global_variables_initializer().run()
    496       slice0 = _IotaInitializer([5, 5])
    497       slice1 = _IotaInitializer([4, 5])
    498       slice2 = _IotaInitializer([4, 5])
    499       val = array_ops.concat(vs, 0).eval()
    500       self.assertAllClose(slice0 + slice1 + slice2, val)
    501       self._TestSaveSpec(vs, ["13 5 0,5:0,5", "13 5 5,4:0,5", "13 5 9,4:0,5"])
    502 
    503   def testRandomInitializer(self):
    504     # Sanity check that the slices uses a different seed when using a random
    505     # initializer function.
    506     with self.test_session():
    507       var0, var1 = partitioned_variables.create_partitioned_variables(
    508           [20, 12], [1, 2], init_ops.random_uniform_initializer())
    509       variables.global_variables_initializer().run()
    510       val0, val1 = var0.eval().flatten(), var1.eval().flatten()
    511       self.assertTrue(np.linalg.norm(val0 - val1) > 1e-6)
    512     # Negative test that proves that slices have the same values if
    513     # the random initializer uses a seed.
    514     with self.test_session():
    515       var0, var1 = partitioned_variables.create_partitioned_variables(
    516           [20, 12], [1, 2], init_ops.random_uniform_initializer(seed=201))
    517       variables.global_variables_initializer().run()
    518       val0, val1 = var0.eval().flatten(), var1.eval().flatten()
    519       self.assertAllClose(val0, val1)
    520 
    521   def testSomeErrors(self):
    522     with self.test_session():
    523       rnd = variables.Variable(random_ops.random_uniform([10, 43]))
    524       with self.assertRaises(ValueError):
    525         partitioned_variables.create_partitioned_variables(
    526             [10], [1, 1], rnd.initialized_value())
    527       with self.assertRaises(ValueError):
    528         partitioned_variables.create_partitioned_variables(
    529             [10, 20], [1], rnd.initialized_value())
    530       with self.assertRaises(ValueError):
    531         partitioned_variables.create_partitioned_variables(
    532             [10, 43], [1], rnd.initialized_value())
    533       with self.assertRaises(ValueError):
    534         partitioned_variables.create_partitioned_variables(
    535             [10, 43], [1, 2, 3], rnd.initialized_value())
    536       with self.assertRaises(ValueError):
    537         partitioned_variables.create_partitioned_variables(
    538             [10, 43], [11, 1], rnd.initialized_value())
    539       with self.assertRaises(ValueError):
    540         partitioned_variables.create_partitioned_variables(
    541             [10, 43], [20, 1], rnd.initialized_value())
    542       with self.assertRaises(ValueError):
    543         partitioned_variables.create_partitioned_variables(
    544             [10, 43], [1, 50], rnd.initialized_value())
    545 
    546   def testControlDepsNone(self):
    547     with self.test_session() as session:
    548       c = constant_op.constant(1.0)
    549       with ops.control_dependencies([c]):
    550         # d get the control dependency.
    551         d = constant_op.constant(2.0)
    552         # Partitioned variables do not.
    553         var_x = variable_scope.get_variable(
    554             "x",
    555             shape=[2],
    556             initializer=init_ops.ones_initializer(),
    557             partitioner=partitioned_variables.variable_axis_size_partitioner(4))
    558 
    559         ops_before_read = session.graph.get_operations()
    560         var_x.as_tensor()  # Caches the ops for subsequent reads.
    561         reading_ops = [
    562             op for op in session.graph.get_operations()
    563             if op not in ops_before_read
    564         ]
    565 
    566       self.assertEqual([c.op], d.op.control_inputs)
    567       # Tests that no control dependencies are added to reading a partitioned
    568       # variable which is similar to reading a variable.
    569       for op in reading_ops:
    570         self.assertEqual([], op.control_inputs)
    571 
    572   def testConcat(self):
    573     with self.test_session() as session:
    574       var_x = variable_scope.get_variable(
    575           "x",
    576           initializer=constant_op.constant([1., 2.]),
    577           partitioner=partitioned_variables.variable_axis_size_partitioner(4))
    578 
    579       c = constant_op.constant(1.0)
    580       with ops.control_dependencies([c]):
    581         ops_before_concat = session.graph.get_operations()
    582         value = var_x._concat()  # pylint: disable=protected-access
    583         concat_ops = [
    584             op for op in session.graph.get_operations()
    585             if op not in ops_before_concat
    586         ]
    587 
    588       concat_control_inputs = [
    589           ci for op in concat_ops for ci in op.control_inputs
    590       ]
    591       self.assertTrue(
    592           c.op in concat_control_inputs,
    593           "var_x._concat() should get control dependencies from its scope.")
    594       variables.global_variables_initializer().run()
    595       self.assertAllClose(value.eval(), var_x.as_tensor().eval())
    596 
    597 
    598 if __name__ == "__main__":
    599   test.main()
    600