1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 # pylint: disable=unidiomatic-typecheck 16 """Defun decorator for defining graph-mode functions.""" 17 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import collections 23 import contextlib 24 import threading 25 26 import numpy as np 27 28 from tensorflow.core.framework import function_pb2 29 from tensorflow.python import pywrap_tensorflow 30 from tensorflow.python.eager import context 31 from tensorflow.python.eager import execute 32 from tensorflow.python.eager import tape 33 from tensorflow.python.eager.graph_only_ops import graph_placeholder 34 from tensorflow.python.framework import c_api_util 35 from tensorflow.python.framework import constant_op 36 from tensorflow.python.framework import dtypes as dtypes_module 37 from tensorflow.python.framework import errors 38 from tensorflow.python.framework import ops 39 from tensorflow.python.ops import control_flow_ops 40 from tensorflow.python.ops import gradients_impl 41 from tensorflow.python.util import compat 42 from tensorflow.python.util import nest 43 from tensorflow.python.util import tf_decorator 44 45 # Thread-local storage for tfe Tensors which are referenced while evaluating a 46 # graph-mode function. 47 _scoped_captures = threading.local() 48 # _scoped_captures.tensors is either None or a map from Tensor id to a pair 49 # of a tfe tensor and its corresponding placeholder to pass as a function 50 # argument. The value should be None unless we're in function definition 51 # context. 52 _scoped_captures.tensors = None 53 54 55 @contextlib.contextmanager 56 def capture_tensors(captures): 57 old = _scoped_captures.__dict__.get("tensors", None) 58 try: 59 _scoped_captures.tensors = captures 60 yield 61 finally: 62 _scoped_captures.tensors = old 63 64 65 def capture_value(tensor_map, value, dtype, name): 66 """Capture a value from outside the function, to pass in as an extra arg.""" 67 captured_value = tensor_map.get(ops.tensor_id(value), None) 68 if captured_value is None: 69 captured_value = graph_placeholder( 70 dtype=dtype or value.dtype, shape=value.shape, name=name) 71 if captured_value.dtype == dtypes_module.resource: 72 handle_data = value._handle_data # pylint: disable=protected-access 73 captured_value._handle_data = handle_data # pylint: disable=protected-access 74 if handle_data is not None and handle_data.is_set: 75 # Ensure that shapes and dtypes are propagated. 76 shapes, types = zip(*[(pair.shape, pair.dtype) 77 for pair in handle_data.shape_and_type]) 78 ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] 79 shapes = [[d.size for d in s.dim] 80 if not s.unknown_rank else None for s in shapes] 81 with errors.raise_exception_on_not_ok_status() as status: 82 pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( 83 captured_value._op._graph._c_graph, # pylint: disable=protected-access 84 captured_value._as_tf_output(), # pylint: disable=protected-access 85 shapes, 86 ranks, 87 types, 88 status) 89 90 tensor_map[ops.tensor_id(value)] = (value, captured_value) 91 else: 92 captured_value = captured_value[1] 93 tape.record_operation("captured_value", [captured_value], [value], 94 lambda x: [x]) 95 return captured_value 96 97 98 def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False): 99 """Captures a Tensor while building a graph mode function. 100 101 Arguments: 102 value: A Tensor object. 103 dtype: The datatype of the value produced by the node in the graph. 104 name: str, Name of the node in the graph. 105 as_ref: Ignored (required by register_tensor_conversion_function). 106 107 Returns: 108 Returns a constant (the current value of the tensor) if capturing 109 is not enabled. A placeholder which will have the value of the 110 tensor at runtime otherwise. 111 """ 112 del as_ref # Unused. 113 114 if context.in_eager_mode(): 115 return value 116 117 default_graph = ops.get_default_graph() 118 if not default_graph.building_function: 119 return value 120 121 tensor_map = _scoped_captures.tensors 122 if tensor_map is None: 123 # Capturing is not enabled. 124 if value.dtype == dtypes_module.resource: 125 return value 126 return constant_op.constant(value.numpy()) 127 if type(value) == ops.Tensor and value.graph is default_graph: 128 # The tensor has already been converted and captured. The type check 129 # is intentional: we are checking that value is a Tensor and not an 130 # EagerTensor. 131 return value 132 return capture_value(tensor_map, value, dtype, name) 133 134 135 class CapturingGraph(ops.Graph): 136 """Graph used when constructing eager functions.""" 137 138 def __init__(self, captures): 139 super(CapturingGraph, self).__init__() 140 self._building_function = True 141 self.captures = captures 142 # Map from resource tensor name to last op (in program order) which uses 143 # this tensor. Used to enforce that execution order matches program order 144 # for resource tensors. 145 self._last_op_using_resource_tensor = {} 146 147 # TODO(apassos) remove once the C API is used by default. 148 def _use_c_api_hack(self): 149 return True 150 151 def clear_resource_control_flow_state(self): 152 self._last_op_using_resource_tensor = {} 153 154 def create_op( 155 self, 156 op_type, 157 inputs, 158 dtypes, # pylint: disable=redefined-outer-name 159 input_types=None, 160 name=None, 161 attrs=None, 162 op_def=None, 163 compute_shapes=True, 164 compute_device=True): 165 # TODO(apassos) probably control flow has to be handled delicately here as 166 # in if a resource is accessed inside a control flow context we need the 167 # control dependency to point to something outside the context which is 168 # guaranteed to happen after the access. 169 # 170 # TODO(apassos) this should do some form of alias analysis as ops which 171 # forward the resources such as Identity and Switch can cause serialization 172 # to fail. 173 resource_inputs = set() 174 control_inputs = set() 175 for i, inp in enumerate(inputs): 176 if inp.graph is not self: 177 inputs[i] = capture_value(self.captures, inp, inp.dtype, inp.op.name) 178 inp = inputs[i] 179 if inp.dtype == dtypes_module.resource: 180 if inp.name in self._last_op_using_resource_tensor: 181 control_inputs.add(self._last_op_using_resource_tensor[inp.name]) 182 resource_inputs.add(inp.name) 183 with self.control_dependencies(list(control_inputs)): 184 op = super(CapturingGraph, self).create_op( 185 op_type, inputs, dtypes, input_types, name, attrs, op_def, 186 compute_shapes, compute_device) 187 for name in resource_inputs: 188 self._last_op_using_resource_tensor[name] = op 189 return op 190 191 192 # TODO(apassos): it'd be really nice if we could scope this registration. 193 # Note that we register this at a higher priority than ops.Tensor since we want 194 # to handle subclass specific conversion before a superclass conversion. 195 ops.register_tensor_conversion_function( 196 ops.EagerTensor, _convert_to_graph_tensor, priority=-1) 197 198 199 class _CapturingContext(object): 200 """Tracks references to Tensors outside this context while it is active.""" 201 202 def __init__(self): 203 # known_ops are ops which are created while this context is active 204 self.known_ops = set() 205 206 # captured_tensors are all tensors referenced to by ops in this context but 207 # not produced in it 208 self.captured_tensors = set() 209 210 def AddOp(self, op): # pylint: disable=invalid-name 211 if op.type in ["Variable", "VariableV2", "VarHandleOp"]: 212 raise ValueError("tfe.defun cannot capture variables created without " 213 "using tf.get_variable. Op: %s" % op) 214 self.known_ops.add(op) 215 for i in op.inputs: 216 if i.op not in self.known_ops: 217 self.captured_tensors.add(i) 218 219 def __enter__(self): 220 self._g = ops.get_default_graph() 221 self._old = self._g._get_control_flow_context() # pylint: disable=protected-access 222 self._g._set_control_flow_context(self) # pylint: disable=protected-access 223 224 def __exit__(self, _, __, ___): # pylint: disable=invalid-name 225 self._g._set_control_flow_context(self._old) # pylint: disable=protected-access 226 227 228 def _forward_name(n): 229 """The name of a generated forward defun named n.""" 230 return "__forward_%s_%s" % (n, ops.uid()) 231 232 233 def _backward_name(n): 234 """The name of a generated backward defun named n.""" 235 return "__backward_%s_%s" % (n, ops.uid()) 236 237 238 def _inference_name(n): 239 """The name of a forward-but-no-gradient defun named n.""" 240 return "__inference_%s_%s" % (n, ops.uid()) 241 242 243 # TODO(apassos) get rid of this by splitting framework.function._DefinedFunction 244 # so it doesn't have the definition-generating logic and is just a container for 245 # an already-defined function. 246 class _EagerDefinedFunction(object): 247 """Function object with the interface of tf _DefinedFunction.""" 248 249 def __init__(self, name, graph, operations, inputs, outputs): 250 """Initializes an eager defined function. 251 252 Args: 253 name: str, the name for the created function. 254 graph: Graph, the graph containing the operations in the function 255 operations: list of Operation; the subset of operations in the graph 256 which will be in the function 257 inputs: the tensors in the graph to be used as inputs to the function 258 outputs: the tensors in the graph which will be outputs to the function 259 """ 260 with errors.raise_exception_on_not_ok_status() as status: 261 fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( 262 graph._c_graph, # pylint: disable=protected-access 263 compat.as_str(name), 264 False, 265 [o._c_op for o in operations], # pylint: disable=protected-access 266 [t._as_tf_output() for t in inputs], # pylint: disable=protected-access 267 [t._as_tf_output() for t in outputs], # pylint: disable=protected-access 268 [], 269 None, 270 compat.as_str(""), 271 status) 272 # TODO(apassos) avoid creating a FunctionDef (specially to grab the 273 # signature, but also in general it's nice not to depend on it. 274 with c_api_util.tf_buffer() as buffer_: 275 with errors.raise_exception_on_not_ok_status() as status: 276 pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status) 277 proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) 278 function_def = function_pb2.FunctionDef() 279 function_def.ParseFromString(compat.as_bytes(proto_data)) 280 if context.in_eager_mode(): 281 _register(fn) 282 self.definition = function_def 283 self.name = function_def.signature.name 284 self.signature = function_def.signature 285 self.grad_func_name = None 286 self.python_grad_func = None 287 self._c_func = fn 288 self._grad_func = None 289 290 291 def _map_sequence_obj_to_idx(sequence): 292 """Maps objs in the sequence from id(obj) to sequence index.""" 293 return {id(x): i for i, x in enumerate(sequence)} 294 295 296 def _flatten(sequence): 297 """A wrapper around `nest.flatten` that also unpacks `IndexedSlices`.""" 298 # TODO(akshayka): Support `SparseTensor` in a similar fashion. 299 flat_sequence = nest.flatten(sequence) 300 outputs = [] 301 for item in flat_sequence: 302 if isinstance(item, ops.IndexedSlices): 303 if item.dense_shape is not None: 304 outputs.extend([item.values, item.indices, item.dense_shape]) 305 else: 306 outputs.extend([item.values, item.indices]) 307 else: 308 outputs.append(item) 309 return outputs 310 311 312 class GraphModeFunction(object): 313 """Callable object representing a graph-mode function. 314 315 Args: 316 name: str the name of the created function 317 input_placeholders: list of placeholder values (tensors) to feed when 318 calling the wrapped function. 319 extra_inputs: Tensor inputs this function definition closed over which 320 are passed as arguments. Need to track so gradients are supported 321 correctly. 322 graph: the Graph from which the operations will be pulled. Used as 323 a context when computing gradients. 324 operations: the subset of Operations in the graph used in the function 325 definition. 326 outputs: a flat list of the Tensors in the graph used as outputs to the 327 function 328 func_outputs: a possibly nested python object which will be returned by 329 this function. The Tensors in this structure will be replaced by their 330 corresponding values in outputs. 331 output_shapes: List of shapes of all tensors in outputs 332 variables: (optional) List of variables to watch during function execution. 333 """ 334 335 def __init__(self, 336 name, 337 input_placeholders, 338 extra_inputs, 339 graph, 340 operations, 341 outputs, 342 func_outputs, 343 output_shapes, 344 variables=None): 345 defined_function = _EagerDefinedFunction( 346 name, graph, operations, input_placeholders, outputs) 347 if len(input_placeholders) != len(defined_function.signature.input_arg): 348 raise ValueError("Internal error: invalid lengths. %s %s" % ( 349 len(input_placeholders), len(defined_function.signature.input_arg))) 350 self._input_placeholders = input_placeholders 351 self._extra_inputs = list(extra_inputs) 352 self._graph = graph 353 self._backward_function = None 354 self._func_name = name 355 self._function_def = defined_function 356 self._num_outputs = len(defined_function.signature.output_arg) 357 self._ops = operations 358 self._func_outputs = func_outputs 359 self._returns = [func_outputs] if isinstance( 360 func_outputs, (ops.Tensor, type(None))) else _flatten(func_outputs) 361 self._output_shapes = output_shapes 362 self._variables = variables if variables is not None else [] 363 364 @property 365 def variables(self): 366 return self._variables 367 368 def _construct_backprop_function(self): 369 """Constructs the backprop function object for this function.""" 370 with self._graph.as_default(), context.graph_mode(): 371 c = _CapturingContext() 372 with c: 373 filtered_outputs = [x for x in self._returns if x is not None] 374 self._out_grad_placeholders = [ 375 graph_placeholder(x.dtype, x.shape) for x in filtered_outputs] 376 in_gradients = gradients_impl.gradients( 377 filtered_outputs, 378 self._input_placeholders, 379 grad_ys=self._out_grad_placeholders) 380 381 backward_outputs = tuple( 382 grad for grad in _flatten(in_gradients) if grad is not None) 383 output_shapes = tuple(grad.shape for grad in backward_outputs) 384 385 captures = list(sorted(c.captured_tensors, key=lambda x: x.name)) 386 forward_name = _forward_name(self._func_name) 387 self._forward_fdef = _EagerDefinedFunction( 388 forward_name, self._graph, self._ops, self._input_placeholders, 389 filtered_outputs + captures) 390 all_inputs = self._out_grad_placeholders + captures 391 # Excluding input ops from the body as we do not intend to execute these 392 # operations when the function is executed. 393 all_ignored_ops = frozenset(x.op for x in all_inputs) 394 # Enforce a deterministic order of operations in the generated graph. This 395 # means rerunning the function-defining code will always define the same 396 # function, which is useful if we serialize this etc. 397 function_def_ops = tuple(x 398 for x in sorted(c.known_ops, key=lambda x: x.name) 399 if x not in all_ignored_ops) 400 bname = _backward_name(self._func_name) 401 self._backward_function = GraphModeFunction( 402 bname, all_inputs, [], self._graph, function_def_ops, 403 backward_outputs, in_gradients, output_shapes) 404 405 def _backprop_call(self, args): 406 """Calls the wrapped function and records the result on a tape.""" 407 all_args = args + self._extra_inputs 408 signature = self._forward_fdef.signature 409 ctx = context.context() 410 if ctx.in_graph_mode(): 411 g = ops.get_default_graph() 412 g._add_function(self._forward_fdef) # pylint: disable=protected-access 413 op = g.create_op( 414 signature.name, 415 [ops.internal_convert_to_tensor(x, ctx=ctx) for x in all_args], 416 tuple(dtypes_module.DType(x.type) for x in signature.output_arg), 417 op_def=signature, 418 name="FunctionCall", 419 compute_shapes=False) 420 outputs = op.outputs 421 outputs = [outputs] if isinstance( 422 outputs, (ops.Tensor, type(None))) else list(outputs) 423 for i, s in enumerate(self._output_shapes): 424 outputs[i].set_shape(s) 425 else: 426 outputs = execute.execute( 427 str(signature.name), 428 num_outputs=len(signature.output_arg), 429 inputs=all_args, 430 attrs=None, 431 ctx=ctx) 432 real_outputs = outputs[:len(self._returns)] 433 side_outputs = outputs[len(self._returns):] 434 435 def backward_function(*args): 436 return self._backward_function(*(list(args) + side_outputs)) # pylint: disable=not-callable 437 438 tape.record_operation( 439 signature.name, 440 real_outputs, 441 (args + self._extra_inputs), 442 backward_function) 443 444 return self._build_call_outputs(real_outputs) 445 446 @property 447 def output_shapes(self): 448 """The function's output shapes.""" 449 # TODO(ebrevdo): Should we only keep the output shapes associated 450 # with len(self._returns) outputs? 451 outputs_list = nest.flatten(self._func_outputs) 452 j = 0 453 for i, o in enumerate(outputs_list): 454 if o is not None: 455 if isinstance(o, ops.IndexedSlices): 456 # Extract the shape of the `IndexedSlices` object's `values` field. 457 outputs_list[i] = self._output_shapes[j] # the `values` shape 458 if o.dense_shape is not None: 459 j += 3 # skip over shapes for `values`, `indices`, `dense_shape` 460 else: 461 j += 2 # skip over shapes for `values`, `indices` 462 else: 463 outputs_list[i] = self._output_shapes[j] 464 j += 1 465 return nest.pack_sequence_as(self._func_outputs, outputs_list) 466 467 @property 468 def output_dtypes(self): 469 return nest.map_structure( 470 lambda x: x.dtype if x is not None else None, self._func_outputs) 471 472 @property 473 def captured_inputs(self): 474 return self._extra_inputs 475 476 @property 477 def name(self): 478 """Returns the name of the function in Eager-compatible format.""" 479 return self._function_def.name.encode("utf-8") 480 481 def add_to_graph(self, g): 482 if self._function_def.name not in g._functions: # pylint: disable=protected-access 483 g._add_function(self._function_def) # pylint: disable=protected-access 484 for f in self._graph._functions.values(): # pylint: disable=protected-access 485 if f.name not in g._functions: # pylint: disable=protected-access 486 g._add_function(f) # pylint: disable=protected-access 487 488 def __call__(self, *args): 489 """Executes the passed function in eager mode.""" 490 for v in self._variables: 491 if v._trainable: # pylint: disable=protected-access 492 tape.watch_variable(v) 493 494 tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)] 495 if tape.should_record(tensor_inputs) or tape.should_record( 496 self._extra_inputs): 497 if self._backward_function is None: 498 self._construct_backprop_function() 499 return self._backprop_call(tensor_inputs) 500 501 ctx = context.context() 502 if ctx.in_graph_mode(): 503 g = ops.get_default_graph() 504 self.add_to_graph(g) 505 signature = self._function_def.definition.signature 506 args = list(tensor_inputs) + self._extra_inputs 507 op = g.create_op( 508 signature.name, 509 [ops.internal_convert_to_tensor(x, ctx=ctx) for x in args], 510 tuple(dtypes_module.DType(x.type) for x in signature.output_arg), 511 op_def=signature, 512 name="FunctionCall", 513 compute_shapes=False) 514 result = op.outputs 515 if not result: 516 return op 517 for i, s in enumerate(self._output_shapes): 518 result[i].set_shape(s) 519 else: 520 result = execute.execute( 521 str(self._func_name), 522 num_outputs=self._num_outputs, 523 inputs=tensor_inputs + self._extra_inputs, 524 attrs=None, 525 ctx=ctx) 526 527 return self._build_call_outputs(result) 528 529 def _build_call_outputs(self, result): 530 """Maps the fdef output list to actual output structure. 531 532 Args: 533 result: Output lists defined by FunctionDef. 534 Returns: 535 The actual call output. 536 """ 537 if self._func_outputs is None: 538 return None 539 # Use `nest.flatten` instead of `_flatten` in order to preserve any 540 # IndexedSlices in `self._func_outputs`. 541 outputs_list = nest.flatten(self._func_outputs) 542 j = 0 543 for i, o in enumerate(outputs_list): 544 if o is not None: 545 if isinstance(o, ops.IndexedSlices): 546 # Repack Tensors for IndexedSlices. 547 if o.dense_shape is not None: 548 outputs_list[i] = ops.IndexedSlices( 549 values=result[j], 550 indices=result[j + 1], 551 dense_shape=result[j + 2]) 552 j += 3 553 else: 554 outputs_list[i] = ops.IndexedSlices( 555 values=result[j], 556 indices=result[j + 1]) 557 j += 2 558 else: 559 outputs_list[i] = result[j] 560 j += 1 561 ret = nest.pack_sequence_as(self._func_outputs, outputs_list) 562 return ret 563 564 565 def _get_defun_inputs(args): 566 """Maps the inputs args to graph inputs.""" 567 ret = [] 568 flat_args = nest.flatten(args) 569 for a in flat_args: 570 if isinstance(a, ops.Tensor): 571 ret.append(graph_placeholder(a.dtype, a.shape)) 572 else: 573 ret.append(a) 574 return nest.pack_sequence_as(args, ret) 575 576 577 def _defun_internal(name, func, args, kwds): 578 """Defines and returns graph-mode version of func.""" 579 graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 580 with context.graph_mode(): 581 captures = {} 582 tmp_graph = CapturingGraph(captures) 583 # Inherit the graph key, since this is used for matching variables in 584 # optimizers. 585 tmp_graph._graph_key = graph_key # pylint: disable=protected-access 586 # Copy the graph collections to ensure summaries and other things work. This 587 # lets the function access (but not mutate) collections of the containing 588 # graph, such as the global step and the summary writer collections. 589 curr_graph = ops.get_default_graph() 590 for collection in curr_graph.collections: 591 tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection( 592 collection) 593 with tmp_graph.as_default(): 594 func_inputs = _get_defun_inputs(args) 595 596 def convert(x): 597 if x is None: 598 return None 599 return ops.convert_to_tensor_or_indexed_slices(x) 600 601 with capture_tensors(captures): 602 this_tape = tape.push_new_tape() 603 try: 604 func_outputs = func(*func_inputs, **kwds) 605 func_outputs = nest.map_structure(convert, func_outputs) 606 finally: 607 tape.pop_tape(this_tape) 608 variables = this_tape.watched_variables() 609 610 # Returning a closed-over tensor as an output does not trigger a 611 # call to convert_to_tensor, so we manually capture all such tensors. 612 outputs_list = _flatten(func_outputs) 613 func_def_outputs = [ 614 _convert_to_graph_tensor(x) for x in outputs_list if x is not None 615 ] 616 617 ids = list(sorted(captures.keys())) 618 if ids: 619 extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids]) 620 else: 621 extra_inputs = [] 622 extra_placeholders = [] 623 output_shapes = tuple( 624 x.shape if isinstance(x, ops.Tensor) else None 625 for x in outputs_list) 626 627 flat_inputs = [x for x in nest.flatten(func_inputs) 628 if isinstance(x, ops.Tensor)] 629 all_inputs = flat_inputs + list(extra_placeholders) 630 all_ignored_ops = frozenset(x.op for x in all_inputs) 631 fname = _inference_name(name) 632 operations = tuple(x for x in tmp_graph.get_operations() 633 if x not in all_ignored_ops) 634 # Register any other functions defined in the graph 635 # TODO(ashankar): Oh lord, forgive me for this lint travesty. 636 if context.in_eager_mode(): 637 for f in tmp_graph._functions.values(): # pylint: disable=protected-access 638 # TODO(ashankar): What about the gradient registry? 639 _register(f._c_func) # pylint: disable=protected-access 640 return GraphModeFunction( 641 fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs, 642 func_outputs, output_shapes, variables) 643 644 645 # Defun uses this instead of Tensor as a cache key. Using dtype because 646 # TensorFlow graphs are not parametric wrt dtypes, and using shapes for 647 # performance reasons, as much TensorFlow code specializes on known shapes to 648 # produce slimmer graphs. 649 _TensorDtype = collections.namedtuple("_TensorDtype", ["dtype", "shape"]) 650 _ZeroDtype = collections.namedtuple("_ZeroDtype", ["dtype", "shape"]) 651 652 653 def _cache_key(x): 654 """Cache key for tfe functions.""" 655 if isinstance(x, ops.Tensor): 656 return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access 657 if isinstance(x, ops.IndexedSlices): 658 if x.dense_shape is not None: 659 return tuple([ 660 _TensorDtype(x.values.dtype, x.values._shape_tuple()), # pylint: disable=protected-access 661 _TensorDtype(x.indices.dtype, x.indices._shape_tuple()), # pylint: disable=protected-access 662 _TensorDtype(x.dense_shape.dtype, x.dense_shape._shape_tuple()) # pylint: disable=protected-access 663 ]) 664 else: 665 return tuple([ 666 _TensorDtype(x.values.dtype, x.values._shape_tuple()), # pylint: disable=protected-access 667 _TensorDtype(x.indices.dtype, x.indices._shape_tuple()) # pylint: disable=protected-access 668 ]) 669 if isinstance(x, np.ndarray): 670 return ("array", x.shape, tuple(x.reshape(-1))) 671 if isinstance(x, (list, tuple)): 672 return tuple([_cache_key(a) for a in x]) 673 if isinstance(x, dict): 674 return tuple(tuple([_cache_key(k), _cache_key(v)]) for k, v in x.items()) 675 return x 676 677 678 def _register(fn): 679 """Registers the function `fn`.""" 680 context.context().add_function(fn) 681 682 683 # TODO(apassos): better error messages for non-hashable arguments. 684 def named_defun(func, name): 685 """Defines a function with a given name. 686 687 See the documentation for `defun` for more information on the semantics of the 688 function. 689 690 Args: 691 func: the function to be wrapped. 692 name: the name given to it. 693 694 Returns: 695 the wrapped function. 696 """ 697 arguments_to_functions = {} 698 699 def decorated(*args, **kwds): 700 """Decorated version of func.""" 701 # Macroexpand on non-Tensor arguments 702 cache_key = tuple(_cache_key(x) for x in args) 703 if any(isinstance(x, ops.EagerTensor) for x in kwds.values()): 704 raise ValueError("Tensor keyword arguments are not supported.") 705 cache_key = (cache_key, tuple(kwds.items())) 706 707 if cache_key not in arguments_to_functions: 708 arguments_to_functions[cache_key] = _defun_internal( 709 name, func, args, kwds) 710 return arguments_to_functions[cache_key](*args) 711 712 return decorated 713 714 715 def defun(func): 716 """Decorator to compile func into graph_mode. 717 718 `defun` converts a function that constructs a TensorFlow graph into a function 719 that executes the graph. TensorFlow graphs typically execute faster and with a 720 lower memory-footprint than executing each of the operations that make up the 721 function individually as the TensorFlow runtime can optimize the graph and 722 execute sub-operations in parallel. 723 724 func must be a Python function that constructs a TensorFlow graph, 725 typically using functions in the tensorflow module. 726 727 Arguments to func can be either Tensor objects or Python 728 objects. Non-Tensor python objects are treated as constants, and new function 729 definitions are created internally based on their values. 730 731 func must return a tf.Tensor (NOT a Tensor) or a list of tf.Tensor (NOT a 732 Tensor). 733 734 Control flow constructs (e.g., `if`, `while`) are not yet compatible with 735 `defun`. 736 737 Example: 738 ```python 739 def f(x, y): 740 return tf.reduce_mean(tf.multiply(x ** 2, 3) + y) 741 742 @tfe.defun 743 def g(x, y): 744 return tf.reduce_mean(tf.multiply(x ** 2, 3) + y) 745 746 x = tf.constant([[2.0, 3.0]]) 747 y = tf.constant([[3.0, -2.0]]) 748 # The plain function and defun-compiled function should return the same value. 749 assert f(x, y).numpy() == g(x, y).numpy() 750 751 # After the first invocation, the defun-compiled (graph) function runs faster 752 # than the plain function because the defun-compiled function does not involve 753 # Python interpreter overhead during the execution. 754 %time print(f(x, y)) 755 %time print(g(x, y)) 756 ``` 757 758 Args: 759 func: function to be compiled. 760 761 Returns: 762 A callable that will execute the compiled function (and return zero 763 or more Tensor objects). 764 """ 765 # TODO(apassos): deal with captured global state. Deal with control flow. 766 try: 767 name = func.__name__ 768 except AttributeError: 769 name = "function" 770 return tf_decorator.make_decorator(func, named_defun(func, name)) 771 772 773 def make_defun_op(func, *args, **kwds): 774 """Compile func into graph_mode, assuming func arguments are *args, **kwargs. 775 776 `make_defun_op` converts a function that constructs a TensorFlow graph into 777 a function object and attaches it to the graph. The resulting function 778 object can be queried for its properties, and called directly with different 779 inputs to execute. 780 781 More details on use cases and limitations are available in the 782 documentation for `defun`. 783 784 Example: 785 ```python 786 def f(x, y): 787 return tf.reduce_mean(tf.multiply(x ** 2, 3) + y) 788 789 def g(x, y): 790 return tf.reduce_mean(tf.multiply(x ** 2, 3) + y) 791 792 z = tf.constant([[0.0, 0.0]]) 793 g_op = make_defun_op(g, z, z) 794 795 assert g_op.output_shapes == tf.TensorShape([]) 796 assert g_op.output_types == tf.float32 797 798 x = tf.constant([[2.0, 3.0]]) 799 y = tf.constant([[3.0, -2.0]]) 800 801 # The plain function and defun-compiled function should return the same value. 802 assert f(x, y).numpy() == g_op(x, y).numpy() 803 ``` 804 805 Args: 806 func: function to be compiled. 807 *args: List arguments to pass to `func` when attaching to the graph. 808 **kwds: Keyword arguments to pass to `func` when attaching to the graph. 809 810 Returns: 811 A wrapper object which can be queried for its output properties, 812 and which can be called directly the way a `@defun` wrapped function 813 can. 814 815 Raises: 816 ValueError: if any of the keyword arguments to `func` are `EagerTensor` 817 objects (not yet supported). 818 """ 819 name = func.__name__ 820 if any(isinstance(x, ops.EagerTensor) for x in kwds.values()): 821 raise ValueError("Tensor keyword arguments are not supported.") 822 return _defun_internal(name, func, args, kwds) 823 824 825 class AutomaticControlDependencies(object): 826 """Context manager to automatically add control dependencies. 827 828 Code under this context manager will act as if a sensible set of control 829 dependencies were present. More specifically: 830 1. All stateful ops in the scope will execute 831 2. Stateful ops which modify the same resource will execute in program order 832 833 Note: creating variables in an automatic control dependencies context is not 834 supported (the value of the variables will never change as they will keep 835 getting reinitialized). 836 837 NOT THREAD SAFE 838 """ 839 840 def __init__(self): 841 self._returned_tensors = set() 842 843 def mark_as_return(self, tensor): 844 self._returned_tensors.add(tensor) 845 846 def __enter__(self): 847 if context.in_eager_mode(): 848 return self 849 # This code assumes no other thread is adding ops to the graph while 850 # we're adding ops to the graph. 851 # TODO(apassos): Fix this by locking the graph or using a temporary 852 # graph (but that would mess up devices and collections at least, 853 # probably other things as well). 854 self._graph = ops.get_default_graph() 855 self._n_operations = len(self._graph.get_operations()) 856 return self 857 858 def _process_switch(self, switch_op, ops_which_must_run, 859 last_op_using_resource_tensor, merge_for_resource): 860 """Processes a switch node for a resource input. 861 862 When tensorflow creates a cond, it creates a control flow context for each 863 branch of the cond. Each external tensor accessed by that branch is routed 864 through a switch op, which gets created in the graph _after_ the op which 865 uses that tensor get created. 866 867 If the resource comes from another switch op we process that one first. 868 869 _process_switch creates a corresponding merge node for the switch node. This 870 merge node is added to the outer control flow context of the switch 871 node. We also ensure that: 872 873 1. The switch node executes after the previous op which used the resource 874 tensor 875 876 2. Any op which uses a resource output of the switch node executes before 877 the merge for the switch node. 878 879 3. The next op which uses the input resource to the switch node (which 880 might be another switch node for the other branch of the conditional) 881 will execute after the merge node is done. 882 883 4. The merge node is marked as must_run so it will run even if no 884 subsequent operation uses the resource. 885 886 Args: 887 switch_op: the switch op to be processed 888 ops_which_must_run: the set of ops which must run 889 last_op_using_resource_tensor: map from resource tensor to last op using 890 it 891 merge_for_resource: map from resource tensor to merge which must follow 892 all usages of it. 893 """ 894 inp = switch_op.inputs[0] 895 if inp.dtype == dtypes_module.resource and inp.op.type == "Switch": 896 self._process_switch(inp.op, ops_which_must_run, 897 last_op_using_resource_tensor, merge_for_resource) 898 if switch_op.outputs[0] in merge_for_resource: 899 return 900 new_merge = control_flow_ops.merge(switch_op.outputs, 901 name="artificial_merge") 902 new_merge[0].op._control_flow_context = ( # pylint: disable=protected-access 903 switch_op._control_flow_context.outer_context) # pylint: disable=protected-access 904 # Ensures the merge always runs 905 ops_which_must_run.add(new_merge[0].op) 906 if inp in last_op_using_resource_tensor: 907 # Ensures the switch exectutes after the previous op using the resource. 908 switch_op._add_control_input(last_op_using_resource_tensor[inp]) # pylint: disable=protected-access 909 # Ensure the next op outside the cond happens after the merge. 910 last_op_using_resource_tensor[inp] = new_merge[0].op 911 if inp in merge_for_resource: 912 merge_for_resource[inp]._add_control_input(new_merge[0].op) # pylint: disable=protected-access 913 for o in switch_op.outputs: 914 # Ensures the merge will execute after all ops inside the cond 915 merge_for_resource[o] = new_merge[0].op 916 917 def __exit__(self, unused_type, unused_value, unused_traceback): 918 if context.in_eager_mode(): 919 return 920 921 if self._graph is not ops.get_default_graph(): 922 raise RuntimeError( 923 "Graph changed while trying to add control dependencies.") 924 925 # map from resource tensor to the last op which used it 926 last_op_using_resource_tensor = {} 927 # set of conditional and loop exits 928 ops_which_must_run = set() 929 # merge which must depend on ops which use this resource 930 merge_for_resource = {} 931 932 new_operations = self._graph.get_operations()[self._n_operations:] 933 934 # Ensures that uses of resource tensors get serialized properly and all 935 # execute. This is done by keeping a map from resource tensor to the last op 936 # in graph-construction order which used it (last_op_using_resource_tensor). 937 # 938 # Conditionals are written in TensorFlow such that every external tensor 939 # accessed in the conditional goes through a switch op and every return 940 # tensor (it's guaranteed that there will be at least one) goes through a 941 # merge op. 942 # 943 # To handle conditionals, switches are handled in a special way (see 944 # comments for _process_switch). Merge nodes created by TF's conditional 945 # logic (as opposed to by _process_switch) are forced to run and also get a 946 # control dependency added to them to ensure all stateful ops inside their 947 # control flow context run. 948 # 949 # We also ensure that if an op is using a resource output by a switch node 950 # (that is, a resource tensor for which there's a value in 951 # merge_for_resource) this op will run before the merge for that resource. 952 # 953 # We try to add control inputs to nodes respecting their control flow 954 # contexts to avoid dead nodes propagating everywhere and leading to 955 # "retval[0] doesn't have value" errors. If a node gets a control dependency 956 # on a dead node (i.e. a note from an untaken control flow branch) that node 957 # will be marked as dead unless it's a merge node. 958 # 959 # TODO(apassos): serialize non-resource-taking stateful ops as well, and 960 # test that it works. Support while loops. Support init_scope escaping from 961 # this. 962 for op in new_operations: 963 control_inputs = set() 964 # Ensure stateful ops run 965 if self._graph._registered_ops[op.type].is_stateful: # pylint: disable=protected-access 966 ops_which_must_run.add(op) 967 # Ignore switches (they're handled separately) 968 if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource: 969 continue 970 # Make merges trigger all other computation which must run 971 if op.type == "Merge": 972 for o in ops_which_must_run: 973 op._add_control_input(o) # pylint: disable=protected-access 974 for inp in o.inputs: 975 if inp in last_op_using_resource_tensor: 976 last_op_using_resource_tensor[inp] = op 977 ops_which_must_run = set([op]) 978 continue 979 for inp in op.inputs: 980 if inp.dtype == dtypes_module.resource: 981 # Deal with switches, finally. 982 if inp.op.type == "Switch": 983 self._process_switch(inp.op, ops_which_must_run, 984 last_op_using_resource_tensor, 985 merge_for_resource) 986 # Ensure uses of resources are serialized 987 if inp in last_op_using_resource_tensor: 988 if (last_op_using_resource_tensor[inp]._control_flow_context # pylint: disable=protected-access 989 is op._control_flow_context): # pylint: disable=protected-access 990 control_inputs.add(last_op_using_resource_tensor[inp]) 991 # Ensure merges happen after the closing of a cond block 992 if inp in merge_for_resource: 993 merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access 994 last_op_using_resource_tensor[inp] = op 995 control_inputs = [c for c in control_inputs 996 if c._control_flow_context is op._control_flow_context] # pylint: disable=protected-access 997 op._add_control_inputs(control_inputs) # pylint: disable=protected-access 998 999 # Ensure all ops which must run do run 1000 for r in self._returned_tensors: 1001 r.op._add_control_inputs( # pylint: disable=protected-access 1002 [o for o in ops_which_must_run 1003 if o._control_flow_context is r.op._control_flow_context]) # pylint: disable=protected-access 1004 1005 1006 def automatic_control_dependencies(f): 1007 """Wraps f to automatically insert control dependencies. 1008 1009 The inserted dependencies ensure that: 1010 1. All stateful ops in f run when the result of f runs 1011 2. Updates to the same resources happen in order. 1012 1013 Args: 1014 f: the function to be wrapped. 1015 1016 Returns: 1017 The wrapped function. 1018 """ 1019 1020 def wrapper(*args, **kwds): 1021 with AutomaticControlDependencies() as a: 1022 result = f(*args, **kwds) 1023 for t in nest.flatten(result): 1024 a.mark_as_return(t) 1025 return result 1026 1027 return tf_decorator.make_decorator(f, wrapper) 1028