Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """critical section tests."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.framework.python.ops import critical_section_ops
     22 from tensorflow.python.framework import dtypes
     23 from tensorflow.python.framework import function
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.framework import tensor_shape
     26 from tensorflow.python.framework import test_util
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import resource_variable_ops
     29 from tensorflow.python.platform import test
     30 # TODO(ebrevdo): Re-enable once CriticalSection is in core.
     31 # from tensorflow.python.training import saver as saver_lib
     32 
     33 
     34 class CriticalSectionTest(test.TestCase):
     35 
     36   @test_util.run_in_graph_and_eager_modes()
     37   def testCreateCriticalSection(self):
     38     cs = critical_section_ops.CriticalSection(name="cs")
     39     v = resource_variable_ops.ResourceVariable(0.0, name="v")
     40 
     41     def fn(a, b):
     42       c = v.read_value()
     43       with ops.control_dependencies([c]):
     44         nv = v.assign_add(a * b)
     45         with ops.control_dependencies([nv]):
     46           return array_ops.identity(c)
     47 
     48     num_concurrent = 1000
     49     r = [cs.execute(fn, 1.0, 2.0) for _ in range(num_concurrent)]
     50     self.evaluate(v.initializer)
     51     r_value = self.evaluate(r)
     52     self.assertAllClose([2.0 * i for i in range(num_concurrent)],
     53                         sorted(r_value))
     54 
     55   @test_util.run_in_graph_and_eager_modes()
     56   def testCreateCriticalSectionFnReturnsOp(self):
     57     cs = critical_section_ops.CriticalSection(name="cs")
     58     v = resource_variable_ops.ResourceVariable(0.0, name="v")
     59 
     60     def fn_return_op(a, b):
     61       c = v.read_value()
     62       with ops.control_dependencies([c]):
     63         nv = v.assign_add(a * b)
     64         with ops.control_dependencies([nv]):
     65           return ()
     66 
     67     num_concurrent = 100
     68     r = [cs.execute(fn_return_op, 1.0, 2.0) for _ in range(num_concurrent)]
     69     self.evaluate(v.initializer)
     70     self.evaluate(r)
     71     final_v = self.evaluate(v)
     72     self.assertAllClose(2.0 * num_concurrent, final_v)
     73 
     74   def testCreateCriticalSectionRaw(self):
     75     cs = critical_section_ops.CriticalSection(name="cs")
     76     v = resource_variable_ops.ResourceVariable(0.0, name="v")
     77 
     78     @function.Defun(dtypes.float32, dtypes.float32)
     79     def fn(a, b):
     80       c = v.read_value()
     81       with ops.control_dependencies([c]):
     82         nv = v.assign_add(a * b)
     83         with ops.control_dependencies([nv]):
     84           return array_ops.identity(c)
     85 
     86     def execute(fn, *args):
     87       output_args = fn.definition.signature.output_arg
     88       return resource_variable_ops.execute_in_critical_section(
     89           critical_section=cs._handle,
     90           arguments=list(args) + fn.captured_inputs,
     91           f=fn,
     92           output_types=[out.type for out in output_args],
     93           output_shapes=[tensor_shape.TensorShape(None) for _ in output_args])
     94 
     95     num_concurrent = 1000
     96     r = [execute(fn, 1.0, 2.0)[0] for _ in range(num_concurrent)]
     97     self.evaluate(v.initializer)
     98     r_value = self.evaluate(r)
     99     self.assertAllClose([2.0 * i for i in range(num_concurrent)],
    100                         sorted(r_value))
    101 
    102   def testCollection(self):
    103     cs = critical_section_ops.CriticalSection(name="cs")
    104     self.assertIn(
    105         cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS))
    106     execute_op = cs.execute(lambda x: x + 1, 1.0).op
    107     self.assertIn(
    108         execute_op,
    109         [signature.op for signature in
    110          ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)])
    111 
    112   @test_util.run_in_graph_and_eager_modes()
    113   def testRecursiveCriticalSectionAccessIsIllegal(self):
    114     cs = critical_section_ops.CriticalSection(name="cs")
    115     def fn(x):
    116       return cs.execute(lambda x: x+1, x)
    117     with self.assertRaisesRegexp(
    118         ValueError,
    119         r"attempts to access the CriticalSection in which it would be running"):
    120       cs.execute(fn, 1.0)
    121 
    122   def testMultipleCSExecutionsRequestSameResource(self):
    123     cs0 = critical_section_ops.CriticalSection()
    124     cs1 = critical_section_ops.CriticalSection()
    125     v = resource_variable_ops.ResourceVariable(0.0, name="v")
    126     cs0.execute(lambda: v + 1)
    127     # It's OK for the same CriticalSection to access this resource.
    128     cs0.execute(lambda: v - 1)
    129     # It's *not* OK for a different CriticalSection to access it by
    130     # default.
    131     with self.assertRaisesRegexp(
    132         ValueError, "requested exclusive resource access"):
    133       cs1.execute(lambda: v + 1)
    134     # It's not even OK if the second call doesn't request exclusive access.
    135     with self.assertRaisesRegexp(
    136         ValueError, "requested exclusive resource access"):
    137       cs1.execute(lambda: v + 1, exclusive_resource_access=False)
    138 
    139     v2 = resource_variable_ops.ResourceVariable(0.0, name="v2")
    140     cs0.execute(lambda: v2 + 1, exclusive_resource_access=False)
    141     # It's OK if neither requests exclusive resource access.
    142     cs1.execute(lambda: v2 + 1, exclusive_resource_access=False)
    143 
    144     # It's not OK if the second request requires exlusive resource
    145     # access.
    146     with self.assertRaisesRegexp(
    147         ValueError, "requested exclusive resource access"):
    148       cs1.execute(lambda: v2 + 1)
    149 
    150   # TODO(ebrevdo): Re-enable once CriticalSection is in core.
    151   #
    152   # def testCriticalSectionAndExecuteOpSaverRoundTrip(self):
    153   #   cs = critical_section_ops.CriticalSection()
    154   #   r = cs.execute(lambda x: x + 1, 1.0)
    155   #   graph = ops.get_default_graph()
    156   #   meta_graph = saver_lib.export_meta_graph(
    157   #       graph=graph, collection_list=graph.get_all_collection_keys())
    158   #   graph_copy = ops.Graph()
    159   #   with graph_copy.as_default():
    160   #     _ = saver_lib.import_meta_graph(meta_graph, import_scope="imported")
    161   #     restored_cs = ops.get_collection(critical_section_ops.CRITICAL_SECTIONS)
    162   #     restored_exec = ops.get_collection(
    163   #         critical_section_ops.CRITICAL_SECTION_EXECUTIONS)
    164   #     self.assertEqual(1, len(restored_cs))
    165   #     self.assertEqual(1, len(restored_exec))
    166   #     self.assertEqual(restored_cs[0].name, "imported/%s" % cs.name)
    167   #     self.assertEqual(restored_exec[0].op.name, "imported/%s" % r.op.name)
    168 
    169   # def testToProto(self):
    170   #   cs = critical_section_ops.CriticalSection(name="cs")
    171   #   proto = cs.to_proto()
    172   #   self.assertEqual(proto.critical_section_name, cs._handle.name)
    173   #   cs_copy = critical_section_ops.CriticalSection.from_proto(proto)
    174   #   self.assertEqual(cs_copy._handle, cs._handle)
    175 
    176 
    177 if __name__ == "__main__":
    178   test.main()
    179