1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tensorflow.ops.tf.Assign*.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.framework import dtypes 24 from tensorflow.python.ops import array_ops 25 from tensorflow.python.ops import state_ops 26 from tensorflow.python.ops import variables 27 from tensorflow.python.platform import test 28 29 30 class AssignOpTest(test.TestCase): 31 32 def _initAssignFetch(self, x, y, use_gpu=False): 33 """Initialize a param to init and update it with y.""" 34 super(AssignOpTest, self).setUp() 35 with self.test_session(use_gpu=use_gpu): 36 p = variables.Variable(x) 37 assign = state_ops.assign(p, y) 38 p.initializer.run() 39 new_value = assign.eval() 40 return p.eval(), new_value 41 42 def _initAssignAddFetch(self, x, y, use_gpu=False): 43 """Initialize a param to init, and compute param += y.""" 44 with self.test_session(use_gpu=use_gpu): 45 p = variables.Variable(x) 46 add = state_ops.assign_add(p, y) 47 p.initializer.run() 48 new_value = add.eval() 49 return p.eval(), new_value 50 51 def _initAssignSubFetch(self, x, y, use_gpu=False): 52 """Initialize a param to init, and compute param -= y.""" 53 with self.test_session(use_gpu=use_gpu): 54 p = variables.Variable(x) 55 sub = state_ops.assign_sub(p, y) 56 p.initializer.run() 57 new_value = sub.eval() 58 return p.eval(), new_value 59 60 def _testTypes(self, vals): 61 for dtype in [np.float32, np.float64, np.int32, np.int64]: 62 x = np.zeros(vals.shape).astype(dtype) 63 y = vals.astype(dtype) 64 var_value, op_value = self._initAssignFetch(x, y, use_gpu=False) 65 self.assertAllEqual(y, var_value) 66 self.assertAllEqual(y, op_value) 67 var_value, op_value = self._initAssignAddFetch(x, y, use_gpu=False) 68 self.assertAllEqual(x + y, var_value) 69 self.assertAllEqual(x + y, op_value) 70 var_value, op_value = self._initAssignSubFetch(x, y, use_gpu=False) 71 self.assertAllEqual(x - y, var_value) 72 self.assertAllEqual(x - y, op_value) 73 if test.is_built_with_cuda() and dtype in [np.float32, np.float64]: 74 var_value, op_value = self._initAssignFetch(x, y, use_gpu=True) 75 self.assertAllEqual(y, var_value) 76 self.assertAllEqual(y, op_value) 77 var_value, op_value = self._initAssignAddFetch(x, y, use_gpu=True) 78 self.assertAllEqual(x + y, var_value) 79 self.assertAllEqual(x + y, op_value) 80 var_value, op_value = self._initAssignSubFetch(x, y, use_gpu=False) 81 self.assertAllEqual(x - y, var_value) 82 self.assertAllEqual(x - y, op_value) 83 84 def testBasic(self): 85 self._testTypes(np.arange(0, 20).reshape([4, 5])) 86 87 def testAssignNonStrictShapeChecking(self): 88 with self.test_session(): 89 data = array_ops.fill([1024, 1024], 0) 90 p = variables.Variable([1]) 91 a = state_ops.assign(p, data, validate_shape=False) 92 a.op.run() 93 self.assertAllEqual(p.eval(), data.eval()) 94 95 # Assign to yet another shape 96 data2 = array_ops.fill([10, 10], 1) 97 a2 = state_ops.assign(p, data2, validate_shape=False) 98 a2.op.run() 99 self.assertAllEqual(p.eval(), data2.eval()) 100 101 def testInitRequiredAssignAdd(self): 102 with self.test_session(): 103 p = variables.Variable(array_ops.fill([1024, 1024], 1), dtypes.int32) 104 a = state_ops.assign_add(p, array_ops.fill([1024, 1024], 0)) 105 with self.assertRaisesOpError("use uninitialized"): 106 a.op.run() 107 108 def testInitRequiredAssignSub(self): 109 with self.test_session(): 110 p = variables.Variable(array_ops.fill([1024, 1024], 1), dtypes.int32) 111 a = state_ops.assign_sub(p, array_ops.fill([1024, 1024], 0)) 112 with self.assertRaisesOpError("use uninitialized"): 113 a.op.run() 114 115 116 if __name__ == "__main__": 117 test.main() 118