Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tf.py."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import operator
     22 
     23 import numpy as np
     24 
     25 from tensorflow.python.eager import function
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.framework import errors_impl
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.ops import array_ops
     31 from tensorflow.python.ops import control_flow_ops
     32 from tensorflow.python.ops import gen_state_ops
     33 from tensorflow.python.ops import math_ops
     34 from tensorflow.python.ops import random_ops
     35 from tensorflow.python.ops import resource_variable_ops
     36 from tensorflow.python.ops import variables
     37 from tensorflow.python.platform import test
     38 from tensorflow.python.training import gradient_descent
     39 from tensorflow.python.util import compat
     40 
     41 
     42 class VariablesTestCase(test.TestCase):
     43 
     44   def testInitialization(self):
     45     with self.test_session():
     46       var0 = variables.Variable(0.0)
     47       self.assertEqual("Variable:0", var0.name)
     48       self.assertEqual("Variable", var0._shared_name)
     49       self.assertEqual([], var0.get_shape())
     50       self.assertEqual([], var0.get_shape())
     51       self.assertEqual([], var0.shape)
     52 
     53       var1 = variables.Variable(1.1)
     54       self.assertEqual("Variable_1:0", var1.name)
     55       self.assertEqual("Variable_1", var1._shared_name)
     56       self.assertEqual([], var1.get_shape())
     57       self.assertEqual([], var1.get_shape())
     58       self.assertEqual([], var1.shape)
     59 
     60       with self.assertRaisesOpError("Attempting to use uninitialized value"):
     61         var0.eval()
     62 
     63       with self.assertRaisesOpError("Attempting to use uninitialized value"):
     64         var1.eval()
     65 
     66       variables.global_variables_initializer().run()
     67 
     68       self.assertAllClose(0.0, var0.eval())
     69       self.assertAllClose(1.1, var1.eval())
     70 
     71   def testInitializationOrder(self):
     72     with self.test_session():
     73       rnd = variables.Variable(random_ops.random_uniform([3, 6]), name="rnd")
     74       self.assertEqual("rnd:0", rnd.name)
     75       self.assertEqual([3, 6], rnd.get_shape())
     76       self.assertEqual([3, 6], rnd.get_shape())
     77       self.assertEqual([3, 6], rnd.shape)
     78 
     79       dep = variables.Variable(rnd.initialized_value(), name="dep")
     80       self.assertEqual("dep:0", dep.name)
     81       self.assertEqual([3, 6], dep.get_shape())
     82       self.assertEqual([3, 6], dep.get_shape())
     83       self.assertEqual([3, 6], dep.shape)
     84 
     85       # Currently have to set the shape manually for Add.
     86       added_val = rnd.initialized_value() + dep.initialized_value() + 2.0
     87       added_val.set_shape(rnd.get_shape())
     88 
     89       depdep = variables.Variable(added_val, name="depdep")
     90       self.assertEqual("depdep:0", depdep.name)
     91       self.assertEqual([3, 6], depdep.get_shape())
     92       self.assertEqual([3, 6], depdep.get_shape())
     93       self.assertEqual([3, 6], depdep.shape)
     94 
     95       variables.global_variables_initializer().run()
     96 
     97       self.assertAllClose(rnd.eval(), dep.eval())
     98       self.assertAllClose(rnd.eval() + dep.eval() + 2.0, depdep.eval())
     99 
    100   def testIterable(self):
    101     with self.assertRaisesRegexp(TypeError, "not iterable"):
    102       for _ in variables.Variable(0.0):
    103         pass
    104     with self.assertRaisesRegexp(TypeError, "not iterable"):
    105       for _ in variables.Variable([0.0, 1.0]):
    106         pass
    107 
    108   def testAssignments(self):
    109     with self.test_session():
    110       var = variables.Variable(0.0)
    111       plus_one = var.assign_add(1.0)
    112       minus_one = var.assign_sub(2.0)
    113       four = var.assign(4.0)
    114       variables.global_variables_initializer().run()
    115       self.assertAllClose(0.0, var.eval())
    116 
    117       self.assertAllClose(1.0, plus_one.eval())
    118       self.assertAllClose(1.0, var.eval())
    119 
    120       self.assertAllClose(-1.0, minus_one.eval())
    121       self.assertAllClose(-1.0, var.eval())
    122 
    123       self.assertAllClose(4.0, four.eval())
    124       self.assertAllClose(4.0, var.eval())
    125 
    126   def testResourceAssignments(self):
    127     with self.test_session(use_gpu=True):
    128       var = resource_variable_ops.ResourceVariable(0.0)
    129       plus_one = var.assign_add(1.0)
    130       minus_one = var.assign_sub(2.0)
    131       four = var.assign(4.0)
    132       variables.global_variables_initializer().run()
    133       self.assertAllClose(0.0, var.eval())
    134 
    135       plus_one.eval()
    136       self.assertAllClose(1.0, var.eval())
    137 
    138       minus_one.eval()
    139       self.assertAllClose(-1.0, var.eval())
    140 
    141       four.eval()
    142       self.assertAllClose(4.0, var.eval())
    143 
    144   def testZeroSizeStringAssign(self):
    145     with self.test_session() as sess:
    146       array = variables.Variable(
    147           initial_value=array_ops.zeros((0,), dtype=dtypes.string),
    148           name="foo",
    149           trainable=False,
    150           collections=[ops.GraphKeys.LOCAL_VARIABLES])
    151       sess.run(variables.local_variables_initializer())
    152       old_value = array.value()
    153       copy_op = array.assign(old_value)
    154       self.assertEqual([], list(sess.run(copy_op)))
    155 
    156   def _countUpToTest(self, dtype):
    157     with self.test_session():
    158       zero = constant_op.constant(0, dtype=dtype)
    159       var = variables.Variable(zero)
    160       count_up_to = var.count_up_to(3)
    161 
    162       variables.global_variables_initializer().run()
    163       self.assertEqual(0, var.eval())
    164 
    165       self.assertEqual(0, count_up_to.eval())
    166       self.assertEqual(1, var.eval())
    167 
    168       self.assertEqual(1, count_up_to.eval())
    169       self.assertEqual(2, var.eval())
    170 
    171       self.assertEqual(2, count_up_to.eval())
    172       self.assertEqual(3, var.eval())
    173 
    174       with self.assertRaisesOpError("Reached limit of 3"):
    175         count_up_to.eval()
    176       self.assertEqual(3, var.eval())
    177 
    178       with self.assertRaisesOpError("Reached limit of 3"):
    179         count_up_to.eval()
    180       self.assertEqual(3, var.eval())
    181 
    182   def testCountUpToInt32(self):
    183     self._countUpToTest(dtypes.int32)
    184 
    185   def testCountUpToInt64(self):
    186     self._countUpToTest(dtypes.int64)
    187 
    188   def testControlDepsNone(self):
    189     with self.test_session():
    190       c = constant_op.constant(1.0)
    191       with ops.control_dependencies([c]):
    192         # d get the control dep.
    193         d = constant_op.constant(2.0)
    194         # variables do not.
    195         var_x = variables.Variable(2.0)
    196       self.assertEqual([c.op], d.op.control_inputs)
    197       self.assertEqual([], var_x.initializer.control_inputs)
    198       self.assertEqual([], var_x.value().op.control_inputs)
    199       self.assertEqual([], var_x._ref().op.control_inputs)  # pylint: disable=protected-access
    200 
    201   def testControlFlow(self):
    202     with self.test_session() as sess:
    203       v0 = variables.Variable(0, name="v0")
    204       var_dict = {}
    205 
    206       # Call get_variable in each of the cond clauses.
    207       def var_in_then_clause():
    208         v1 = variables.Variable(1, name="v1")
    209         var_dict["v1"] = v1
    210         return v1 + v0
    211 
    212       def var_in_else_clause():
    213         v2 = variables.Variable(2, name="v2")
    214         var_dict["v2"] = v2
    215         return v2 + v0
    216 
    217       add = control_flow_ops.cond(
    218           math_ops.less(v0, 10), var_in_then_clause, var_in_else_clause)
    219       v1 = var_dict["v1"]
    220       v2 = var_dict["v2"]
    221       # We should be able to initialize and run v1 and v2 without initializing
    222       # v0, even if the variable was created with a control dep on v0.
    223       sess.run(v1.initializer)
    224       self.assertEqual([1], sess.run(v1))
    225       sess.run(v2.initializer)
    226       self.assertEqual([2], sess.run(v2))
    227       # v0 should still be uninitialized.
    228       with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"):
    229         sess.run(v0)
    230       # We should not be able to run 'add' yet.
    231       with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"):
    232         sess.run(add)
    233       # If we initialize v0 we should be able to run 'add'.
    234       sess.run(v0.initializer)
    235       sess.run(add)
    236 
    237   def testControlFlowInitialization(self):
    238     """Expects an error if an initializer is in a control-flow scope."""
    239     def cond(i, _):
    240       return i < 10
    241 
    242     def body(i, _):
    243       zero = array_ops.zeros([], dtype=dtypes.int32)
    244       v = variables.Variable(initial_value=zero)
    245       return (i + 1, v.read_value())
    246 
    247     with self.assertRaisesRegexp(ValueError, "inside a control-flow"):
    248       control_flow_ops.while_loop(cond, body, [0, 0])
    249 
    250   def testUseVariableAsTensor(self):
    251     with self.test_session():
    252       var_x = variables.Variable(2.0)
    253       var_y = variables.Variable(3.0)
    254       variables.global_variables_initializer().run()
    255       self.assertAllClose(2.0, var_x.eval())
    256       self.assertAllClose(3.0, var_y.eval())
    257       self.assertAllClose(5.0, math_ops.add(var_x, var_y).eval())
    258 
    259   def testZeroSizeVarSameAsConst(self):
    260     with self.test_session():
    261       zero_size_var = variables.Variable(array_ops.zeros([0, 2]))
    262       zero_size_const = array_ops.ones([2, 0])
    263       variable_mul = math_ops.matmul(zero_size_const, zero_size_var)
    264       const_mul = math_ops.matmul(
    265           zero_size_const, zero_size_const, transpose_b=True)
    266       variables.global_variables_initializer().run()
    267       variable_output = variable_mul.eval()
    268       self.assertAllClose(const_mul.eval(), variable_output)
    269       self.assertAllClose([[0., 0.], [0., 0.]], variable_output)
    270 
    271   def testCachingDevice(self):
    272     with self.test_session():
    273       var = variables.Variable(2.0)
    274       self.assertEqual(var.device, var.value().device)
    275       self.assertEqual(var.device, var.initialized_value().device)
    276 
    277       var_cached = variables.Variable(2.0, caching_device="/job:foo")
    278       self.assertFalse(var_cached.device.startswith("/job:foo"))
    279       self.assertTrue(var_cached.value().device.startswith("/job:foo"))
    280 
    281   def testCollections(self):
    282     with self.test_session():
    283       var_x = variables.Variable(2.0)
    284       var_y = variables.Variable(2.0, trainable=False)
    285       var_z = variables.Variable(2.0, trainable=True)
    286       var_t = variables.Variable(
    287           2.0,
    288           trainable=True,
    289           collections=[
    290               ops.GraphKeys.TRAINABLE_VARIABLES, ops.GraphKeys.GLOBAL_VARIABLES
    291           ])
    292       self.assertEqual([var_x, var_y, var_z, var_t],
    293                        variables.global_variables())
    294       self.assertEqual([var_x, var_z, var_t], variables.trainable_variables())
    295 
    296   def testCollectionsWithScope(self):
    297     with self.test_session():
    298       with ops.name_scope("scope_1"):
    299         var_x = variables.Variable(2.0)
    300       with ops.name_scope("scope_2"):
    301         var_y = variables.Variable(2.0)
    302 
    303       self.assertEqual([var_x, var_y], variables.global_variables())
    304       self.assertEqual([var_x], variables.global_variables("scope_1"))
    305       self.assertEqual([var_y], variables.global_variables("scope_2"))
    306 
    307       self.assertEqual([var_x, var_y], variables.trainable_variables())
    308       self.assertEqual([var_x], variables.trainable_variables("scope_1"))
    309       self.assertEqual([var_y], variables.trainable_variables("scope_2"))
    310 
    311   def testOperators(self):
    312     with self.test_session():
    313       var_f = variables.Variable([2.0])
    314       add = var_f + 0.0
    315       radd = 1.0 + var_f
    316       sub = var_f - 1.0
    317       rsub = 1.0 - var_f
    318       mul = var_f * 10.0
    319       rmul = 10.0 * var_f
    320       div = var_f / 10.0
    321       rdiv = 10.0 / var_f
    322       lt = var_f < 3.0
    323       rlt = 3.0 < var_f
    324       le = var_f <= 2.0
    325       rle = 2.0 <= var_f
    326       gt = var_f > 3.0
    327       rgt = 3.0 > var_f
    328       ge = var_f >= 2.0
    329       rge = 2.0 >= var_f
    330       neg = -var_f
    331       abs_v = abs(var_f)
    332 
    333       var_i = variables.Variable([20])
    334       mod = var_i % 7
    335       rmod = 103 % var_i
    336 
    337       var_b = variables.Variable([True, False])
    338       and_v = operator.and_(var_b, [True, True])
    339       or_v = operator.or_(var_b, [False, True])
    340       xor_v = operator.xor(var_b, [False, False])
    341       invert_v = ~var_b
    342 
    343       rnd = np.random.rand(4, 4).astype("f")
    344       var_t = variables.Variable(rnd)
    345       slice_v = var_t[2, 0:0]
    346 
    347       var_m = variables.Variable([[2.0, 3.0]])
    348       matmul = var_m.__matmul__([[10.0], [20.0]])
    349       rmatmul = var_m.__rmatmul__([[10.0], [20.0]])
    350 
    351       variables.global_variables_initializer().run()
    352       self.assertAllClose([2.0], add.eval())
    353       self.assertAllClose([3.0], radd.eval())
    354       self.assertAllClose([1.0], sub.eval())
    355       self.assertAllClose([-1.0], rsub.eval())
    356       self.assertAllClose([20.0], mul.eval())
    357       self.assertAllClose([20.0], rmul.eval())
    358       self.assertAllClose([0.2], div.eval())
    359       self.assertAllClose([5.0], rdiv.eval())
    360       self.assertAllClose([-2.0], neg.eval())
    361       self.assertAllClose([2.0], abs_v.eval())
    362       self.assertAllClose([True], lt.eval())
    363       self.assertAllClose([False], rlt.eval())
    364       self.assertAllClose([True], le.eval())
    365       self.assertAllClose([True], rle.eval())
    366       self.assertAllClose([False], gt.eval())
    367       self.assertAllClose([True], rgt.eval())
    368       self.assertAllClose([True], ge.eval())
    369       self.assertAllClose([True], rge.eval())
    370 
    371       self.assertAllClose([6], mod.eval())
    372       self.assertAllClose([3], rmod.eval())
    373 
    374       self.assertAllClose([True, False], and_v.eval())
    375       self.assertAllClose([True, True], or_v.eval())
    376       self.assertAllClose([True, False], xor_v.eval())
    377       self.assertAllClose([False, True], invert_v.eval())
    378 
    379       self.assertAllClose(rnd[2, 0:0], slice_v.eval())
    380 
    381       self.assertAllClose([[80.0]], matmul.eval())
    382       self.assertAllClose([[20.0, 30.0], [40.0, 60.0]], rmatmul.eval())
    383 
    384   def testSession(self):
    385     with self.test_session() as sess:
    386       var = variables.Variable([1, 12])
    387       variables.global_variables_initializer().run()
    388       self.assertAllClose([1, 12], sess.run(var))
    389 
    390   def testDevicePlacement(self):
    391     with self.test_session() as sess:
    392       with ops.device("/cpu:0"):
    393         var = variables.Variable([1, 12])
    394       init_value = var.initialized_value()
    395       init_op = variables.global_variables_initializer()
    396       self.assertEqual(var.op.device, init_value.device)
    397       self.assertEqual(var.op.device, init_op.device)
    398       sess.run(init_op)
    399 
    400   def testColocation(self):
    401     with ops.device("/job:ps"):
    402       var = variables.Variable(0, name="v")
    403     with ops.device("/job:worker/task:7"):
    404       assign_op = var.assign(1)
    405     self.assertDeviceEqual("/job:ps", assign_op.device)
    406     self.assertEqual([b"loc:@v"], assign_op.op.colocation_groups())
    407 
    408   def testInitializerFunction(self):
    409     value = [[-42], [133.7]]
    410     shape = [2, 1]
    411     with self.test_session():
    412       initializer = lambda: constant_op.constant(value)
    413 
    414       v1 = variables.Variable(initializer, dtype=dtypes.float32)
    415       self.assertEqual(shape, v1.get_shape())
    416       self.assertEqual(shape, v1.shape)
    417       self.assertAllClose(value, v1.initial_value.eval())
    418       with self.assertRaises(errors_impl.FailedPreconditionError):
    419         v1.eval()
    420 
    421       v2 = variables.Variable(
    422           math_ops.negative(v1.initialized_value()), dtype=dtypes.float32)
    423       self.assertEqual(v1.get_shape(), v2.get_shape())
    424       self.assertEqual(v1.shape, v2.shape)
    425       self.assertAllClose(np.negative(value), v2.initial_value.eval())
    426 
    427       with self.assertRaises(errors_impl.FailedPreconditionError):
    428         v2.eval()
    429       variables.global_variables_initializer().run()
    430       self.assertAllClose(np.negative(value), v2.eval())
    431 
    432   def testConstraintArg(self):
    433     constraint = lambda x: x
    434     v = variables.Variable(
    435         lambda: constant_op.constant(1.),
    436         constraint=constraint)
    437     self.assertEqual(v.constraint, constraint)
    438 
    439     constraint = 0
    440     with self.assertRaises(ValueError):
    441       v = variables.Variable(
    442           lambda: constant_op.constant(1.),
    443           constraint=constraint)
    444 
    445   def testNoRefDataRace(self):
    446     with self.test_session():
    447       a = variables.Variable([1, 2, 3], dtype=dtypes.float32)
    448       b = variables.Variable(a.initialized_value() + 2)
    449       c = variables.Variable(b.initialized_value() + 2)
    450       variables.global_variables_initializer().run()
    451       self.assertAllEqual(a.eval(), [1, 2, 3])
    452       self.assertAllEqual(b.eval(), [3, 4, 5])
    453       self.assertAllEqual(c.eval(), [5, 6, 7])
    454 
    455   def testInitializerFunctionDevicePlacement(self):
    456     with self.test_session():
    457       initializer = lambda: constant_op.constant(42.0)
    458       with ops.device("/cpu:100"):
    459         v1 = variables.Variable(initializer, dtype=dtypes.float32, name="v1")
    460       expected_device = "/device:CPU:100"
    461       expected_group_v1 = [b"loc:@v1"]
    462       self.assertEqual(expected_device, v1.op.device)
    463       self.assertEqual(expected_group_v1, v1.op.colocation_groups())
    464       for i in v1.initializer.inputs:
    465         self.assertEqual(expected_group_v1, i.op.colocation_groups())
    466 
    467       v2 = variables.Variable(initializer, dtype=dtypes.float32, name="v2")
    468       expected_group_v2 = [b"loc:@v2"]
    469       self.assertEqual(expected_group_v2, v2.op.colocation_groups())
    470       for i in v2.initializer.inputs:
    471         self.assertEqual(expected_group_v2, i.op.colocation_groups())
    472 
    473   def testVariableDefInitializedInstances(self):
    474     with ops.Graph().as_default(), self.test_session() as sess:
    475       v_def = variables.Variable(
    476           initial_value=constant_op.constant(3.0)).to_proto()
    477 
    478     with ops.Graph().as_default(), self.test_session() as sess:
    479       # v describes a VariableDef-based variable without an initial value.
    480       v = variables.Variable(variable_def=v_def)
    481       self.assertEqual(3.0, sess.run(v.initialized_value()))
    482 
    483       # initialized_value should not rerun the initializer_op if the variable
    484       # has already been initialized elsewhere.
    485       sess.run(v.assign(1.0))
    486       self.assertEqual(1.0, v.initialized_value().eval())
    487 
    488     v_def.ClearField("initial_value_name")
    489     with ops.Graph().as_default(), self.test_session() as sess:
    490       # Restoring a legacy VariableDef proto that does not have
    491       # initial_value_name set should still work.
    492       v = variables.Variable(variable_def=v_def)
    493       # We should also be able to re-export the variable to a new meta graph.
    494       self.assertProtoEquals(v_def, v.to_proto())
    495       # But attempts to use initialized_value will result in errors.
    496       with self.assertRaises(ValueError):
    497         sess.run(v.initialized_value())
    498 
    499   def testLoad(self):
    500     with self.test_session():
    501       var = variables.Variable(np.zeros((5, 5), np.float32))
    502       variables.global_variables_initializer().run()
    503       var.load(np.ones((5, 5), np.float32))
    504 
    505       self.assertAllClose(np.ones((5, 5), np.float32), var.eval())
    506 
    507   def testRepr(self):
    508     var = variables.Variable(np.zeros((5, 5), np.float32), name="noop")
    509     self.assertEqual(
    510         "<tf.Variable 'noop:0' shape=(5, 5) dtype=float32_ref>",
    511         repr(var))
    512 
    513   def testVariableNamesPreserveNameScopesWithDefun(self):
    514     @function.defun
    515     def create_variable():
    516       with ops.name_scope("foo"):
    517         v = variables.Variable(0.0, name="bar")
    518       self.assertEqual(v.name, "foo/bar:0")
    519     with ops.get_default_graph().as_default():
    520       create_variable()
    521 
    522 
    523 class IsInitializedTest(test.TestCase):
    524 
    525   def testNoVars(self):
    526     with ops.Graph().as_default(), self.test_session() as sess:
    527       uninited = variables.report_uninitialized_variables()
    528       self.assertEqual(0, sess.run(uninited).size)
    529 
    530   def testAssertVariablesInitialized(self):
    531     with ops.Graph().as_default(), self.test_session() as sess:
    532       v = variables.Variable([1, 2], name="v")
    533       w = variables.Variable([3, 4], name="w")
    534       _ = v, w
    535       uninited = variables.report_uninitialized_variables()
    536       self.assertAllEqual(np.array([b"v", b"w"]), sess.run(uninited))
    537       variables.global_variables_initializer().run()
    538       self.assertEqual(0, sess.run(uninited).size)
    539 
    540   def testVariableList(self):
    541     with ops.Graph().as_default(), self.test_session() as sess:
    542       v = variables.Variable([1, 2], name="v")
    543       w = variables.Variable([3, 4], name="w")
    544       uninited = variables.report_uninitialized_variables()
    545       self.assertAllEqual(np.array([b"v", b"w"]), sess.run(uninited))
    546       sess.run(w.initializer)
    547       self.assertAllEqual(np.array([b"v"]), sess.run(uninited))
    548       v.initializer.run()
    549       self.assertEqual(0, sess.run(uninited).size)
    550 
    551   def testZeroSizeVarInitialized(self):
    552     with ops.Graph().as_default(), self.test_session() as sess:
    553       v = variables.Variable(array_ops.zeros([0, 2]), name="v")
    554       uninited = variables.report_uninitialized_variables()
    555       v.initializer.run()  # not strictly necessary
    556       self.assertEqual(0, sess.run(uninited).size)
    557 
    558   def testTrainingWithZeroSizeVar(self):
    559     with ops.Graph().as_default(), self.test_session() as sess:
    560       a = variables.Variable(array_ops.zeros([0, 2]))
    561       b = variables.Variable(array_ops.ones([2, 2]))
    562       objective = math_ops.reduce_sum(b + math_ops.matmul(
    563           a, a, transpose_a=True))
    564       variables.global_variables_initializer().run()
    565       do_opt = gradient_descent.GradientDescentOptimizer(0.1).minimize(
    566           objective)
    567       sess.run([do_opt])
    568       self.assertAllClose([[0.9, 0.9], [0.9, 0.9]], b.eval())
    569 
    570 
    571 class ObsoleteIsInitializedTest(test.TestCase):
    572 
    573   def testNoVars(self):
    574     with ops.Graph().as_default():
    575       self.assertEqual(None, variables.assert_variables_initialized())
    576 
    577   def testVariables(self):
    578     with ops.Graph().as_default(), self.test_session() as sess:
    579       v = variables.Variable([1, 2])
    580       w = variables.Variable([3, 4])
    581       _ = v, w
    582       inited = variables.assert_variables_initialized()
    583       with self.assertRaisesOpError("Attempting to use uninitialized value"):
    584         sess.run(inited)
    585       variables.global_variables_initializer().run()
    586       sess.run(inited)
    587 
    588   def testVariableList(self):
    589     with ops.Graph().as_default(), self.test_session() as sess:
    590       v = variables.Variable([1, 2])
    591       w = variables.Variable([3, 4])
    592       inited = variables.assert_variables_initialized([v])
    593       with self.assertRaisesOpError("Attempting to use uninitialized value"):
    594         inited.op.run()
    595       sess.run(w.initializer)
    596       with self.assertRaisesOpError("Attempting to use uninitialized value"):
    597         inited.op.run()
    598       v.initializer.run()
    599       inited.op.run()
    600 
    601 
    602 class PartitionedVariableTest(test.TestCase):
    603 
    604   def testPartitionedVariable(self):
    605     with ops.Graph().as_default():
    606       v0 = variables.Variable([0])
    607       v1 = variables.Variable([1])
    608       v0._set_save_slice_info(
    609           variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1]))
    610       v1._set_save_slice_info(
    611           variables.Variable.SaveSliceInfo(v0.name, [2], [1], [1]))
    612       partitions = [2]
    613 
    614       # Pass variable_list as [v1, v0] to ensure they are properly
    615       # re-sorted to [v0, v1] based on their slice info offsets.
    616       partitioned_variable = variables.PartitionedVariable(
    617           name="two_vars",
    618           shape=[2],
    619           dtype=v0.dtype,
    620           variable_list=[v1, v0],
    621           partitions=partitions)
    622 
    623       concatenated = ops.convert_to_tensor(partitioned_variable)
    624       num_partitions = len(partitioned_variable)
    625       iterated_partitions = list(partitioned_variable)
    626       self.assertEqual(2, num_partitions)
    627       self.assertEqual([v0, v1], iterated_partitions)
    628       self.assertEqual([2], concatenated.get_shape())
    629       self.assertEqual([2], concatenated.shape)
    630 
    631   def testPartitionedVariableFailures(self):
    632     with ops.Graph().as_default():
    633       with self.assertRaisesRegexp(ValueError, "empty"):
    634         variables.PartitionedVariable(
    635             name="fail",
    636             shape=2,
    637             dtype=dtypes.int32,
    638             variable_list=[],
    639             partitions=[])
    640 
    641       with self.assertRaisesRegexp(ValueError, "must have a save_slice_info"):
    642         v0 = variables.Variable([0])
    643         partitions = [1]
    644         variables.PartitionedVariable(
    645             name="two_vars",
    646             shape=[1],
    647             dtype=v0.dtype,
    648             variable_list=[v0],
    649             partitions=partitions)
    650 
    651       with self.assertRaisesRegexp(ValueError, "full shapes must match"):
    652         v0 = variables.Variable([0])
    653         v1 = variables.Variable([1])
    654         v0._set_save_slice_info(
    655             variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1]))
    656         v1._set_save_slice_info(
    657             variables.Variable.SaveSliceInfo(v0.name, [2], [1], [1]))
    658         partitions = [2]
    659 
    660         variables.PartitionedVariable(
    661             name="two_vars",
    662             shape=[3],
    663             dtype=v0.dtype,
    664             variable_list=[v1, v0],
    665             partitions=partitions)
    666 
    667       with self.assertRaisesRegexp(ValueError, "must be positive"):
    668         v0 = variables.Variable([0])
    669         v0._set_save_slice_info(
    670             variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1]))
    671         partitions = [0]
    672 
    673         variables.PartitionedVariable(
    674             name="two_vars",
    675             shape=[2],
    676             dtype=v0.dtype,
    677             variable_list=[v0],
    678             partitions=partitions)
    679 
    680 
    681 class VariableContainerTest(test.TestCase):
    682 
    683   def testContainer(self):
    684     with ops.Graph().as_default():
    685       v0 = variables.Variable([0])
    686       with ops.container("l1"):
    687         v1 = variables.Variable([1])
    688         with ops.container("l2"):
    689           v2 = variables.Variable([2])
    690           special_v = gen_state_ops._variable(
    691               shape=[1],
    692               dtype=dtypes.float32,
    693               name="VariableInL3",
    694               container="l3",
    695               shared_name="")
    696         v3 = variables.Variable([3])
    697       v4 = variables.Variable([4])
    698     self.assertEqual(compat.as_bytes(""), v0.op.get_attr("container"))
    699     self.assertEqual(compat.as_bytes("l1"), v1.op.get_attr("container"))
    700     self.assertEqual(compat.as_bytes("l2"), v2.op.get_attr("container"))
    701     self.assertEqual(compat.as_bytes("l3"), special_v.op.get_attr("container"))
    702     self.assertEqual(compat.as_bytes("l1"), v3.op.get_attr("container"))
    703     self.assertEqual(compat.as_bytes(""), v4.op.get_attr("container"))
    704 
    705 
    706 if __name__ == "__main__":
    707   test.main()
    708