Home | History | Annotate | Download | only in eager
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Experimental API for TensorFlow's "Eager" mode of execution."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 import contextlib
     23 import copy
     24 import random
     25 import threading
     26 
     27 from tensorflow.core.protobuf import config_pb2
     28 from tensorflow.python import pywrap_tensorflow
     29 from tensorflow.python.framework import c_api_util
     30 from tensorflow.python.framework import device as pydev
     31 from tensorflow.python.framework import errors
     32 from tensorflow.python.util import compat
     33 from tensorflow.python.util import is_in_graph_mode
     34 from tensorflow.python.util import tf_contextlib
     35 
     36 GRAPH_MODE = 0
     37 EAGER_MODE = 1
     38 
     39 # Default execution mode.
     40 _default_mode = GRAPH_MODE
     41 
     42 # Cache from (old_device_name, partial_new_device_name) -> (new_device_name,
     43 # new_device_spec).
     44 # Note that we do not protect this with a lock and instead rely on python's GIL
     45 # and the idempotent nature of writes to provide thread safety.
     46 _device_parsing_cache = {}
     47 
     48 _MAXINT32 = 2**31 - 1
     49 
     50 DEVICE_PLACEMENT_EXPLICIT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_EXPLICIT
     51 DEVICE_PLACEMENT_WARN = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_WARN
     52 DEVICE_PLACEMENT_SILENT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT
     53 DEVICE_PLACEMENT_SILENT_FOR_INT32 = (
     54     pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
     55 
     56 
     57 # TODO(agarwal): better name ?
     58 class _EagerContext(threading.local):
     59   """Thread local eager context."""
     60 
     61   def __init__(self):
     62     super(_EagerContext, self).__init__()
     63     self.device_spec = pydev.DeviceSpec.from_string(
     64         "/job:localhost/replica:0/task:0/device:CPU:0")
     65     self.device_name = self.device_spec.to_string()
     66     self.mode = _default_mode
     67     self.scope_name = ""
     68     self.recording_summaries = False
     69     self.summary_writer_resource = None
     70     self.scalar_cache = {}
     71 
     72 
     73 ContextStackEntry = collections.namedtuple(
     74     "ContextStackEntry", ["is_building_function", "enter_context_fn"])
     75 
     76 
     77 class ContextStack(threading.local):
     78   """A thread-local stack of context switches."""
     79 
     80   def __init__(self):
     81     super(ContextStack, self).__init__()
     82     self.stack = []
     83 
     84   def push(self, is_building_function, enter_context_fn):
     85     """Push metadata about a context switch onto the stack.
     86 
     87     A context switch can take one of two forms: installing a graph as the
     88     default graph, or entering the eager context.
     89 
     90     Args:
     91       is_building_function: (bool.) Whether the context is building a function.
     92       enter_context_fn: (function.) A callable that executes the context switch.
     93         For example, `graph.as_default` or `eager_mode`.
     94     """
     95 
     96     self.stack.append(
     97         ContextStackEntry(is_building_function, enter_context_fn))
     98 
     99   def pop(self):
    100     """Pop the stack."""
    101 
    102     self.stack.pop()
    103 
    104 
    105 context_stack = ContextStack()
    106 
    107 
    108 # TODO(agarwal): rename to EagerContext / EagerRuntime ?
    109 # TODO(agarwal): consider keeping the corresponding Graph here.
    110 class Context(object):
    111   """Environment in which eager operations execute."""
    112 
    113   def __init__(self, config=None, device_policy=None):
    114     """Creates a new Context.
    115 
    116     Args:
    117       config: (Optional.) A `ConfigProto` protocol buffer with configuration
    118        options for the Context. Note that a lot of these options may be
    119        currently unimplemented or irrelevant when eager execution is enabled.
    120       device_policy: (Optional.) What policy to use when trying to run an
    121        operation on a device with inputs which are not on that device.
    122        Valid values:
    123          tfe.DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not
    124            correct.
    125          tfe.DEVICE_PLACEMENT_WARN: copies the tensors which are not on the
    126            right device but raises a warning.
    127          tfe.DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might
    128            hide performance problems.
    129          tfe.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors,
    130            raising errors on the other ones.
    131     """
    132     self._eager_context = _EagerContext()
    133     self._context_handle = None
    134     self._context_devices = None
    135     self._post_execution_callbacks = []
    136     self._config = config
    137     self._seed = None
    138     self._initialize_lock = threading.Lock()
    139     self._device_policy = device_policy
    140 
    141   def _set_global_seed(self, seed):
    142     """Set a global eager mode seed for random ops."""
    143     self._seed = seed
    144     self._rng = random.Random(self._seed)
    145     # Also clear the kernel cache, to reset any existing seeds
    146     if self._context_handle is not None:
    147       pywrap_tensorflow.TFE_ContextClearCaches(self._context_handle)
    148 
    149   def _internal_operation_seed(self):
    150     """Returns a fake operation seed.
    151 
    152       In eager mode, user shouldn't set or depend on operation seed.
    153       Here, we generate a random seed based on global seed to make
    154       operation's randomness different and depend on the global seed.
    155 
    156     Returns:
    157       A fake operation seed based on global seed.
    158     """
    159     return self._rng.randint(0, _MAXINT32)
    160 
    161   def _initialize_handle_and_devices(self):
    162     """Initialize handle and devices."""
    163     with self._initialize_lock:
    164       if self._context_handle is not None:
    165         return
    166       assert self._context_devices is None
    167       opts = pywrap_tensorflow.TFE_NewContextOptions()
    168       try:
    169         with errors.raise_exception_on_not_ok_status() as status:
    170           if self._config is not None:
    171             config_str = self._config.SerializeToString()
    172             pywrap_tensorflow.TFE_ContextOptionsSetConfig(
    173                 opts, config_str, len(config_str), status)
    174           if self._device_policy is not None:
    175             pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy(
    176                 opts, self._device_policy)
    177           self._context_handle = pywrap_tensorflow.TFE_NewContext(opts, status)
    178       finally:
    179         pywrap_tensorflow.TFE_DeleteContextOptions(opts)
    180       # Store list of devices
    181       self._context_devices = []
    182       with errors.raise_exception_on_not_ok_status() as status:
    183         device_list = pywrap_tensorflow.TFE_ContextListDevices(
    184             self._context_handle, status)
    185       try:
    186         self._num_gpus = 0
    187         for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
    188           with errors.raise_exception_on_not_ok_status() as status:
    189             dev_name = pywrap_tensorflow.TF_DeviceListName(
    190                 device_list, i, status)
    191           self._context_devices.append(pydev.canonical_name(dev_name))
    192           with errors.raise_exception_on_not_ok_status() as status:
    193             dev_type = pywrap_tensorflow.TF_DeviceListType(
    194                 device_list, i, status)
    195           if dev_type == "GPU":
    196             self._num_gpus += 1
    197 
    198       finally:
    199         pywrap_tensorflow.TF_DeleteDeviceList(device_list)
    200 
    201   @property
    202   def _handle(self):
    203     ctx = self._context_handle
    204     if ctx is None:
    205       self._initialize_handle_and_devices()
    206       return self._context_handle
    207     else:
    208       return ctx
    209 
    210   @property
    211   def _devices(self):
    212     devices = self._context_devices
    213     if devices is None:
    214       self._initialize_handle_and_devices()
    215       return self._context_devices
    216     else:
    217       return devices
    218 
    219   def __str__(self):
    220     if self._context_handle is None:
    221       return "Eager TensorFlow Context. Devices currently uninitialized."
    222     else:
    223       devices = self._devices
    224       lines = ["Eager TensorFlow Context with %d devices" % (len(devices))]
    225       for i, d in enumerate(devices):
    226         lines.append("   Device %d: %s" % (i, d))
    227       return "\n".join(lines)
    228 
    229   @tf_contextlib.contextmanager
    230   def _mode(self, mode):
    231     ctx = self._eager_context
    232     old_mode = ctx.mode
    233     ctx.mode = mode
    234     if mode == EAGER_MODE:
    235       context_stack.push(False, eager_mode)
    236     try:
    237       yield
    238     finally:
    239       ctx.mode = old_mode
    240       if mode == EAGER_MODE:
    241         context_stack.pop()
    242 
    243   def in_graph_mode(self):
    244     """Returns True if current thread is in GRAPH mode."""
    245     return self._eager_context.mode == GRAPH_MODE
    246 
    247   def in_eager_mode(self):
    248     """Returns True if current thread is in EAGER mode."""
    249     return self._eager_context.mode == EAGER_MODE
    250 
    251   def scalar_cache(self):
    252     """Per-device cache for scalars."""
    253     return self._eager_context.scalar_cache
    254 
    255   @property
    256   def scope_name(self):
    257     """Returns scope name for the current thread."""
    258     return self._eager_context.scope_name
    259 
    260   @scope_name.setter
    261   def scope_name(self, s):
    262     """Sets scope name for the current thread."""
    263     self._eager_context.scope_name = s
    264 
    265   @property
    266   def summary_writer_resource(self):
    267     """Returns summary writer resource."""
    268     return self._eager_context.summary_writer_resource
    269 
    270   @summary_writer_resource.setter
    271   def summary_writer_resource(self, resource):
    272     """Sets summary writer resource."""
    273     self._eager_context.summary_writer_resource = resource
    274 
    275   @property
    276   def device_name(self):
    277     """Returns the device name for the current thread."""
    278     return self._eager_context.device_name
    279 
    280   @property
    281   def device_spec(self):
    282     """Returns the device spec for the current thread."""
    283     return self._eager_context.device_spec
    284 
    285   @tf_contextlib.contextmanager
    286   def device(self, name):
    287     """Context-manager to force placement of operations and Tensors on a device.
    288 
    289     Args:
    290       name: Name of the device or None to get default placement.
    291 
    292     Yields:
    293       Nothing.
    294 
    295     Raises:
    296       ValueError: If name is not a string or is an invalid device name.
    297     """
    298     eager_context = self._eager_context
    299     old_device_name = eager_context.device_name
    300     old_device_spec = eager_context.device_spec
    301     cache_key = (old_device_name, name)
    302     try:
    303       new_device_name, new_device_spec = _device_parsing_cache[cache_key]
    304     except TypeError:
    305       # Error while trying to compute the cache key.
    306       raise ValueError("Expecting a string device name. Got %s(%s)" %
    307                        (type(name), name))
    308     except KeyError:
    309       # Handle a cache miss.
    310       if name is not None:
    311         if not isinstance(name, str):
    312           raise ValueError("Expecting a string device name. Got %s(%s)" %
    313                            (type(name), name))
    314         device_spec = pydev.DeviceSpec.from_string(name)
    315         if old_device_name:
    316           new_device_spec = copy.copy(old_device_spec)
    317         else:
    318           new_device_spec = pydev.DeviceSpec.from_string(
    319               "/job:localhost/replica:0/task:0/device:CPU:0")
    320         new_device_spec.merge_from(device_spec)
    321       else:
    322         new_device_spec = pydev.DeviceSpec.from_string("")
    323       new_device_name = new_device_spec.to_string()
    324       _device_parsing_cache[cache_key] = (new_device_name, new_device_spec)
    325 
    326     try:
    327       eager_context.device_name = new_device_name
    328       eager_context.device_spec = new_device_spec
    329       yield
    330     finally:
    331       eager_context.device_name = old_device_name
    332       eager_context.device_spec = old_device_spec
    333 
    334   def devices(self):
    335     """List of the names of devices available to execute operations."""
    336     return self._devices
    337 
    338   def num_gpus(self):
    339     """The number of GPUs available to execute operations."""
    340     self._initialize_handle_and_devices()
    341     return self._num_gpus
    342 
    343   def add_function(self, fn):
    344     """Add a function definition to the context.
    345 
    346     Once added, the function (identified by its name) can be executed like any
    347     other operation.
    348 
    349     Args:
    350       fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
    351     """
    352     with errors.raise_exception_on_not_ok_status() as status:
    353       pywrap_tensorflow.TFE_ContextAddFunction(
    354           self._handle,  # pylint: disable=protected-access
    355           fn,
    356           status)
    357 
    358   def add_function_def(self, fdef):
    359     """Add a function definition to the context.
    360 
    361     Once added, the function (identified by its name) can be executed like any
    362     other operation.
    363 
    364     Args:
    365       fdef: A FunctionDef protocol buffer message.
    366     """
    367     fdef_string = fdef.SerializeToString()
    368     with errors.raise_exception_on_not_ok_status() as status:
    369       pywrap_tensorflow.TFE_ContextAddFunctionDef(
    370           self._handle,  # pylint: disable=protected-access
    371           fdef_string,
    372           len(fdef_string),
    373           status)
    374 
    375   def add_post_execution_callback(self, callback):
    376     """Add a post-execution callback to the context.
    377 
    378     A post-execution callback is invoked immediately after an eager operation or
    379     function has finished execution, providing access to the op's type, name
    380     input and output tensors. Multiple execution callbacks can be added, in
    381     which case the callbacks will be invoked in the order in which they are
    382     added.
    383 
    384     Args:
    385       callback: a callable of the signature
    386       `f(op_type, op_name, attrs, inputs, outputs)`.
    387       `op_type` is the type of the operation that was just executed (e.g.,
    388         `MatMul`).
    389       `op_name` is the name of the operation that has was just executed. This
    390         name is set by the client who created the operation and can be `None` if
    391         it is unset.
    392       `attrs` contains the attributes of the operation as a `tuple` of
    393         alternating attribute names and attribute values.
    394       `inputs` is the `list` of input `Tensor`(s) to the op.
    395       `outputs` is the `list` of output `Tensor`(s) from the op.
    396        Return value(s) from the callback are ignored.
    397     """
    398     # TODO(cais): (b/64674139) Allow access to function-internal operations.
    399     self._post_execution_callbacks.append(callback)
    400 
    401   def clear_post_execution_callbacks(self):
    402     """Clear all post-execution callbacks added to the context."""
    403     del self._post_execution_callbacks[:]
    404 
    405   @property
    406   def post_execution_callbacks(self):
    407     """Get the list of post-execution callbacks added to the context."""
    408     return self._post_execution_callbacks
    409 
    410   def enable_run_metadata(self):
    411     """Enables tracing of op execution via RunMetadata.
    412 
    413     To retrieve the accumulated metadata call context.export_run_metadata()
    414     and to stop tracing call context.disable_run_metadata().
    415     """
    416     if not self._context_handle:
    417       self._initialize_handle_and_devices()
    418     pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._context_handle)
    419 
    420   @tf_contextlib.contextmanager
    421   def device_policy(self, policy):
    422     if not self._context_handle:
    423       self._initialize_handle_and_devices()
    424     old = pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
    425         self._context_handle)
    426     pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
    427         self._handle, policy)
    428     try:
    429       yield
    430     finally:
    431       pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
    432           self._handle, old)
    433 
    434   def disable_run_metadata(self):
    435     """Disables tracing of op execution via RunMetadata."""
    436     if not self._context_handle:
    437       return
    438     pywrap_tensorflow.TFE_ContextDisableRunMetadata(self._context_handle)
    439 
    440   def export_run_metadata(self):
    441     """Returns a RunMetadata proto with accumulated information.
    442 
    443     The returned protocol buffer contains information since the most recent call
    444     to either enable_run_metadata or export_run_metadata.
    445 
    446     Returns:
    447       A RunMetadata protocol buffer. Or None if not enabled.
    448     """
    449     if not self._context_handle:
    450       return None
    451     with c_api_util.tf_buffer() as buffer_:
    452       with errors.raise_exception_on_not_ok_status() as status:
    453         pywrap_tensorflow.TFE_ContextExportRunMetadata(
    454             self._context_handle, buffer_, status)
    455       proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
    456     run_metadata = config_pb2.RunMetadata()
    457     run_metadata.ParseFromString(compat.as_bytes(proto_data))
    458     return run_metadata
    459 
    460 _context = None
    461 _context_lock = threading.Lock()
    462 
    463 
    464 def _initialize_context():
    465   global _context
    466   with _context_lock:
    467     if _context is None:
    468       _context = Context()
    469 
    470 
    471 def context():
    472   """Returns a singleton context object."""
    473   if _context is None:
    474     _initialize_context()
    475   return _context
    476 
    477 
    478 # TODO(agarwal): remove this.
    479 def get_default_context():
    480   """Same as context."""
    481   if _context is None:
    482     _initialize_context()
    483   return _context
    484 
    485 
    486 def set_global_seed(seed):
    487   """Sets the eager mode seed."""
    488   context()._set_global_seed(seed)  # pylint: disable=protected-access
    489 
    490 
    491 def global_seed():
    492   """Returns the eager mode seed."""
    493   return context()._seed  # pylint: disable=protected-access
    494 
    495 
    496 def internal_operation_seed():
    497   """Returns the operation seed generated based on global seed."""
    498   return context()._internal_operation_seed()  # pylint: disable=protected-access
    499 
    500 
    501 def in_graph_mode():
    502   """Returns True if current thread is in GRAPH mode for default context."""
    503   return context().in_graph_mode()
    504 
    505 
    506 def in_eager_mode():
    507   """Returns True if current thread is in EAGER mode for default context."""
    508   return context().in_eager_mode()
    509 
    510 
    511 def graph_mode():
    512   """Context-manager to enable GRAPH mode for current thread."""
    513   return context()._mode(GRAPH_MODE)  # pylint: disable=protected-access
    514 
    515 
    516 def eager_mode():
    517   """Context-manager to enable EAGER mode for current thread."""
    518   return context()._mode(EAGER_MODE)  # pylint: disable=protected-access
    519 
    520 
    521 # TODO(agarwal): get rid of this and use ops.name_scope instead.
    522 @contextlib.contextmanager
    523 def namescope(name):
    524   """ContextManager for creating hierarchical name scopes."""
    525   ctx = context()
    526   old_name = ctx.scope_name
    527   ctx.scope_name = "%s/%s" % (old_name, name) if old_name else name
    528   try:
    529     yield
    530   finally:
    531     ctx.scope_name = old_name
    532 
    533 
    534 def scope_name():
    535   """Name of the current scope."""
    536   return context().scope_name
    537 
    538 
    539 def device(name):
    540   """Context-manager to force placement of operations and Tensors on a device.
    541 
    542   Example:
    543   ```python
    544   with tfe.device('gpu:0'):
    545     with tfe.device('cpu:0'):
    546       shape = tf.constant([], dtype=tf.int32)
    547     x = tf.truncated_normal(shape, tf.float32)
    548   ```
    549   will ensure that the `shape` Tensor is on CPU but the `truncated_normal`
    550   operation runs on GPU 0.
    551 
    552   Args:
    553     name: Name of the device (see context().devices()), or None to
    554       perform automatic placement.
    555 
    556   Returns:
    557     Context manager for setting the device.
    558   """
    559   return context().device(name)
    560 
    561 
    562 def list_devices():
    563   """List the names of the available devices.
    564 
    565   Returns:
    566     Names of the available devices, as a `list`.
    567   """
    568   return context().devices()
    569 
    570 
    571 def num_gpus():
    572   """Get the number of available GPU devices.
    573 
    574   Returns:
    575     The number of available GPU devices.
    576   """
    577   return context().num_gpus()
    578 
    579 
    580 def enable_run_metadata():
    581   """Enables tracing of op execution via RunMetadata.
    582 
    583   To retrieve the accumulated metadata call context.export_run_metadata()
    584   and to stop tracing call context.disable_run_metadata().
    585   """
    586   context().enable_run_metadata()
    587 
    588 
    589 def disable_run_metadata():
    590   """Disables tracing of op execution via RunMetadata."""
    591   context().disable_run_metadata()
    592 
    593 
    594 def export_run_metadata():
    595   """Returns a RunMetadata proto with accumulated information.
    596 
    597   The returned protocol buffer contains information since the most recent call
    598   to either enable_run_metadata or export_run_metadata.
    599 
    600   Returns:
    601     A RunMetadata protocol buffer.
    602   """
    603   return context().export_run_metadata()
    604 
    605 
    606 # Not every user creates a Context via context.context()
    607 # (for example, enable_eager_execution in python/framework/ops.py),
    608 # but they do all import this file.  Note that IS_IN_GRAPH_MODE and
    609 # in_graph_mode are both parameterless functions.
    610 is_in_graph_mode.IS_IN_GRAPH_MODE = in_graph_mode
    611