Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.ops.tf.Assign*."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.ops import array_ops
     25 from tensorflow.python.ops import state_ops
     26 from tensorflow.python.ops import variables
     27 from tensorflow.python.platform import test
     28 
     29 
     30 class AssignOpTest(test.TestCase):
     31 
     32   def _initAssignFetch(self, x, y, use_gpu=False):
     33     """Initialize a param to init and update it with y."""
     34     super(AssignOpTest, self).setUp()
     35     with self.test_session(use_gpu=use_gpu):
     36       p = variables.Variable(x)
     37       assign = state_ops.assign(p, y)
     38       p.initializer.run()
     39       new_value = assign.eval()
     40       return p.eval(), new_value
     41 
     42   def _initAssignAddFetch(self, x, y, use_gpu=False):
     43     """Initialize a param to init, and compute param += y."""
     44     with self.test_session(use_gpu=use_gpu):
     45       p = variables.Variable(x)
     46       add = state_ops.assign_add(p, y)
     47       p.initializer.run()
     48       new_value = add.eval()
     49       return p.eval(), new_value
     50 
     51   def _initAssignSubFetch(self, x, y, use_gpu=False):
     52     """Initialize a param to init, and compute param -= y."""
     53     with self.test_session(use_gpu=use_gpu):
     54       p = variables.Variable(x)
     55       sub = state_ops.assign_sub(p, y)
     56       p.initializer.run()
     57       new_value = sub.eval()
     58       return p.eval(), new_value
     59 
     60   def _testTypes(self, vals):
     61     for dtype in [np.float32, np.float64, np.int32, np.int64]:
     62       x = np.zeros(vals.shape).astype(dtype)
     63       y = vals.astype(dtype)
     64       var_value, op_value = self._initAssignFetch(x, y, use_gpu=False)
     65       self.assertAllEqual(y, var_value)
     66       self.assertAllEqual(y, op_value)
     67       var_value, op_value = self._initAssignAddFetch(x, y, use_gpu=False)
     68       self.assertAllEqual(x + y, var_value)
     69       self.assertAllEqual(x + y, op_value)
     70       var_value, op_value = self._initAssignSubFetch(x, y, use_gpu=False)
     71       self.assertAllEqual(x - y, var_value)
     72       self.assertAllEqual(x - y, op_value)
     73       if test.is_built_with_cuda() and dtype in [np.float32, np.float64]:
     74         var_value, op_value = self._initAssignFetch(x, y, use_gpu=True)
     75         self.assertAllEqual(y, var_value)
     76         self.assertAllEqual(y, op_value)
     77         var_value, op_value = self._initAssignAddFetch(x, y, use_gpu=True)
     78         self.assertAllEqual(x + y, var_value)
     79         self.assertAllEqual(x + y, op_value)
     80         var_value, op_value = self._initAssignSubFetch(x, y, use_gpu=False)
     81         self.assertAllEqual(x - y, var_value)
     82         self.assertAllEqual(x - y, op_value)
     83 
     84   def testBasic(self):
     85     self._testTypes(np.arange(0, 20).reshape([4, 5]))
     86 
     87   def testAssignNonStrictShapeChecking(self):
     88     with self.test_session():
     89       data = array_ops.fill([1024, 1024], 0)
     90       p = variables.Variable([1])
     91       a = state_ops.assign(p, data, validate_shape=False)
     92       a.op.run()
     93       self.assertAllEqual(p.eval(), data.eval())
     94 
     95       # Assign to yet another shape
     96       data2 = array_ops.fill([10, 10], 1)
     97       a2 = state_ops.assign(p, data2, validate_shape=False)
     98       a2.op.run()
     99       self.assertAllEqual(p.eval(), data2.eval())
    100 
    101   def testInitRequiredAssignAdd(self):
    102     with self.test_session():
    103       p = variables.Variable(array_ops.fill([1024, 1024], 1), dtypes.int32)
    104       a = state_ops.assign_add(p, array_ops.fill([1024, 1024], 0))
    105       with self.assertRaisesOpError("use uninitialized"):
    106         a.op.run()
    107 
    108   def testInitRequiredAssignSub(self):
    109     with self.test_session():
    110       p = variables.Variable(array_ops.fill([1024, 1024], 1), dtypes.int32)
    111       a = state_ops.assign_sub(p, array_ops.fill([1024, 1024], 0))
    112       with self.assertRaisesOpError("use uninitialized"):
    113         a.op.run()
    114 
    115 
    116 if __name__ == "__main__":
    117   test.main()
    118