1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """critical section tests.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib.framework.python.ops import critical_section_ops 22 from tensorflow.python.framework import dtypes 23 from tensorflow.python.framework import function 24 from tensorflow.python.framework import ops 25 from tensorflow.python.framework import tensor_shape 26 from tensorflow.python.framework import test_util 27 from tensorflow.python.ops import array_ops 28 from tensorflow.python.ops import resource_variable_ops 29 from tensorflow.python.platform import test 30 # TODO(ebrevdo): Re-enable once CriticalSection is in core. 31 # from tensorflow.python.training import saver as saver_lib 32 33 34 class CriticalSectionTest(test.TestCase): 35 36 @test_util.run_in_graph_and_eager_modes() 37 def testCreateCriticalSection(self): 38 cs = critical_section_ops.CriticalSection(name="cs") 39 v = resource_variable_ops.ResourceVariable(0.0, name="v") 40 41 def fn(a, b): 42 c = v.read_value() 43 with ops.control_dependencies([c]): 44 nv = v.assign_add(a * b) 45 with ops.control_dependencies([nv]): 46 return array_ops.identity(c) 47 48 num_concurrent = 1000 49 r = [cs.execute(fn, 1.0, 2.0) for _ in range(num_concurrent)] 50 self.evaluate(v.initializer) 51 r_value = self.evaluate(r) 52 self.assertAllClose([2.0 * i for i in range(num_concurrent)], 53 sorted(r_value)) 54 55 @test_util.run_in_graph_and_eager_modes() 56 def testCreateCriticalSectionFnReturnsOp(self): 57 cs = critical_section_ops.CriticalSection(name="cs") 58 v = resource_variable_ops.ResourceVariable(0.0, name="v") 59 60 def fn_return_op(a, b): 61 c = v.read_value() 62 with ops.control_dependencies([c]): 63 nv = v.assign_add(a * b) 64 with ops.control_dependencies([nv]): 65 return () 66 67 num_concurrent = 100 68 r = [cs.execute(fn_return_op, 1.0, 2.0) for _ in range(num_concurrent)] 69 self.evaluate(v.initializer) 70 self.evaluate(r) 71 final_v = self.evaluate(v) 72 self.assertAllClose(2.0 * num_concurrent, final_v) 73 74 def testCreateCriticalSectionRaw(self): 75 cs = critical_section_ops.CriticalSection(name="cs") 76 v = resource_variable_ops.ResourceVariable(0.0, name="v") 77 78 @function.Defun(dtypes.float32, dtypes.float32) 79 def fn(a, b): 80 c = v.read_value() 81 with ops.control_dependencies([c]): 82 nv = v.assign_add(a * b) 83 with ops.control_dependencies([nv]): 84 return array_ops.identity(c) 85 86 def execute(fn, *args): 87 output_args = fn.definition.signature.output_arg 88 return resource_variable_ops.execute_in_critical_section( 89 critical_section=cs._handle, 90 arguments=list(args) + fn.captured_inputs, 91 f=fn, 92 output_types=[out.type for out in output_args], 93 output_shapes=[tensor_shape.TensorShape(None) for _ in output_args]) 94 95 num_concurrent = 1000 96 r = [execute(fn, 1.0, 2.0)[0] for _ in range(num_concurrent)] 97 self.evaluate(v.initializer) 98 r_value = self.evaluate(r) 99 self.assertAllClose([2.0 * i for i in range(num_concurrent)], 100 sorted(r_value)) 101 102 def testCollection(self): 103 cs = critical_section_ops.CriticalSection(name="cs") 104 self.assertIn( 105 cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS)) 106 execute_op = cs.execute(lambda x: x + 1, 1.0).op 107 self.assertIn( 108 execute_op, 109 [signature.op for signature in 110 ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)]) 111 112 @test_util.run_in_graph_and_eager_modes() 113 def testRecursiveCriticalSectionAccessIsIllegal(self): 114 cs = critical_section_ops.CriticalSection(name="cs") 115 def fn(x): 116 return cs.execute(lambda x: x+1, x) 117 with self.assertRaisesRegexp( 118 ValueError, 119 r"attempts to access the CriticalSection in which it would be running"): 120 cs.execute(fn, 1.0) 121 122 def testMultipleCSExecutionsRequestSameResource(self): 123 cs0 = critical_section_ops.CriticalSection() 124 cs1 = critical_section_ops.CriticalSection() 125 v = resource_variable_ops.ResourceVariable(0.0, name="v") 126 cs0.execute(lambda: v + 1) 127 # It's OK for the same CriticalSection to access this resource. 128 cs0.execute(lambda: v - 1) 129 # It's *not* OK for a different CriticalSection to access it by 130 # default. 131 with self.assertRaisesRegexp( 132 ValueError, "requested exclusive resource access"): 133 cs1.execute(lambda: v + 1) 134 # It's not even OK if the second call doesn't request exclusive access. 135 with self.assertRaisesRegexp( 136 ValueError, "requested exclusive resource access"): 137 cs1.execute(lambda: v + 1, exclusive_resource_access=False) 138 139 v2 = resource_variable_ops.ResourceVariable(0.0, name="v2") 140 cs0.execute(lambda: v2 + 1, exclusive_resource_access=False) 141 # It's OK if neither requests exclusive resource access. 142 cs1.execute(lambda: v2 + 1, exclusive_resource_access=False) 143 144 # It's not OK if the second request requires exlusive resource 145 # access. 146 with self.assertRaisesRegexp( 147 ValueError, "requested exclusive resource access"): 148 cs1.execute(lambda: v2 + 1) 149 150 # TODO(ebrevdo): Re-enable once CriticalSection is in core. 151 # 152 # def testCriticalSectionAndExecuteOpSaverRoundTrip(self): 153 # cs = critical_section_ops.CriticalSection() 154 # r = cs.execute(lambda x: x + 1, 1.0) 155 # graph = ops.get_default_graph() 156 # meta_graph = saver_lib.export_meta_graph( 157 # graph=graph, collection_list=graph.get_all_collection_keys()) 158 # graph_copy = ops.Graph() 159 # with graph_copy.as_default(): 160 # _ = saver_lib.import_meta_graph(meta_graph, import_scope="imported") 161 # restored_cs = ops.get_collection(critical_section_ops.CRITICAL_SECTIONS) 162 # restored_exec = ops.get_collection( 163 # critical_section_ops.CRITICAL_SECTION_EXECUTIONS) 164 # self.assertEqual(1, len(restored_cs)) 165 # self.assertEqual(1, len(restored_exec)) 166 # self.assertEqual(restored_cs[0].name, "imported/%s" % cs.name) 167 # self.assertEqual(restored_exec[0].op.name, "imported/%s" % r.op.name) 168 169 # def testToProto(self): 170 # cs = critical_section_ops.CriticalSection(name="cs") 171 # proto = cs.to_proto() 172 # self.assertEqual(proto.critical_section_name, cs._handle.name) 173 # cs_copy = critical_section_ops.CriticalSection.from_proto(proto) 174 # self.assertEqual(cs_copy._handle, cs._handle) 175 176 177 if __name__ == "__main__": 178 test.main() 179