1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Critical Section object and execution logic.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import collections 22 23 # TODO(ebrevdo): Re-enable once CriticalSection is in core. 24 # from tensorflow.core.protobuf import critical_section_pb2 25 26 from tensorflow.python.eager import context 27 from tensorflow.python.eager import function 28 from tensorflow.python.framework import dtypes 29 from tensorflow.python.framework import ops 30 from tensorflow.python.ops import gen_resource_variable_ops 31 from tensorflow.python.util import nest 32 33 34 # Graph Keys 35 CRITICAL_SECTIONS = "critical_sections" 36 CRITICAL_SECTION_EXECUTIONS = "critical_section_executions" 37 38 39 class _ExecutionSignature( 40 collections.namedtuple("_ExecutionSignature", 41 ("op", "exclusive_resource_access"))): 42 """A class storing an `ExecuteInCriticalResource` op and associated attrs.""" 43 pass 44 45 46 class CriticalSection(object): 47 """Critical section. 48 49 A `CriticalSection` object is a resource in the graph which executes subgraphs 50 in **serial** order. A common example of a subgraph one may wish to run 51 exclusively is the one given by the following function: 52 53 ```python 54 v = resource_variable_ops.ResourceVariable(0.0, name="v") 55 56 def count(): 57 value = v.read_value() 58 with tf.control_dependencies([value]): 59 with tf.control_dependencies([v.assign_add(1)]): 60 return tf.identity(value) 61 ``` 62 63 Here, a snapshot of `v` is captured in `value`; and then `v` is updated. 64 The snapshot value is returned. 65 66 If multiple workers or threads all execute `count` in parallel, there is no 67 guarantee that access to the variable `v` is atomic at any point within 68 any thread's calculation of `count`. In fact, even implementing an atomic 69 counter that guarantees that the user will see each value `0, 1, ...,` is 70 currently impossible. 71 72 The solution is to ensure any access to the underlying resource `v` is 73 only processed through a critical section: 74 75 ```python 76 cs = CriticalSection() 77 f1 = cs.execute(count) 78 f2 = cs.execute(count) 79 output = f1 + f2 80 session.run(output) 81 ``` 82 The functions `f1` and `f2` will be executed serially, and updates to `v` 83 will be atomic. 84 85 **NOTES** 86 87 All resource objects, including the critical section and any captured 88 variables of functions executed on that critical section, will be 89 colocated to the same device (host and cpu/gpu). 90 91 When using multiple critical sections on the same resources, there is no 92 guarantee of exclusive access to those resources. This behavior is disallowed 93 by default (but see the kwarg `exclusive_resource_access`). 94 95 For example, running the same function in two separate critical sections 96 will not ensure serial execution: 97 98 ```python 99 v = tf.get_variable("v", initializer=0.0, use_resource=True) 100 def accumulate(up): 101 x = v.read_value() 102 with tf.control_dependencies([x]): 103 with tf.control_dependencies([v.assign_add(up)]): 104 return tf.identity(x) 105 ex1 = CriticalSection().execute( 106 accumulate, 1.0, exclusive_resource_access=False) 107 ex2 = CriticalSection().execute( 108 accumulate, 1.0, exclusive_resource_access=False) 109 bad_sum = ex1 + ex2 110 sess.run(v.initializer) 111 sess.run(bad_sum) # May return 0.0 112 ``` 113 """ 114 115 def __init__(self, name=None, critical_section_def=None, import_scope=None): 116 """Creates a critical section.""" 117 if critical_section_def and name is not None: 118 raise ValueError("critical_section_def and name are mutually exclusive.") 119 if critical_section_def: 120 self._init_from_proto(critical_section_def, import_scope=import_scope) 121 else: 122 self._init_from_args(name) 123 124 def _init_from_proto(self, critical_section_def, import_scope): 125 raise NotImplementedError("Not yet implemented") 126 # TODO(ebrevdo): Re-enable once CriticalSection is in core. 127 # assert isinstance( 128 # critical_section_def, critical_section_pb2.CriticalSectionDef) 129 # # Create from critical_section_def. 130 # g = ops.get_default_graph() 131 # self._handle = g.as_graph_element( 132 # ops.prepend_name_scope( 133 # critical_section_def.critical_section_name, 134 # import_scope=import_scope)) 135 136 def _init_from_args(self, name): 137 """Initialize the CriticalSection from constructor arguments.""" 138 with ops.name_scope(name, "CriticalSection", []) as name: 139 with ops.control_dependencies(None): 140 # pylint: disable=protected-access 141 handle_name = ops._name_from_scope_name(name) 142 container = ops.get_default_graph()._container 143 # pylint: enable=protected-access 144 if container is None: 145 container = "" 146 self._handle = gen_resource_variable_ops.critical_section_op( 147 shared_name=handle_name, name=name) 148 if context.in_graph_mode(): 149 ops.add_to_collections(CRITICAL_SECTIONS, self) 150 151 @property 152 def name(self): 153 return self._handle.op.name 154 155 def execute(self, fn, *args, **kwargs): 156 """Execute function `fn(*args, **kwargs)` inside the CriticalSection. 157 158 Args: 159 fn: The function to execute. Must return at least one tensor. 160 *args: Additional positional arguments to `fn`. 161 **kwargs: Additional keyword arguments to `fn`. 162 Several keywords are reserved for `execute`. These are: 163 164 - name; The name to use when creating the execute operation. 165 - exclusive_resource_access; Whether the resources required by 166 `fn` should be exclusive to this `CriticalSection`. Default: `True`. 167 You may want to set this to `False` if you will be accessing a 168 resource in read-only mode in two different CriticalSections. 169 170 Returns: 171 The tensors returned from `fn(*args, **kwargs)`. 172 173 Raises: 174 ValueError: If `fn` attempts to use this `CriticalSection` in any nested 175 way. 176 ValueError: If `exclusive_resource_access` is not provided (is `True`) and 177 another `CriticalSection` has an execution requesting the same 178 resources as in `*args`, `**kwargs`, and any additionaly captured 179 inputs in `fn`. Note, even if `exclusive_resource_access` is `True`, 180 if another execution in another `CriticalSection` was created without 181 `exclusive_resource_access=True`, a `ValueError` will be raised. 182 """ 183 name = kwargs.pop("name", None) 184 exclusive_resource_access = kwargs.pop("exclusive_resource_access", True) 185 186 args = nest.map_structure(ops.convert_to_tensor, args) 187 with ops.name_scope(name, "critical_section_execute", []): 188 fn_op = function.make_defun_op(fn, *args, **kwargs) 189 flat_dtypes = nest.flatten(fn_op.output_dtypes) 190 flat_shapes = nest.flatten(fn_op.output_shapes) 191 all_inputs = nest.flatten(args) + fn_op.captured_inputs 192 if self._handle in all_inputs: 193 raise ValueError("The function fn attempts to access the " 194 "CriticalSection in which it would be running. This " 195 "is illegal and would cause deadlocks. " 196 "CriticalSection: %s." % self._handle) 197 198 if context.in_graph_mode(): 199 # Collections and op introspection does not work in eager 200 # mode. This is generally ok; since eager mode (as of 201 # writing) executes sequentially anyway. 202 all_input_resources = [ 203 x for x in all_inputs if x.dtype == dtypes.resource] 204 for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS): 205 if sg.op.inputs[0].name == self._handle.name: 206 # Other executions in the same critical section are allowed. 207 continue 208 if not (exclusive_resource_access or sg.exclusive_resource_access): 209 # Neither execution requested exclusive access. 210 continue 211 sg_input_names = [y.name for y in sg.op.inputs[1:]] 212 for res in all_input_resources: 213 if res.name in sg_input_names: 214 raise ValueError( 215 "This execution would access resource %s; but either this " 216 "execution (CriticalSection: %s) or Execution '%s' " 217 "(CriticalSection: %s) requested exclusive resource access " 218 "of this resource for their critical section. Did you mean " 219 "to call execute with keyword argument " 220 "exclusive_resource_access=False?" 221 % (res.name, 222 self.name, 223 sg.op.name, 224 sg.op.inputs[0].op.name)) 225 226 flat_outputs = gen_resource_variable_ops.execute_in_critical_section( 227 critical_section=self._handle, 228 arguments=all_inputs, 229 f=fn_op, 230 output_types=flat_dtypes, 231 output_shapes=flat_shapes) 232 233 if context.in_graph_mode(): 234 if isinstance(flat_outputs, ops.Operation): 235 flat_outputs = [flat_outputs] 236 op = (flat_outputs[0].op if isinstance(flat_outputs[0], ops.Tensor) 237 else flat_outputs[0]) 238 signature = _ExecutionSignature( 239 op=op, 240 exclusive_resource_access=exclusive_resource_access) 241 ops.add_to_collections( 242 CRITICAL_SECTION_EXECUTIONS, signature) 243 244 return (flat_outputs[0] 245 if (len(flat_outputs) == 1 246 and isinstance(flat_outputs[0], ops.Operation)) 247 else nest.pack_sequence_as(fn_op.output_dtypes, flat_outputs)) 248 249 # TODO(ebrevdo): Re-enable once CriticalSection is in core. 250 251 # def to_proto(self, export_scope=None): 252 # """Converts a `CriticalSection` to a `CriticalSectoinDef` protocol buffer. 253 254 # Args: 255 # export_scope: Optional `string`. Name scope to remove. 256 257 # Returns: 258 # A `CriticalSectionDef` protocol buffer, or `None` if the 259 # `CriticalSection` is not in the specified name scope. 260 # """ 261 # if export_scope is None or self.handle.name.startswith(export_scope): 262 # cs_def = critical_section_pb2.CriticalSectionDef() 263 # cs_def.critical_section_name = ops.strip_name_scope( 264 # self._handle.name, export_scope) 265 # return cs_def 266 # else: 267 # return None 268 269 # @staticmethod 270 # def from_proto(critical_section_def, import_scope=None): 271 # return CriticalSection( 272 # critical_section_def=critical_section_def, import_scope=import_scope) 273 274 275 # TODO(ebrevdo): Re-enable once CriticalSection is in core. 276 277 # def _execution_to_proto_fn(execution_signature, export_scope=None): 278 # """Converts `_ExecutionSignature` to a `CriticalSectionExecutionDef`. 279 280 # Args: 281 # execution_signature: Instance of `_ExecutionSignature`. 282 # export_scope: The export scope, if any. 283 284 # Returns: 285 # An instance of `CriticalSectionExecutionDef`. 286 # """ 287 # if (export_scope is None 288 # or execution_signature.op.name.startswith(export_scope)): 289 # op_def = critical_section_pb2.CriticalSectionExecutionDef() 290 # op_def.execute_in_critical_section_name = ops.strip_name_scope( 291 # execution_signature.op.name, export_scope) 292 # op_def.exclusive_resource_access = ( 293 # execution_signature.exclusive_resource_access) 294 # return op_def 295 # else: 296 # return None 297 298 299 # def _execution_from_proto_fn(op_def, import_scope=None): 300 # """Converts a `CriticalSectionExecutionDef` to a `_ExecutionSignature`.""" 301 # assert isinstance( 302 # op_def, critical_section_pb2.CriticalSectionExecutionDef) 303 304 # # Create from op_def. 305 # g = ops.get_default_graph() 306 # execution_op = g.as_graph_element( 307 # ops.prepend_name_scope( 308 # op_def.execute_in_critical_section_name, 309 # import_scope=import_scope)) 310 # return _ExecutionSignature( 311 # op=execution_op, 312 # exclusive_resource_access=op_def.exclusive_resource_access) 313 314 # ops.register_proto_function( 315 # CRITICAL_SECTIONS, 316 # proto_type=critical_section_pb2.CriticalSectionDef, 317 # to_proto=CriticalSection.to_proto, 318 # from_proto=CriticalSection.from_proto) 319 320 # ops.register_proto_function( 321 # CRITICAL_SECTION_EXECUTIONS, 322 # proto_type=critical_section_pb2.CriticalSectionExecutionDef, 323 # to_proto=_execution_to_proto_fn, 324 # from_proto=_execution_from_proto_fn) 325