Home | History | Annotate | Download | only in tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for reading and writing variables."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.compiler.tests.xla_test import XLATestCase
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import errors
     27 from tensorflow.python.framework import ops
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import init_ops
     30 from tensorflow.python.ops import math_ops
     31 from tensorflow.python.ops import resource_variable_ops
     32 from tensorflow.python.ops import state_ops
     33 from tensorflow.python.ops import variable_scope
     34 from tensorflow.python.ops import variables
     35 from tensorflow.python.platform import googletest
     36 from tensorflow.python.training.gradient_descent import GradientDescentOptimizer
     37 
     38 
     39 class VariableOpsTest(XLATestCase):
     40   """Test cases for resource variable operators."""
     41 
     42   def testOneWriteOneOutput(self):
     43     # Regression test for a bug where computations with one non-constant
     44     # output and one variable update were mishandled.
     45     for dtype in self.numeric_types:
     46       init = np.array([[1, 2j], [3, 4]]).astype(dtype)
     47       with self.test_session() as sess, self.test_scope():
     48         v = resource_variable_ops.ResourceVariable(init)
     49         sess.run(variables.variables_initializer([v]))
     50         p = array_ops.placeholder(dtype)
     51         x = v.assign_add(p)
     52         with ops.control_dependencies([x]):
     53           y = v.read_value()
     54         self.assertAllClose(
     55             np.array([[2, 1 + 2j], [4, 5]]).astype(dtype), sess.run(y, {
     56                 p: 1
     57             }))
     58 
     59   def testSparseRead0DIndices(self):
     60     for dtype in self.numeric_types:
     61       init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8j, 9, 10,
     62                                                     11]]).astype(dtype)
     63       with self.test_session() as sess, self.test_scope():
     64         v = resource_variable_ops.ResourceVariable(init)
     65         sess.run(variables.variables_initializer([v]))
     66         x = v.sparse_read(2)
     67         self.assertAllClose(
     68             np.array([8j, 9, 10, 11]).astype(dtype), sess.run(x))
     69 
     70   def testSparseRead1DIndices(self):
     71     for dtype in self.numeric_types:
     72       init = np.array([[0, 1, 2, 3], [4, 5, 6j, 7], [8, 9, 10,
     73                                                      11]]).astype(dtype)
     74       with self.test_session() as sess, self.test_scope():
     75         v = resource_variable_ops.ResourceVariable(init)
     76         sess.run(variables.variables_initializer([v]))
     77         x = v.sparse_read([2, 1])
     78         self.assertAllClose(
     79             np.array([[8, 9, 10, 11], [4, 5, 6j, 7]]).astype(dtype),
     80             sess.run(x))
     81 
     82   def testSparseRead2DIndices(self):
     83     for dtype in self.numeric_types:
     84       init = np.array([[0, 1, 2j, 3], [4, 5, 6, 7], [8, 9, 10,
     85                                                      11]]).astype(dtype)
     86       with self.test_session() as sess, self.test_scope():
     87         v = resource_variable_ops.ResourceVariable(init)
     88         sess.run(variables.variables_initializer([v]))
     89         x = v.sparse_read([[2, 1], [0, 2]])
     90         self.assertAllClose(
     91             np.array([[[8, 9, 10, 11], [4, 5, 6, 7]],
     92                       [[0, 1, 2j, 3], [8, 9, 10, 11]]]).astype(dtype),
     93             sess.run(x))
     94 
     95   def testSparseRead2DIndices3DTensor(self):
     96     for dtype in self.numeric_types:
     97       init = np.array([[[0, 1, 2], [3, 4, 5]], [[10, 11, 12], [13, 14, 15]],
     98                        [[20, 21, 22], [23, 24j, 25]],
     99                        [[30, 31, 32], [33, 34, 35]]]).astype(dtype)
    100       with self.test_session() as sess, self.test_scope():
    101         v = resource_variable_ops.ResourceVariable(init)
    102         sess.run(variables.variables_initializer([v]))
    103         x = v.sparse_read([[2, 1], [3, 0]])
    104         self.assertAllClose(
    105             np.array(
    106                 [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]],
    107                  [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]],
    108             ).astype(dtype), sess.run(x))
    109 
    110   def testShape(self):
    111     for dtype in self.numeric_types:
    112       init = np.ones([2, 3]).astype(dtype)
    113       with self.test_session() as session, self.test_scope():
    114         v = resource_variable_ops.ResourceVariable(init)
    115         session.run(variables.variables_initializer([v]))
    116         h = v.handle
    117         s32, s64 = session.run([
    118             resource_variable_ops.variable_shape(h),
    119             resource_variable_ops.variable_shape(h, out_type=dtypes.int64)
    120         ])
    121         self.assertEqual(s32.dtype, np.int32)
    122         self.assertEqual(s64.dtype, np.int64)
    123         self.assertAllEqual(s32, [2, 3])
    124         self.assertAllEqual(s64, [2, 3])
    125 
    126   def testReadWrite(self):
    127     """Tests initialization, reading, and writing a resource variable."""
    128     for dtype in self.numeric_types:
    129       with self.test_session() as session:
    130         with self.test_scope():
    131           with variable_scope.variable_scope("ascope", use_resource=True):
    132             x = variable_scope.get_variable(
    133                 "x",
    134                 shape=[],
    135                 dtype=dtype,
    136                 initializer=init_ops.constant_initializer(2))
    137             a = x.read_value()
    138             with ops.control_dependencies([a]):
    139               b = state_ops.assign(x, dtype(47))
    140             with ops.control_dependencies([b]):
    141               c = x.read_value()
    142             with ops.control_dependencies([c]):
    143               d = state_ops.assign_add(x, np.array(6 + 2j).astype(dtype))
    144             with ops.control_dependencies([d]):
    145               e = state_ops.assign_sub(x, dtype(3))
    146             with ops.control_dependencies([e]):
    147               f = x.read_value()
    148 
    149         session.run(variables.global_variables_initializer())
    150         v1, v2, v3 = session.run([a, c, f])
    151         self.assertAllClose(dtype(2), v1)
    152         self.assertAllClose(dtype(47), v2)
    153         self.assertAllClose(np.array(50 + 2j).astype(dtype), v3)
    154 
    155   def testTraining(self):
    156     """Tests a gradient descent step for a simple model."""
    157     with self.test_session() as session:
    158       with self.test_scope():
    159         with variable_scope.variable_scope("ascope", use_resource=True):
    160           w = variable_scope.get_variable(
    161               "w",
    162               shape=[4, 2],
    163               dtype=dtypes.float32,
    164               initializer=init_ops.constant_initializer(
    165                   np.array([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=np.float32)))
    166           b = variable_scope.get_variable(
    167               "b",
    168               shape=[2],
    169               dtype=dtypes.float32,
    170               initializer=init_ops.constant_initializer(
    171                   np.array([2, 3], dtype=np.float32)))
    172 
    173           x = array_ops.placeholder(dtypes.float32, shape=[1, 4])
    174           y = math_ops.matmul(x, w) + b
    175           loss = math_ops.reduce_sum(y)
    176           optimizer = GradientDescentOptimizer(0.1)
    177           train = optimizer.minimize(loss)
    178 
    179       session.run(variables.global_variables_initializer())
    180       session.run(train, {x: np.array([[7, 3, 5, 9]], dtype=np.float32)})
    181       vw, vb = session.run([w, b])
    182       self.assertAllClose(
    183           np.array(
    184               [[0.3, 1.3], [2.7, 3.7], [4.5, 5.5], [6.1, 7.1]],
    185               dtype=np.float32),
    186           vw,
    187           rtol=1e-4)
    188       self.assertAllClose(np.array([1.9, 2.9], dtype=np.float32), vb, rtol=1e-4)
    189 
    190 
    191 class StridedSliceAssignChecker(object):
    192   """Compares the results of a slice assignment using Tensorflow and numpy."""
    193 
    194   def __init__(self, test, x, dtype):
    195     self.dtype = dtype
    196     self.test = test
    197     self.x_np = np.array(x).astype(dtype)
    198     # Randomly start on mode 0 or 1.
    199     self.which_mode = np.random.randint(2, size=1)[0]
    200 
    201   def __setitem__(self, index, value):
    202     self.which_mode = 1 - self.which_mode
    203     value = np.array(value).astype(self.dtype)
    204 
    205     with self.test.test_session() as sess, self.test.test_scope():
    206       x = constant_op.constant(self.x_np, dtype=self.dtype)
    207       var = resource_variable_ops.ResourceVariable(x)
    208       sess.run(variables.variables_initializer([var]))
    209 
    210       if self.which_mode == 0:
    211         val = sess.run(var[index].assign(value))
    212       else:
    213         assert self.which_mode == 1
    214         val = sess.run(state_ops.assign(var[index], value))
    215       valnp = np.copy(self.x_np)
    216       valnp[index] = np.array(value)
    217       self.test.assertAllEqual(val, valnp)
    218 
    219 
    220 class SliceAssignTest(XLATestCase):
    221 
    222   def testSliceAssign(self):
    223     for dtype in self.numeric_types:
    224       checker = StridedSliceAssignChecker(self, [[1, 2, 3], [4, 5, 6]],
    225                                           dtype=dtype)
    226       # No-op assignment
    227       checker[:] = [[10, 20, 30], [40, 50, 60]]
    228       # Checks trivial (1,1) shape tensor
    229       checker[1:2, 1:2] = [[66]]
    230       # shrink shape changes
    231       checker[1:2, 1] = [66]
    232       checker[1, 1:2] = [66]
    233       checker[1, 1] = 66
    234       # newaxis shape changes
    235       checker[:, None, :] = [[[10, 20, 30]], [[40, 50, 50]]]
    236       # shrink and newaxis
    237       checker[None, None, 0, 0:1] = [[[99]]]
    238       # Non unit strides
    239       checker[::1, 1::-1] = [[3, 33], [4, 44]]
    240       # degenerate interval
    241       checker[8:10, 0] = []
    242       checker[8:10, 8:10] = [[]]
    243 
    244       # Assign vector to scalar (rank-0) using newaxis
    245       checker2 = StridedSliceAssignChecker(self, 222, dtype=dtype)
    246       checker2[()] = 6  # no indices
    247       checker2[...] = 6  # ellipsis
    248       checker2[None] = [6]  # new axis
    249 
    250   def testUninitialized(self):
    251     with self.assertRaisesRegexp(errors.InvalidArgumentError,
    252                                  "uninitialized variable"):
    253       with self.test_session() as sess, self.test_scope():
    254         v = resource_variable_ops.ResourceVariable([1, 2])
    255         sess.run(v[:].assign([1, 2]))
    256 
    257 
    258 if __name__ == "__main__":
    259   googletest.main()
    260