Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Critical Section object and execution logic."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 
     23 # TODO(ebrevdo): Re-enable once CriticalSection is in core.
     24 # from tensorflow.core.protobuf import critical_section_pb2
     25 
     26 from tensorflow.python.eager import context
     27 from tensorflow.python.eager import function
     28 from tensorflow.python.framework import dtypes
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.ops import gen_resource_variable_ops
     31 from tensorflow.python.util import nest
     32 
     33 
     34 # Graph Keys
     35 CRITICAL_SECTIONS = "critical_sections"
     36 CRITICAL_SECTION_EXECUTIONS = "critical_section_executions"
     37 
     38 
     39 class _ExecutionSignature(
     40     collections.namedtuple("_ExecutionSignature",
     41                            ("op", "exclusive_resource_access"))):
     42   """A class storing an `ExecuteInCriticalResource` op and associated attrs."""
     43   pass
     44 
     45 
     46 class CriticalSection(object):
     47   """Critical section.
     48 
     49   A `CriticalSection` object is a resource in the graph which executes subgraphs
     50   in **serial** order.  A common example of a subgraph one may wish to run
     51   exclusively is the one given by the following function:
     52 
     53   ```python
     54   v = resource_variable_ops.ResourceVariable(0.0, name="v")
     55 
     56   def count():
     57     value = v.read_value()
     58     with tf.control_dependencies([value]):
     59       with tf.control_dependencies([v.assign_add(1)]):
     60         return tf.identity(value)
     61   ```
     62 
     63   Here, a snapshot of `v` is captured in `value`; and then `v` is updated.
     64   The snapshot value is returned.
     65 
     66   If multiple workers or threads all execute `count` in parallel, there is no
     67   guarantee that access to the variable `v` is atomic at any point within
     68   any thread's calculation of `count`.  In fact, even implementing an atomic
     69   counter that guarantees that the user will see each value `0, 1, ...,` is
     70   currently impossible.
     71 
     72   The solution is to ensure any access to the underlying resource `v` is
     73   only processed through a critical section:
     74 
     75   ```python
     76   cs = CriticalSection()
     77   f1 = cs.execute(count)
     78   f2 = cs.execute(count)
     79   output = f1 + f2
     80   session.run(output)
     81   ```
     82   The functions `f1` and `f2` will be executed serially, and updates to `v`
     83   will be atomic.
     84 
     85   **NOTES**
     86 
     87   All resource objects, including the critical section and any captured
     88   variables of functions executed on that critical section, will be
     89   colocated to the same device (host and cpu/gpu).
     90 
     91   When using multiple critical sections on the same resources, there is no
     92   guarantee of exclusive access to those resources.  This behavior is disallowed
     93   by default (but see the kwarg `exclusive_resource_access`).
     94 
     95   For example, running the same function in two separate critical sections
     96   will not ensure serial execution:
     97 
     98   ```python
     99   v = tf.get_variable("v", initializer=0.0, use_resource=True)
    100   def accumulate(up):
    101     x = v.read_value()
    102     with tf.control_dependencies([x]):
    103       with tf.control_dependencies([v.assign_add(up)]):
    104         return tf.identity(x)
    105   ex1 = CriticalSection().execute(
    106     accumulate, 1.0, exclusive_resource_access=False)
    107   ex2 = CriticalSection().execute(
    108     accumulate, 1.0, exclusive_resource_access=False)
    109   bad_sum = ex1 + ex2
    110   sess.run(v.initializer)
    111   sess.run(bad_sum)  # May return 0.0
    112   ```
    113   """
    114 
    115   def __init__(self, name=None, critical_section_def=None, import_scope=None):
    116     """Creates a critical section."""
    117     if critical_section_def and name is not None:
    118       raise ValueError("critical_section_def and name are mutually exclusive.")
    119     if critical_section_def:
    120       self._init_from_proto(critical_section_def, import_scope=import_scope)
    121     else:
    122       self._init_from_args(name)
    123 
    124   def _init_from_proto(self, critical_section_def, import_scope):
    125     raise NotImplementedError("Not yet implemented")
    126     # TODO(ebrevdo): Re-enable once CriticalSection is in core.
    127     # assert isinstance(
    128     #     critical_section_def, critical_section_pb2.CriticalSectionDef)
    129     # # Create from critical_section_def.
    130     # g = ops.get_default_graph()
    131     # self._handle = g.as_graph_element(
    132     #     ops.prepend_name_scope(
    133     #         critical_section_def.critical_section_name,
    134     #         import_scope=import_scope))
    135 
    136   def _init_from_args(self, name):
    137     """Initialize the CriticalSection from constructor arguments."""
    138     with ops.name_scope(name, "CriticalSection", []) as name:
    139       with ops.control_dependencies(None):
    140         # pylint: disable=protected-access
    141         handle_name = ops._name_from_scope_name(name)
    142         container = ops.get_default_graph()._container
    143         # pylint: enable=protected-access
    144         if container is None:
    145           container = ""
    146         self._handle = gen_resource_variable_ops.critical_section_op(
    147             shared_name=handle_name, name=name)
    148     if context.in_graph_mode():
    149       ops.add_to_collections(CRITICAL_SECTIONS, self)
    150 
    151   @property
    152   def name(self):
    153     return self._handle.op.name
    154 
    155   def execute(self, fn, *args, **kwargs):
    156     """Execute function `fn(*args, **kwargs)` inside the CriticalSection.
    157 
    158     Args:
    159       fn: The function to execute.  Must return at least one tensor.
    160       *args: Additional positional arguments to `fn`.
    161       **kwargs: Additional keyword arguments to `fn`.
    162         Several keywords are reserved for `execute`.  These are:
    163 
    164         - name; The name to use when creating the execute operation.
    165         - exclusive_resource_access; Whether the resources required by
    166           `fn` should be exclusive to this `CriticalSection`.  Default: `True`.
    167           You may want to set this to `False` if you will be accessing a
    168           resource in read-only mode in two different CriticalSections.
    169 
    170     Returns:
    171       The tensors returned from `fn(*args, **kwargs)`.
    172 
    173     Raises:
    174       ValueError: If `fn` attempts to use this `CriticalSection` in any nested
    175         way.
    176       ValueError: If `exclusive_resource_access` is not provided (is `True`) and
    177         another `CriticalSection` has an execution requesting the same
    178         resources as in `*args`, `**kwargs`, and any additionaly captured
    179         inputs in `fn`.  Note, even if `exclusive_resource_access` is `True`,
    180         if another execution in another `CriticalSection` was created without
    181         `exclusive_resource_access=True`, a `ValueError` will be raised.
    182     """
    183     name = kwargs.pop("name", None)
    184     exclusive_resource_access = kwargs.pop("exclusive_resource_access", True)
    185 
    186     args = nest.map_structure(ops.convert_to_tensor, args)
    187     with ops.name_scope(name, "critical_section_execute", []):
    188       fn_op = function.make_defun_op(fn, *args, **kwargs)
    189       flat_dtypes = nest.flatten(fn_op.output_dtypes)
    190       flat_shapes = nest.flatten(fn_op.output_shapes)
    191       all_inputs = nest.flatten(args) + fn_op.captured_inputs
    192       if self._handle in all_inputs:
    193         raise ValueError("The function fn attempts to access the "
    194                          "CriticalSection in which it would be running.  This "
    195                          "is illegal and would cause deadlocks.  "
    196                          "CriticalSection: %s." % self._handle)
    197 
    198       if context.in_graph_mode():
    199         # Collections and op introspection does not work in eager
    200         # mode.  This is generally ok; since eager mode (as of
    201         # writing) executes sequentially anyway.
    202         all_input_resources = [
    203             x for x in all_inputs if x.dtype == dtypes.resource]
    204         for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
    205           if sg.op.inputs[0].name == self._handle.name:
    206             # Other executions in the same critical section are allowed.
    207             continue
    208           if not (exclusive_resource_access or sg.exclusive_resource_access):
    209             # Neither execution requested exclusive access.
    210             continue
    211           sg_input_names = [y.name for y in sg.op.inputs[1:]]
    212           for res in all_input_resources:
    213             if res.name in sg_input_names:
    214               raise ValueError(
    215                   "This execution would access resource %s; but either this "
    216                   "execution (CriticalSection: %s) or Execution '%s' "
    217                   "(CriticalSection: %s) requested exclusive resource access "
    218                   "of this resource for their critical section.  Did you mean "
    219                   "to call execute with keyword argument "
    220                   "exclusive_resource_access=False?"
    221                   % (res.name,
    222                      self.name,
    223                      sg.op.name,
    224                      sg.op.inputs[0].op.name))
    225 
    226       flat_outputs = gen_resource_variable_ops.execute_in_critical_section(
    227           critical_section=self._handle,
    228           arguments=all_inputs,
    229           f=fn_op,
    230           output_types=flat_dtypes,
    231           output_shapes=flat_shapes)
    232 
    233       if context.in_graph_mode():
    234         if isinstance(flat_outputs, ops.Operation):
    235           flat_outputs = [flat_outputs]
    236         op = (flat_outputs[0].op if isinstance(flat_outputs[0], ops.Tensor)
    237               else flat_outputs[0])
    238         signature = _ExecutionSignature(
    239             op=op,
    240             exclusive_resource_access=exclusive_resource_access)
    241         ops.add_to_collections(
    242             CRITICAL_SECTION_EXECUTIONS, signature)
    243 
    244       return (flat_outputs[0]
    245               if (len(flat_outputs) == 1
    246                   and isinstance(flat_outputs[0], ops.Operation))
    247               else nest.pack_sequence_as(fn_op.output_dtypes, flat_outputs))
    248 
    249   # TODO(ebrevdo): Re-enable once CriticalSection is in core.
    250 
    251   # def to_proto(self, export_scope=None):
    252   #   """Converts a `CriticalSection` to a `CriticalSectoinDef` protocol buffer.
    253 
    254   #   Args:
    255   #     export_scope: Optional `string`. Name scope to remove.
    256 
    257   #   Returns:
    258   #     A `CriticalSectionDef` protocol buffer, or `None` if the
    259   #     `CriticalSection` is not in the specified name scope.
    260   #   """
    261   #   if export_scope is None or self.handle.name.startswith(export_scope):
    262   #     cs_def = critical_section_pb2.CriticalSectionDef()
    263   #     cs_def.critical_section_name = ops.strip_name_scope(
    264   #         self._handle.name, export_scope)
    265   #     return cs_def
    266   #   else:
    267   #     return None
    268 
    269   # @staticmethod
    270   # def from_proto(critical_section_def, import_scope=None):
    271   #   return CriticalSection(
    272   #       critical_section_def=critical_section_def, import_scope=import_scope)
    273 
    274 
    275 # TODO(ebrevdo): Re-enable once CriticalSection is in core.
    276 
    277 # def _execution_to_proto_fn(execution_signature, export_scope=None):
    278 #   """Converts `_ExecutionSignature` to a `CriticalSectionExecutionDef`.
    279 
    280 #   Args:
    281 #     execution_signature: Instance of `_ExecutionSignature`.
    282 #     export_scope: The export scope, if any.
    283 
    284 #   Returns:
    285 #     An instance of `CriticalSectionExecutionDef`.
    286 #   """
    287 #   if (export_scope is None
    288 #       or execution_signature.op.name.startswith(export_scope)):
    289 #     op_def = critical_section_pb2.CriticalSectionExecutionDef()
    290 #     op_def.execute_in_critical_section_name = ops.strip_name_scope(
    291 #         execution_signature.op.name, export_scope)
    292 #     op_def.exclusive_resource_access = (
    293 #         execution_signature.exclusive_resource_access)
    294 #     return op_def
    295 #   else:
    296 #     return None
    297 
    298 
    299 # def _execution_from_proto_fn(op_def, import_scope=None):
    300 #   """Converts a `CriticalSectionExecutionDef` to a `_ExecutionSignature`."""
    301 #   assert isinstance(
    302 #       op_def, critical_section_pb2.CriticalSectionExecutionDef)
    303 
    304 #   # Create from op_def.
    305 #   g = ops.get_default_graph()
    306 #   execution_op = g.as_graph_element(
    307 #       ops.prepend_name_scope(
    308 #           op_def.execute_in_critical_section_name,
    309 #           import_scope=import_scope))
    310 #   return _ExecutionSignature(
    311 #       op=execution_op,
    312 #       exclusive_resource_access=op_def.exclusive_resource_access)
    313 
    314 # ops.register_proto_function(
    315 #     CRITICAL_SECTIONS,
    316 #     proto_type=critical_section_pb2.CriticalSectionDef,
    317 #     to_proto=CriticalSection.to_proto,
    318 #     from_proto=CriticalSection.from_proto)
    319 
    320 # ops.register_proto_function(
    321 #     CRITICAL_SECTION_EXECUTIONS,
    322 #     proto_type=critical_section_pb2.CriticalSectionExecutionDef,
    323 #     to_proto=_execution_to_proto_fn,
    324 #     from_proto=_execution_from_proto_fn)
    325