Home | History | Annotate | Download | only in client
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """A client interface for TensorFlow."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import functools
     22 import re
     23 import threading
     24 
     25 import numpy as np
     26 
     27 from tensorflow.core.protobuf import config_pb2
     28 from tensorflow.python import pywrap_tensorflow as tf_session
     29 from tensorflow.python.framework import c_api_util
     30 from tensorflow.python.framework import device
     31 from tensorflow.python.framework import errors
     32 from tensorflow.python.framework import ops
     33 from tensorflow.python.framework import sparse_tensor
     34 from tensorflow.python.ops import session_ops
     35 from tensorflow.python.platform import tf_logging as logging
     36 from tensorflow.python.util import compat
     37 from tensorflow.python.util import nest
     38 from tensorflow.python.util.tf_export import tf_export
     39 
     40 
     41 class SessionInterface(object):
     42   """Base class for implementations of TensorFlow client sessions."""
     43 
     44   @property
     45   def graph(self):
     46     """The underlying TensorFlow graph, to be used in building Operations."""
     47     raise NotImplementedError('graph')
     48 
     49   @property
     50   def sess_str(self):
     51     """The TensorFlow process to which this session will connect."""
     52     raise NotImplementedError('sess_str')
     53 
     54   def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
     55     """Runs operations in the session. See `BaseSession.run()` for details."""
     56     raise NotImplementedError('run')
     57 
     58   def partial_run_setup(self, fetches, feeds=None):
     59     """Sets up the feeds and fetches for partial runs in the session."""
     60     raise NotImplementedError('partial_run_setup')
     61 
     62   def partial_run(self, handle, fetches, feed_dict=None):
     63     """Continues the execution with additional feeds and fetches."""
     64     raise NotImplementedError('partial_run')
     65 
     66 
     67 def _get_indexed_slices_value_from_fetches(fetched_vals):
     68   return ops.IndexedSlicesValue(fetched_vals[0], fetched_vals[1],
     69                                 fetched_vals[2]
     70                                 if len(fetched_vals) == 3 else None)
     71 
     72 
     73 def _get_feeds_for_indexed_slices(feed, feed_val):
     74   return list(
     75       zip([feed.values, feed.indices] if feed.dense_shape is None else
     76           [feed.values, feed.indices, feed.dense_shape], feed_val))
     77 
     78 
     79 # List of extensions supported to convert run arguments into actual fetches and
     80 # feeds.
     81 #
     82 # Each element in the list is a tuple of (Type, fetch_fn, feed_fn1, feed_fn2),
     83 # where the function signatures are:
     84 #   fetch_fn : Type -> (list of Tensors,
     85 #                       lambda: list of fetched np.ndarray -> TypeVal)
     86 #   feed_fn1 : Type, TypeVal -> list of (Tensor, value)
     87 #   feed_fn2 : Type -> list of Tensors
     88 #
     89 # `fetch_fn` describes how to expand fetch into its
     90 # component Tensors and how to contract the fetched results back into
     91 # a single return value.
     92 #
     93 # Each feed function describes how to unpack a single fed value and map it to
     94 # feeds of one or more tensors and their corresponding values: `feed_fn1` is
     95 # used to feed a run, `feed_fn2` to set up a partial run.
     96 #
     97 # TODO(touts): We could reimplement these as specialized _FeedMapper
     98 # implementations after we refactor the feed handling code to use them.
     99 #
    100 # Eventually, this registration could be opened up to support custom Tensor
    101 # expansions.
    102 # pylint: disable=g-long-lambda
    103 _REGISTERED_EXPANSIONS = [
    104     # SparseTensors are fetched as SparseTensorValues. They can be fed
    105     # SparseTensorValues or normal tuples.
    106     (sparse_tensor.SparseTensor,
    107      lambda fetch: (
    108          [fetch.indices, fetch.values, fetch.dense_shape],
    109          lambda fetched_vals: sparse_tensor.SparseTensorValue(*fetched_vals)),
    110      lambda feed, feed_val: list(zip(
    111          [feed.indices, feed.values, feed.dense_shape], feed_val)),
    112      lambda feed: [feed.indices, feed.values, feed.dense_shape]),
    113     # IndexedSlices are fetched as IndexedSlicesValues. They can be fed
    114     # IndexedSlicesValues or normal tuples.
    115     (ops.IndexedSlices,
    116      lambda fetch: (
    117          [fetch.values, fetch.indices] if fetch.dense_shape is None
    118          else [fetch.values, fetch.indices, fetch.dense_shape],
    119          _get_indexed_slices_value_from_fetches),
    120      _get_feeds_for_indexed_slices,
    121      lambda feed: [feed.values, feed.indices] if feed.dense_shape is None
    122      else [feed.values, feed.indices, feed.dense_shape]),
    123     # The default catches all other types and performs no expansions.
    124     (object,
    125      lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]),
    126      lambda feed, feed_val: [(feed, feed_val)],
    127      lambda feed: [feed])]
    128 
    129 # pylint: enable=g-long-lambda
    130 
    131 
    132 def _convert_to_numpy_obj(numpy_dtype, obj):
    133   """Explicitly convert obj based on numpy type except for string type."""
    134   return numpy_dtype(obj) if numpy_dtype is not object else str(obj)
    135 
    136 
    137 def register_session_run_conversion_functions(
    138     tensor_type,
    139     fetch_function,
    140     feed_function=None,
    141     feed_function_for_partial_run=None):
    142   """Register fetch and feed conversion functions for `tf.Session.run()`.
    143 
    144   This function registers a triple of conversion functions for fetching and/or
    145   feeding values of user-defined types in a call to tf.Session.run().
    146 
    147   An example
    148 
    149   ```python
    150      class SquaredTensor(object):
    151        def __init__(self, tensor):
    152          self.sq = tf.square(tensor)
    153      #you can define conversion functions as follows:
    154      fetch_function = lambda squared_tensor:([squared_tensor.sq],
    155                                              lambda val: val[0])
    156      feed_function = lambda feed, feed_val: [(feed.sq, feed_val)]
    157      feed_function_for_partial_run = lambda feed: [feed.sq]
    158      #then after invoking this register function, you can use as follows:
    159      session.run(squared_tensor1,
    160                  feed_dict = {squared_tensor2 : some_numpy_array})
    161   ```
    162 
    163   Args:
    164     tensor_type: The type for which you want to register a conversion function.
    165     fetch_function: A callable that takes an object of type `tensor_type` and
    166       returns a tuple, where the first element is a list of `tf.Tensor` objects,
    167       and the second element is a callable that takes a list of ndarrays and
    168       returns an object of some value type that corresponds to `tensor_type`.
    169       fetch_function describes how to expand fetch into its component Tensors
    170       and how to contract the fetched results back into a single return value.
    171     feed_function: A callable that takes feed_key and feed_value as input, and
    172       returns a list of tuples (feed_tensor, feed_val), feed_key must have type
    173       `tensor_type`, and feed_tensor must have type `tf.Tensor`. Each feed
    174       function describes how to unpack a single fed value and map it to feeds
    175       of one or more tensors and their corresponding values.
    176     feed_function_for_partial_run: A callable for specifying tensor values to
    177       feed when setting up a partial run, which takes a `tensor_type` type
    178       object as input, and returns a list of Tensors.
    179   """
    180   for conversion_function in _REGISTERED_EXPANSIONS:
    181     if issubclass(conversion_function[0], tensor_type):
    182       raise ValueError('%s has already been registered so ignore it.',
    183                        tensor_type)
    184       return
    185   _REGISTERED_EXPANSIONS.insert(0, (tensor_type, fetch_function, feed_function,
    186                                     feed_function_for_partial_run))
    187 
    188 
    189 class _FetchMapper(object):
    190   """Definition of the interface provided by fetch mappers.
    191 
    192   Fetch mappers are utility classes used by the _FetchHandler to handle
    193   arbitrary structures for the `fetch` argument to `Session.run()`.
    194 
    195   The `fetch` argument can be of various shapes: single tensor or op, list of
    196   fetches, tuple of fetches, namedtuple of fetches, or dict of fetches.  The
    197   structures can be arbitrarily nested.
    198 
    199   The low level run() API only wants a list of tensor or op names.  The various
    200   `_FetchMapper` subclasses below take care of handling the different shapes:
    201   uniquifying the fetches, and constructing results with the original shape.
    202   """
    203 
    204   def unique_fetches(self):
    205     """Return the list of unique tensors or ops needed by this fetch mapper.
    206 
    207     Returns:
    208       A list of tensors or ops.
    209     """
    210     raise NotImplementedError('Must be implemented by subclasses')
    211 
    212   def build_results(self, values):
    213     """Build results that match the original shape of the fetch.
    214 
    215     Args:
    216       values: List of values returned by run(). The values correspond
    217         exactly to the list tensors or ops returned by unique_fetches().
    218 
    219     Returns:
    220       A struct of the same shape as the original fetch object handled by
    221       this fetch mapper.  In the returned struct, the original fetches are
    222       replaced by their fetched values.
    223     """
    224     raise NotImplementedError('Must be implemented by subclasses')
    225 
    226   @staticmethod
    227   def for_fetch(fetch):
    228     """Creates fetch mapper that handles the structure of `fetch`.
    229 
    230     The default graph must be the one from which we want to fetch values when
    231     this function is called.
    232 
    233     Args:
    234       fetch: An arbitrary fetch structure: singleton, list, tuple,
    235         namedtuple, or dict.
    236 
    237     Returns:
    238       An instance of a subclass of `_FetchMapper` that handles the shape.
    239     """
    240     if fetch is None:
    241       raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
    242                                                                  type(fetch)))
    243     elif isinstance(fetch, (list, tuple)):
    244       # NOTE(touts): This is also the code path for namedtuples.
    245       return _ListFetchMapper(fetch)
    246     elif isinstance(fetch, dict):
    247       return _DictFetchMapper(fetch)
    248     else:
    249       # Look for a handler in the registered expansions.
    250       for tensor_type, fetch_fn, _, _ in _REGISTERED_EXPANSIONS:
    251         if isinstance(fetch, tensor_type):
    252           fetches, contraction_fn = fetch_fn(fetch)
    253           return _ElementFetchMapper(fetches, contraction_fn)
    254     # Did not find anything.
    255     raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
    256                                                                type(fetch)))
    257 
    258 
    259 class _ElementFetchMapper(_FetchMapper):
    260   """Fetch mapper for singleton tensors and ops."""
    261 
    262   def __init__(self, fetches, contraction_fn):
    263     """Creates an _ElementFetchMapper.
    264 
    265     This is the fetch mapper used for leaves in the fetch struct.  Because of
    266     the expansions mechanism, a leaf can actually fetch more than one tensor.
    267 
    268     Also note that the fetches here can be just strings (tensor or op names) or
    269     any other object that the graph knows how to convert to a tensor, such as a
    270     Variable.  So we have to run each fetch through `as_graph_element()` to get
    271     the corresponding tensor or op.
    272 
    273     Args:
    274       fetches: List of objects, as returned by a fetch_fn defined
    275         in _REGISTERED_EXPANSIONS.
    276       contraction_fn: Callable as returned by a fetch_fn.
    277     """
    278     self._unique_fetches = []
    279     for fetch in fetches:
    280       try:
    281         self._unique_fetches.append(ops.get_default_graph().as_graph_element(
    282             fetch, allow_tensor=True, allow_operation=True))
    283       except TypeError as e:
    284         raise TypeError('Fetch argument %r has invalid type %r, '
    285                         'must be a string or Tensor. (%s)' %
    286                         (fetch, type(fetch), str(e)))
    287       except ValueError as e:
    288         raise ValueError('Fetch argument %r cannot be interpreted as a '
    289                          'Tensor. (%s)' % (fetch, str(e)))
    290       except KeyError as e:
    291         raise ValueError('Fetch argument %r cannot be interpreted as a '
    292                          'Tensor. (%s)' % (fetch, str(e)))
    293     self._contraction_fn = contraction_fn
    294 
    295   def unique_fetches(self):
    296     return self._unique_fetches
    297 
    298   def build_results(self, values):
    299     if not values:
    300       # 'Operation' case
    301       return None
    302     else:
    303       return self._contraction_fn(values)
    304 
    305 
    306 def _uniquify_fetches(fetch_mappers):
    307   """Uniquifies fetches from a list of fetch_mappers.
    308 
    309   This is a utility function used by _ListFetchMapper and _DictFetchMapper.  It
    310   gathers all the unique fetches from a list of mappers and builds a list
    311   containing all of them but without duplicates (unique_fetches).
    312 
    313   It also returns a 2-D list of integers (values_indices) indicating at which
    314   index in unique_fetches the fetches of the mappers are located.
    315 
    316   This list is as follows:
    317     values_indices[mapper_index][mapper_fetch_index] = unique_fetches_index
    318 
    319   Args:
    320     fetch_mappers: list of fetch mappers.
    321 
    322   Returns:
    323     A list of fetches.
    324     A 2-D list of integers.
    325   """
    326   unique_fetches = []
    327   value_indices = []
    328   seen_fetches = {}
    329   for m in fetch_mappers:
    330     m_value_indices = []
    331     for f in m.unique_fetches():
    332       j = seen_fetches.get(f)
    333       if j is None:
    334         j = len(seen_fetches)
    335         seen_fetches[f] = j
    336         unique_fetches.append(f)
    337       m_value_indices.append(j)
    338     value_indices.append(m_value_indices)
    339   return unique_fetches, value_indices
    340 
    341 
    342 class _ListFetchMapper(_FetchMapper):
    343   """Fetch mapper for lists, tuples, and namedtuples."""
    344 
    345   def __init__(self, fetches):
    346     """Creates a _ListFetchMapper.
    347 
    348     Args:
    349       fetches: List, tuple, or namedtuple of fetches.
    350     """
    351     self._fetch_type = type(fetches)
    352     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
    353     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
    354 
    355   def unique_fetches(self):
    356     return self._unique_fetches
    357 
    358   def build_results(self, values):
    359     # Create the list of results for each mapper.
    360     results = []
    361     for m, vi in zip(self._mappers, self._value_indices):
    362       results.append(m.build_results([values[j] for j in vi]))
    363     # Return a value of the original type of the fetches.
    364     if self._fetch_type == list:
    365       return results
    366     elif self._fetch_type == tuple:
    367       return tuple(results)
    368     else:
    369       # This is the code path for namedtuple.
    370       return self._fetch_type(*results)
    371 
    372 
    373 class _DictFetchMapper(_FetchMapper):
    374   """Fetch mapper for dicts."""
    375 
    376   def __init__(self, fetches):
    377     """Creates a _DictFetchMapper.
    378 
    379     Args:
    380       fetches: Dict of fetches.
    381     """
    382     self._fetch_type = type(fetches)
    383     self._keys = fetches.keys()
    384     self._mappers = [
    385         _FetchMapper.for_fetch(fetch) for fetch in fetches.values()
    386     ]
    387     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
    388 
    389   def unique_fetches(self):
    390     return self._unique_fetches
    391 
    392   def build_results(self, values):
    393     results = self._fetch_type()
    394     for k, m, vi in zip(self._keys, self._mappers, self._value_indices):
    395       results[k] = m.build_results([values[j] for j in vi])
    396     return results
    397 
    398 
    399 class _FetchHandler(object):
    400   """Handler for structured fetches.
    401 
    402   Given a graph, a user-provided structure for fetches, and a feed dict, this
    403   class takes care of generating a list of tensor names to fetch and op names
    404   to run for a low level `run()` call.
    405 
    406   Given the results of the low level run call, this class can also rebuild a
    407   result structure matching the user-provided structure for fetches, but
    408   containing the corresponding results.
    409   """
    410 
    411   # TODO(touts): Make this class also take care of destructuring the feed
    412   # dict instead of doing it in the callers.
    413 
    414   def __init__(self, graph, fetches, feeds, feed_handles=None):
    415     """Creates a fetch handler.
    416 
    417     Args:
    418       graph: Graph of the fetches.   Used to check for fetchability
    419         and to convert all fetches to tensors or ops as needed.
    420       fetches: An arbitrary fetch structure: singleton, list, tuple,
    421         namedtuple, or dict.
    422       feeds: A feed dict where keys are Tensors.
    423       feed_handles: A dict from feed Tensors to TensorHandle objects used as
    424         direct feeds.
    425     """
    426     with graph.as_default():
    427       self._fetch_mapper = _FetchMapper.for_fetch(fetches)
    428     self._fetches = []
    429     self._targets = []
    430     self._feeds = feeds
    431     self._feed_handles = feed_handles or {}
    432     self._ops = []
    433     self._fetch_handles = {}
    434     for fetch in self._fetch_mapper.unique_fetches():
    435       if isinstance(fetch, ops.Operation):
    436         self._assert_fetchable(graph, fetch)
    437         self._targets.append(fetch)
    438         self._ops.append(True)
    439       else:
    440         self._assert_fetchable(graph, fetch.op)
    441         self._fetches.append(fetch)
    442         self._ops.append(False)
    443       # Remember the fetch if it is for a tensor handle.
    444       if (isinstance(fetch, ops.Tensor) and
    445           (fetch.op.type == 'GetSessionHandle' or
    446            fetch.op.type == 'GetSessionHandleV2')):
    447         self._fetch_handles[fetch] = fetch.op.inputs[0].dtype
    448     self._final_fetches = [x for x in self._fetches if x not in feeds]
    449 
    450   def _assert_fetchable(self, graph, op):
    451     if not graph.is_fetchable(op):
    452       raise ValueError(
    453           'Operation %r has been marked as not fetchable.' % op.name)
    454 
    455   def fetches(self):
    456     """Return the unique names of tensors to fetch.
    457 
    458     Returns:
    459       A list of strings.
    460     """
    461     return self._final_fetches
    462 
    463   def targets(self):
    464     """Return the unique names of ops to run.
    465 
    466     Returns:
    467       A list of strings.
    468     """
    469     return self._targets
    470 
    471   def build_results(self, session, tensor_values):
    472     """Build results matching the original fetch shape.
    473 
    474     `tensor_values` must be a list of the same length as
    475     the one returned by `fetches()`, and holding the requested
    476     fetch values.
    477 
    478     This method builds a struct with the same shape as the original `fetches`
    479     passed to the constructor, in which the fetches are replaced by their
    480     fetched value.
    481 
    482     Args:
    483       session: The enclosing session.  Used for tensor handles.
    484       tensor_values: List of values matching the list returned
    485         by fetches().
    486 
    487     Returns:
    488       A structure of the same shape as the original `fetches` argument but
    489         containing tensors or None (for fetched ops).
    490     """
    491     full_values = []
    492     assert len(self._final_fetches) == len(tensor_values)
    493     i = 0
    494     j = 0
    495     for is_op in self._ops:
    496       if is_op:
    497         full_values.append(None)
    498       else:
    499         # If the fetch was in the feeds, use the fed value, otherwise
    500         # use the returned value.
    501         if self._fetches[i] in self._feed_handles:
    502           # A fetch had a corresponding direct TensorHandle feed. Call eval()
    503           # to obtain the Tensor value from the TensorHandle.
    504           value = self._feed_handles[self._fetches[i]].eval()
    505         else:
    506           value = self._feeds.get(self._fetches[i])
    507         if value is None:
    508           value = tensor_values[j]
    509           j += 1
    510         dtype = self._fetch_handles.get(self._fetches[i])
    511         if dtype:
    512           full_values.append(session_ops.TensorHandle(value, dtype, session))
    513         else:
    514           full_values.append(value)
    515         i += 1
    516     assert j == len(tensor_values)
    517     return self._fetch_mapper.build_results(full_values)
    518 
    519 
    520 def _name_list(tensor_list):
    521   """Utility function for transitioning to the new session API.
    522 
    523   Args:
    524     tensor_list: a list of `Tensor`s.
    525 
    526   Returns:
    527     A list of each `Tensor`s name (as byte arrays).
    528   """
    529   return [compat.as_bytes(t.name) for t in tensor_list]
    530 
    531 
    532 class _DeviceAttributes(object):
    533   """Struct-like object describing a device's attributes.
    534 
    535   Each device has 3 key properties:
    536    - name: the fully-qualified TensorFlow path to the device. For
    537         example: /job:worker/replica:0/task:3/device:CPU:0
    538    - device_type: the type of the device (e.g. CPU, GPU, TPU, etc.)
    539    - memory_limit_bytes: the maximum amount of memory available on the device
    540         (in bytes).
    541   """
    542 
    543   def __init__(self, name, device_type, memory_limit_bytes):
    544     self._name = device.canonical_name(name)
    545     self._device_type = device_type
    546     self._memory_limit_bytes = memory_limit_bytes
    547 
    548   @property
    549   def name(self):
    550     return self._name
    551 
    552   @property
    553   def device_type(self):
    554     return self._device_type
    555 
    556   @property
    557   def memory_limit_bytes(self):
    558     return self._memory_limit_bytes
    559 
    560   def __repr__(self):
    561     return '_DeviceAttributes(%s, %s, %d)' % (
    562         self.name,
    563         self.device_type,
    564         self.memory_limit_bytes,
    565     )
    566 
    567 
    568 class BaseSession(SessionInterface):
    569   """A class for interacting with a TensorFlow computation.
    570 
    571   The BaseSession enables incremental graph building with inline
    572   execution of Operations and evaluation of Tensors.
    573   """
    574 
    575   def __init__(self, target='', graph=None, config=None):
    576     """Constructs a new TensorFlow session.
    577 
    578     Args:
    579       target: (Optional) The TensorFlow execution engine to connect to.
    580       graph: (Optional) The graph to be used. If this argument is None,
    581         the default graph will be used.
    582       config: (Optional) ConfigProto proto used to configure the session.
    583 
    584     Raises:
    585       tf.errors.OpError: Or one of its subclasses if an error occurs while
    586         creating the TensorFlow session.
    587       TypeError: If one of the arguments has the wrong type.
    588     """
    589     if graph is None:
    590       self._graph = ops.get_default_graph()
    591     else:
    592       if not isinstance(graph, ops.Graph):
    593         raise TypeError('graph must be a tf.Graph, but got %s' % type(graph))
    594       self._graph = graph
    595 
    596     self._opened = False
    597     self._closed = False
    598 
    599     self._current_version = 0
    600     self._extend_lock = threading.Lock()
    601     if target is not None:
    602       try:
    603         self._target = compat.as_bytes(target)
    604       except TypeError:
    605         raise TypeError('target must be a string, but got %s' % type(target))
    606     else:
    607       self._target = None
    608 
    609     self._delete_lock = threading.Lock()
    610     self._dead_handles = []
    611 
    612     if config is not None:
    613       if not isinstance(config, config_pb2.ConfigProto):
    614         raise TypeError(
    615             'config must be a tf.ConfigProto, but got %s' % type(config))
    616       self._config = config
    617       self._add_shapes = config.graph_options.infer_shapes
    618     else:
    619       self._config = None
    620       self._add_shapes = False
    621 
    622     # pylint: disable=protected-access
    623     # We cache _USE_C_API's value because some test cases will create a session
    624     # with _USE_C_API = False but set it back to True before calling close().
    625     self._created_with_new_api = ops._USE_C_API
    626     # pylint: enable=protected-access
    627 
    628     self._session = None
    629     opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
    630     try:
    631       with errors.raise_exception_on_not_ok_status() as status:
    632         if self._created_with_new_api:
    633           # pylint: disable=protected-access
    634           self._session = tf_session.TF_NewSession(self._graph._c_graph, opts,
    635                                                    status)
    636           # pylint: enable=protected-access
    637         else:
    638           self._session = tf_session.TF_NewDeprecatedSession(opts, status)
    639     finally:
    640       tf_session.TF_DeleteSessionOptions(opts)
    641 
    642   def list_devices(self):
    643     """Lists available devices in this session.
    644 
    645     ```python
    646     devices = sess.list_devices()
    647     for d in devices:
    648       print(d.name)
    649     ```
    650 
    651     Each element in the list has the following properties:
    652      - `name`: A string with the full name of the device. ex:
    653           `/job:worker/replica:0/task:3/device:CPU:0`
    654      - `device_type`: The type of the device (e.g. `CPU`, `GPU`, `TPU`.)
    655      - `memory_limit`: The maximum amount of memory available on the device.
    656           Note: depending on the device, it is possible the usable memory could
    657           be substantially less.
    658     Raises:
    659       tf.errors.OpError: If it encounters an error (e.g. session is in an
    660       invalid state, or network errors occur).
    661 
    662     Returns:
    663       A list of devices in the session.
    664     """
    665     with errors.raise_exception_on_not_ok_status() as status:
    666       if self._created_with_new_api:
    667         raw_device_list = tf_session.TF_SessionListDevices(
    668             self._session, status)
    669       else:
    670         raw_device_list = tf_session.TF_DeprecatedSessionListDevices(
    671             self._session, status)
    672       device_list = []
    673       size = tf_session.TF_DeviceListCount(raw_device_list)
    674       for i in range(size):
    675         name = tf_session.TF_DeviceListName(raw_device_list, i, status)
    676         device_type = tf_session.TF_DeviceListType(raw_device_list, i, status)
    677         memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i, status)
    678         device_list.append(_DeviceAttributes(name, device_type, memory))
    679       tf_session.TF_DeleteDeviceList(raw_device_list)
    680       return device_list
    681 
    682   def close(self):
    683     """Closes this session.
    684 
    685     Calling this method frees all resources associated with the session.
    686 
    687     Raises:
    688       tf.errors.OpError: Or one of its subclasses if an error occurs while
    689         closing the TensorFlow session.
    690     """
    691     if self._created_with_new_api:
    692       if self._session and not self._closed:
    693         self._closed = True
    694         with errors.raise_exception_on_not_ok_status() as status:
    695           tf_session.TF_CloseSession(self._session, status)
    696 
    697     else:
    698       with self._extend_lock:
    699         if self._opened and not self._closed:
    700           self._closed = True
    701           with errors.raise_exception_on_not_ok_status() as status:
    702             tf_session.TF_CloseDeprecatedSession(self._session, status)
    703 
    704   def __del__(self):
    705     # cleanly ignore all exceptions
    706     try:
    707       self.close()
    708     except Exception:  # pylint: disable=broad-except
    709       pass
    710     if self._session is not None:
    711       try:
    712         status = c_api_util.ScopedTFStatus()
    713         if self._created_with_new_api:
    714           tf_session.TF_DeleteSession(self._session, status)
    715         else:
    716           tf_session.TF_DeleteDeprecatedSession(self._session, status)
    717       except AttributeError:
    718         # At shutdown, `c_api_util` or `tf_session` may have been garbage
    719         # collected, causing the above method calls to fail. In this case,
    720         # silently leak since the program is about to terminate anyway.
    721         pass
    722       self._session = None
    723 
    724   @property
    725   def graph(self):
    726     """The graph that was launched in this session."""
    727     return self._graph
    728 
    729   @property
    730   def graph_def(self):
    731     """A serializable version of the underlying TensorFlow graph.
    732 
    733     Returns:
    734       A graph_pb2.GraphDef proto containing nodes for all of the Operations in
    735       the underlying TensorFlow graph.
    736     """
    737     return self._graph.as_graph_def(add_shapes=self._add_shapes)
    738 
    739   @property
    740   def sess_str(self):
    741     return self._target
    742 
    743   def as_default(self):
    744     """Returns a context manager that makes this object the default session.
    745 
    746     Use with the `with` keyword to specify that calls to
    747     @{tf.Operation.run} or @{tf.Tensor.eval} should be executed in
    748     this session.
    749 
    750     ```python
    751     c = tf.constant(..)
    752     sess = tf.Session()
    753 
    754     with sess.as_default():
    755       assert tf.get_default_session() is sess
    756       print(c.eval())
    757     ```
    758 
    759     To get the current default session, use @{tf.get_default_session}.
    760 
    761     *N.B.* The `as_default` context manager *does not* close the
    762     session when you exit the context, and you must close the session
    763     explicitly.
    764 
    765     ```python
    766     c = tf.constant(...)
    767     sess = tf.Session()
    768     with sess.as_default():
    769       print(c.eval())
    770     # ...
    771     with sess.as_default():
    772       print(c.eval())
    773 
    774     sess.close()
    775     ```
    776 
    777     Alternatively, you can use `with tf.Session():` to create a
    778     session that is automatically closed on exiting the context,
    779     including when an uncaught exception is raised.
    780 
    781     *N.B.* The default session is a property of the current thread. If you
    782     create a new thread, and wish to use the default session in that
    783     thread, you must explicitly add a `with sess.as_default():` in that
    784     thread's function.
    785 
    786     *N.B.* Entering a `with sess.as_default():` block does not affect
    787     the current default graph. If you are using multiple graphs, and
    788     `sess.graph` is different from the value of @{tf.get_default_graph},
    789     you must explicitly enter a `with sess.graph.as_default():` block
    790     to make `sess.graph` the default graph.
    791 
    792     Returns:
    793       A context manager using this session as the default session.
    794     """
    795     return ops.default_session(self)
    796 
    797   def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
    798     """Runs operations and evaluates tensors in `fetches`.
    799 
    800     This method runs one "step" of TensorFlow computation, by
    801     running the necessary graph fragment to execute every `Operation`
    802     and evaluate every `Tensor` in `fetches`, substituting the values in
    803     `feed_dict` for the corresponding input values.
    804 
    805     The `fetches` argument may be a single graph element, or an arbitrarily
    806     nested list, tuple, namedtuple, dict, or OrderedDict containing graph
    807     elements at its leaves.  A graph element can be one of the following types:
    808 
    809     * An @{tf.Operation}.
    810       The corresponding fetched value will be `None`.
    811     * A @{tf.Tensor}.
    812       The corresponding fetched value will be a numpy ndarray containing the
    813       value of that tensor.
    814     * A @{tf.SparseTensor}.
    815       The corresponding fetched value will be a
    816       @{tf.SparseTensorValue}
    817       containing the value of that sparse tensor.
    818     * A `get_tensor_handle` op.  The corresponding fetched value will be a
    819       numpy ndarray containing the handle of that tensor.
    820     * A `string` which is the name of a tensor or operation in the graph.
    821 
    822     The value returned by `run()` has the same shape as the `fetches` argument,
    823     where the leaves are replaced by the corresponding values returned by
    824     TensorFlow.
    825 
    826     Example:
    827 
    828     ```python
    829        a = tf.constant([10, 20])
    830        b = tf.constant([1.0, 2.0])
    831        # 'fetches' can be a singleton
    832        v = session.run(a)
    833        # v is the numpy array [10, 20]
    834        # 'fetches' can be a list.
    835        v = session.run([a, b])
    836        # v is a Python list with 2 numpy arrays: the 1-D array [10, 20] and the
    837        # 1-D array [1.0, 2.0]
    838        # 'fetches' can be arbitrary lists, tuples, namedtuple, dicts:
    839        MyData = collections.namedtuple('MyData', ['a', 'b'])
    840        v = session.run({'k1': MyData(a, b), 'k2': [b, a]})
    841        # v is a dict with
    842        # v['k1'] is a MyData namedtuple with 'a' (the numpy array [10, 20]) and
    843        # 'b' (the numpy array [1.0, 2.0])
    844        # v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array
    845        # [10, 20].
    846     ```
    847 
    848     The optional `feed_dict` argument allows the caller to override
    849     the value of tensors in the graph. Each key in `feed_dict` can be
    850     one of the following types:
    851 
    852     * If the key is a @{tf.Tensor}, the
    853       value may be a Python scalar, string, list, or numpy ndarray
    854       that can be converted to the same `dtype` as that
    855       tensor. Additionally, if the key is a
    856       @{tf.placeholder}, the shape of
    857       the value will be checked for compatibility with the placeholder.
    858     * If the key is a
    859       @{tf.SparseTensor},
    860       the value should be a
    861       @{tf.SparseTensorValue}.
    862     * If the key is a nested tuple of `Tensor`s or `SparseTensor`s, the value
    863       should be a nested tuple with the same structure that maps to their
    864       corresponding values as above.
    865 
    866     Each value in `feed_dict` must be convertible to a numpy array of the dtype
    867     of the corresponding key.
    868 
    869     The optional `options` argument expects a [`RunOptions`] proto. The options
    870     allow controlling the behavior of this particular step (e.g. turning tracing
    871     on).
    872 
    873     The optional `run_metadata` argument expects a [`RunMetadata`] proto. When
    874     appropriate, the non-Tensor output of this step will be collected there. For
    875     example, when users turn on tracing in `options`, the profiled info will be
    876     collected into this argument and passed back.
    877 
    878     Args:
    879       fetches: A single graph element, a list of graph elements,
    880         or a dictionary whose values are graph elements or lists of graph
    881         elements (described above).
    882       feed_dict: A dictionary that maps graph elements to values
    883         (described above).
    884       options: A [`RunOptions`] protocol buffer
    885       run_metadata: A [`RunMetadata`] protocol buffer
    886 
    887     Returns:
    888       Either a single value if `fetches` is a single graph element, or
    889       a list of values if `fetches` is a list, or a dictionary with the
    890       same keys as `fetches` if that is a dictionary (described above).
    891 
    892     Raises:
    893       RuntimeError: If this `Session` is in an invalid state (e.g. has been
    894         closed).
    895       TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type.
    896       ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a
    897         `Tensor` that doesn't exist.
    898     """
    899     options_ptr = tf_session.TF_NewBufferFromString(
    900         compat.as_bytes(options.SerializeToString())) if options else None
    901     run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
    902 
    903     try:
    904       result = self._run(None, fetches, feed_dict, options_ptr,
    905                          run_metadata_ptr)
    906       if run_metadata:
    907         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
    908         run_metadata.ParseFromString(compat.as_bytes(proto_data))
    909     finally:
    910       if run_metadata_ptr:
    911         tf_session.TF_DeleteBuffer(run_metadata_ptr)
    912       if options:
    913         tf_session.TF_DeleteBuffer(options_ptr)
    914     return result
    915 
    916   def partial_run(self, handle, fetches, feed_dict=None):
    917     """Continues the execution with more feeds and fetches.
    918 
    919     This is EXPERIMENTAL and subject to change.
    920 
    921     To use partial execution, a user first calls `partial_run_setup()` and
    922     then a sequence of `partial_run()`. `partial_run_setup` specifies the
    923     list of feeds and fetches that will be used in the subsequent
    924     `partial_run` calls.
    925 
    926     The optional `feed_dict` argument allows the caller to override
    927     the value of tensors in the graph. See run() for more information.
    928 
    929     Below is a simple example:
    930 
    931     ```python
    932     a = array_ops.placeholder(dtypes.float32, shape=[])
    933     b = array_ops.placeholder(dtypes.float32, shape=[])
    934     c = array_ops.placeholder(dtypes.float32, shape=[])
    935     r1 = math_ops.add(a, b)
    936     r2 = math_ops.multiply(r1, c)
    937 
    938     h = sess.partial_run_setup([r1, r2], [a, b, c])
    939     res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
    940     res = sess.partial_run(h, r2, feed_dict={c: res})
    941     ```
    942 
    943     Args:
    944       handle: A handle for a sequence of partial runs.
    945       fetches: A single graph element, a list of graph elements,
    946         or a dictionary whose values are graph elements or lists of graph
    947         elements (see documentation for `run`).
    948       feed_dict: A dictionary that maps graph elements to values
    949         (described above).
    950 
    951     Returns:
    952       Either a single value if `fetches` is a single graph element, or
    953       a list of values if `fetches` is a list, or a dictionary with the
    954       same keys as `fetches` if that is a dictionary
    955       (see documentation for `run`).
    956 
    957     Raises:
    958       tf.errors.OpError: Or one of its subclasses on error.
    959     """
    960     # TODO(touts): Support feeding and fetching the same tensor.
    961     return self._run(handle, fetches, feed_dict, None, None)
    962 
    963   def partial_run_setup(self, fetches, feeds=None):
    964     """Sets up a graph with feeds and fetches for partial run.
    965 
    966     This is EXPERIMENTAL and subject to change.
    967 
    968     Note that contrary to `run`, `feeds` only specifies the graph elements.
    969     The tensors will be supplied by the subsequent `partial_run` calls.
    970 
    971     Args:
    972       fetches: A single graph element, or a list of graph elements.
    973       feeds: A single graph element, or a list of graph elements.
    974 
    975     Returns:
    976       A handle for partial run.
    977 
    978     Raises:
    979       RuntimeError: If this `Session` is in an invalid state (e.g. has been
    980         closed).
    981       TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type.
    982       tf.errors.OpError: Or one of its subclasses if a TensorFlow error happens.
    983     """
    984 
    985     def _feed_fn(feed):
    986       for tensor_type, _, _, feed_fn in _REGISTERED_EXPANSIONS:
    987         if isinstance(feed, tensor_type):
    988           return feed_fn(feed)
    989       raise TypeError('Feed argument %r has invalid type %r' % (feed,
    990                                                                 type(feed)))
    991 
    992     # Check session.
    993     if self._closed:
    994       raise RuntimeError('Attempted to use a closed Session.')
    995     if self.graph.version == 0:
    996       raise RuntimeError('The Session graph is empty.  Add operations to the '
    997                          'graph before calling run().')
    998 
    999     if feeds is None:
   1000       feeds = []
   1001     # Create request.
   1002     feed_list = []
   1003 
   1004     # Validate and process feed_list.
   1005     is_list_feed = isinstance(feeds, (list, tuple))
   1006     if not is_list_feed:
   1007       feeds = [feeds]
   1008     for feed in feeds:
   1009       for subfeed in _feed_fn(feed):
   1010         try:
   1011           subfeed_t = self.graph.as_graph_element(
   1012               subfeed, allow_tensor=True, allow_operation=False)
   1013           if self._created_with_new_api:
   1014             # pylint: disable=protected-access
   1015             feed_list.append(subfeed_t._as_tf_output())
   1016             # pylint: enable=protected-access
   1017           else:
   1018             feed_list.append(compat.as_bytes(subfeed_t.name))
   1019         except Exception as e:
   1020           e.message = ('Cannot interpret feed_list key as Tensor: ' + e.message)
   1021           e.args = (e.message,)
   1022           raise e
   1023 
   1024     # Validate and process fetches.
   1025     # TODO(touts): Support feeding and fetching the same tensor.
   1026     fetch_handler = _FetchHandler(self._graph, fetches, {})
   1027 
   1028     # Set up a graph with feeds and fetches for partial run.
   1029     def _setup_fn(session, feed_list, fetch_list, target_list):
   1030       self._extend_graph()
   1031       with errors.raise_exception_on_not_ok_status() as status:
   1032         if self._created_with_new_api:
   1033           return tf_session.TF_SessionPRunSetup_wrapper(
   1034               session, feed_list, fetch_list, target_list, status)
   1035         else:
   1036           return tf_session.TF_PRunSetup(session, feed_list, fetch_list,
   1037                                          target_list, status)
   1038 
   1039     if self._created_with_new_api:
   1040       # pylint: disable=protected-access
   1041       final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()]
   1042       final_targets = [op._c_op for op in fetch_handler.targets()]
   1043       # pylint: enable=protected-access
   1044     else:
   1045       final_fetches = _name_list(fetch_handler.fetches())
   1046       final_targets = _name_list(fetch_handler.targets())
   1047 
   1048     return self._do_call(_setup_fn, self._session, feed_list, final_fetches,
   1049                          final_targets)
   1050 
   1051   def _run(self, handle, fetches, feed_dict, options, run_metadata):
   1052     """Perform either run or partial_run, depending the presence of `handle`."""
   1053 
   1054     def _feed_fn(feed, feed_val):
   1055       for tensor_type, _, feed_fn, _ in _REGISTERED_EXPANSIONS:
   1056         if isinstance(feed, tensor_type):
   1057           return feed_fn(feed, feed_val)
   1058       raise TypeError('Feed argument %r has invalid type %r' % (feed,
   1059                                                                 type(feed)))
   1060 
   1061     # Check session.
   1062     if self._closed:
   1063       raise RuntimeError('Attempted to use a closed Session.')
   1064     if self.graph.version == 0:
   1065       raise RuntimeError('The Session graph is empty.  Add operations to the '
   1066                          'graph before calling run().')
   1067 
   1068     # Create request.
   1069     feed_dict_tensor = {}
   1070     feed_map = {}
   1071 
   1072     # Validate and process feed_dict.
   1073     feed_handles = {}
   1074     if feed_dict:
   1075       feed_dict = nest.flatten_dict_items(feed_dict)
   1076       for feed, feed_val in feed_dict.items():
   1077         for subfeed, subfeed_val in _feed_fn(feed, feed_val):
   1078           try:
   1079             subfeed_t = self.graph.as_graph_element(
   1080                 subfeed, allow_tensor=True, allow_operation=False)
   1081           except Exception as e:
   1082             raise TypeError(
   1083                 'Cannot interpret feed_dict key as Tensor: ' + e.args[0])
   1084 
   1085           if isinstance(subfeed_val, ops.Tensor):
   1086             raise TypeError('The value of a feed cannot be a tf.Tensor object. '
   1087                             'Acceptable feed values include Python scalars, '
   1088                             'strings, lists, numpy ndarrays, or TensorHandles.')
   1089 
   1090           subfeed_dtype = subfeed_t.dtype.as_numpy_dtype
   1091           if isinstance(subfeed_val, int) and _convert_to_numpy_obj(
   1092               subfeed_dtype, subfeed_val) != subfeed_val:
   1093             raise TypeError(
   1094                 'Type of feed value ' + str(subfeed_val) + ' with type ' + str(
   1095                     type(subfeed_val)) +
   1096                 ' is not compatible with Tensor type ' + str(subfeed_dtype) +
   1097                 '. Try explicitly setting the type of the feed tensor'
   1098                 ' to a larger type (e.g. int64).')
   1099 
   1100           is_tensor_handle_feed = isinstance(subfeed_val,
   1101                                              session_ops.TensorHandle)
   1102           if is_tensor_handle_feed:
   1103             np_val = subfeed_val.to_numpy_array()
   1104             feed_handles[subfeed_t] = subfeed_val
   1105           else:
   1106             np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)
   1107 
   1108           if (not is_tensor_handle_feed and
   1109               not subfeed_t.get_shape().is_compatible_with(np_val.shape)):
   1110             raise ValueError('Cannot feed value of shape %r for Tensor %r, '
   1111                              'which has shape %r' %
   1112                              (np_val.shape, subfeed_t.name,
   1113                               str(subfeed_t.get_shape())))
   1114           if not self.graph.is_feedable(subfeed_t):
   1115             raise ValueError('Tensor %s may not be fed.' % subfeed_t)
   1116 
   1117           feed_dict_tensor[subfeed_t] = np_val
   1118           feed_map[compat.as_bytes(subfeed_t.name)] = (subfeed_t, subfeed_val)
   1119 
   1120     # Create a fetch handler to take care of the structure of fetches.
   1121     fetch_handler = _FetchHandler(
   1122         self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
   1123 
   1124     # Run request and get response.
   1125     # We need to keep the returned movers alive for the following _do_run().
   1126     # These movers are no longer needed when _do_run() completes, and
   1127     # are deleted when `movers` goes out of scope when this _run() ends.
   1128     # TODO(yuanbyu, keveman): Revisit whether we should just treat feeding
   1129     # of a handle from a different device as an error.
   1130     _ = self._update_with_movers(feed_dict_tensor, feed_map)
   1131     final_fetches = fetch_handler.fetches()
   1132     final_targets = fetch_handler.targets()
   1133     # We only want to really perform the run if fetches or targets are provided,
   1134     # or if the call is a partial run that specifies feeds.
   1135     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1136       results = self._do_run(handle, final_targets, final_fetches,
   1137                              feed_dict_tensor, options, run_metadata)
   1138     else:
   1139       results = []
   1140     return fetch_handler.build_results(self, results)
   1141 
   1142   def make_callable(self, fetches, feed_list=None, accept_options=False):
   1143     """Returns a Python callable that runs a particular step.
   1144 
   1145     The returned callable will take `len(feed_list)` arguments whose types
   1146     must be compatible feed values for the respective elements of `feed_list`.
   1147     For example, if element `i` of `feed_list` is a `tf.Tensor`, the `i`th
   1148     argument to the returned callable must be a numpy ndarray (or something
   1149     convertible to an ndarray) with matching element type and shape. See
   1150     @{tf.Session.run} for details of the allowable feed key and value types.
   1151 
   1152     The returned callable will have the same return type as
   1153     `tf.Session.run(fetches, ...)`. For example, if `fetches` is a `tf.Tensor`,
   1154     the callable will return a numpy ndarray; if `fetches` is a `tf.Operation`,
   1155     it will return `None`.
   1156 
   1157     Args:
   1158       fetches: A value or list of values to fetch. See @{tf.Session.run}
   1159         for details of the allowable fetch types.
   1160       feed_list: (Optional.) A list of `feed_dict` keys. See
   1161         @{tf.Session.run} for details of the allowable feed key types.
   1162       accept_options: (Optional.) Iff `True`, the returned `Callable` will be
   1163         able to accept @{tf.RunOptions} and @{tf.RunMetadata} as optional
   1164         keyword arguments `options` and `run_metadata`, respectively, with
   1165         the same syntax and semantics as @{tf.Session.run}, which is useful
   1166         for certain use cases (profiling and debugging) but will result in
   1167         measurable slowdown of the `Callable`'s performance. Default: `False`.
   1168 
   1169     Returns:
   1170       A function that when called will execute the step defined by
   1171       `feed_list` and `fetches` in this session.
   1172 
   1173     Raises:
   1174       TypeError: If `fetches` or `feed_list` cannot be interpreted
   1175         as arguments to @{tf.Session.run}.
   1176     """
   1177     if feed_list is not None:
   1178       if not isinstance(feed_list, (list, tuple)):
   1179         raise TypeError('`feed_list` must be a list or tuple.')
   1180       # Delegate any non-empty feed lists to the existing `run()` logic.
   1181       # TODO(mrry): Refactor the feed handling logic from
   1182       # `Session._run()` so that we can convert the feeds to a list of
   1183       # strings here.
   1184       def _generic_run(*feed_args, **kwargs):
   1185         feed_dict = {
   1186             feed: feed_val
   1187             for feed, feed_val in zip(feed_list, feed_args)
   1188         }
   1189         return self.run(fetches, feed_dict=feed_dict, **kwargs)
   1190 
   1191       return _generic_run
   1192 
   1193     # Ensure any changes to the graph are reflected in the runtime.
   1194     # Note that we don't need to do this on subsequent calls to the
   1195     # returned object, because the arguments to `fetches` must already be
   1196     # in the graph.
   1197     self._extend_graph()
   1198 
   1199     # Create a fetch handler to take care of the structure of fetches.
   1200     fetch_handler = _FetchHandler(self._graph, fetches, {})
   1201     if self._created_with_new_api:
   1202       # pylint: disable=protected-access
   1203       fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()]
   1204       target_list = [op._c_op for op in fetch_handler.targets()]
   1205       # pylint: enable=protected-access
   1206     else:
   1207       fetch_list = _name_list(fetch_handler.fetches())
   1208       target_list = _name_list(fetch_handler.targets())
   1209 
   1210     def _callable_template_with_options_and_metadata(fetch_list,
   1211                                                      target_list,
   1212                                                      fetch_handler,
   1213                                                      options=None,
   1214                                                      run_metadata=None):
   1215       """Template callable that accepts RunOptions and RunMetadata."""
   1216       options_ptr = tf_session.TF_NewBufferFromString(
   1217           compat.as_bytes(options.SerializeToString())) if options else None
   1218       run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
   1219       try:
   1220         with errors.raise_exception_on_not_ok_status() as status:
   1221           if self._created_with_new_api:
   1222             results = tf_session.TF_SessionRun_wrapper(
   1223                 self._session, options_ptr, {}, fetch_list, target_list,
   1224                 run_metadata_ptr, status)
   1225           else:
   1226             results = tf_session.TF_Run(self._session, options_ptr, {},
   1227                                         fetch_list, target_list, status,
   1228                                         run_metadata_ptr)
   1229           if fetch_handler:
   1230             results = fetch_handler.build_results(self, results)
   1231           else:
   1232             results = results[0] if results else None
   1233         if run_metadata:
   1234           proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
   1235           run_metadata.ParseFromString(compat.as_bytes(proto_data))
   1236       finally:
   1237         if run_metadata_ptr:
   1238           tf_session.TF_DeleteBuffer(run_metadata_ptr)
   1239         if options:
   1240           tf_session.TF_DeleteBuffer(options_ptr)
   1241       return results
   1242 
   1243     if accept_options:
   1244       return functools.partial(_callable_template_with_options_and_metadata,
   1245                                fetch_list, target_list, fetch_handler)
   1246     elif isinstance(fetches, ops.Operation):
   1247       # Special case for fetching a single operation, because the
   1248       # function will have no return value.
   1249       assert not fetch_list
   1250       assert len(target_list) == 1
   1251 
   1252       def _single_operation_run():
   1253         with errors.raise_exception_on_not_ok_status() as status:
   1254           if self._created_with_new_api:
   1255             tf_session.TF_SessionRun_wrapper(self._session, None, {}, [],
   1256                                              target_list, None, status)
   1257           else:
   1258             tf_session.TF_Run(self._session, None, {}, [], target_list, status,
   1259                               None)
   1260 
   1261       return _single_operation_run
   1262     elif isinstance(fetches, ops.Tensor):
   1263       # Special case for fetching a single tensor, because the
   1264       # function can return the result of `TF_Run()` directly.
   1265       assert len(fetch_list) == 1
   1266       assert not target_list
   1267 
   1268       def _single_tensor_run():
   1269         with errors.raise_exception_on_not_ok_status() as status:
   1270           if self._created_with_new_api:
   1271             results = tf_session.TF_SessionRun_wrapper(
   1272                 self._session, None, {}, fetch_list, [], None, status)
   1273           else:
   1274             results = tf_session.TF_Run(self._session, None, {}, fetch_list, [],
   1275                                         status, None)
   1276         return results[0]
   1277 
   1278       return _single_tensor_run
   1279     else:
   1280       # In all other cases, we must use `fetch_handler` to build the
   1281       # results for us.
   1282       def _fetch_handler_run():
   1283         with errors.raise_exception_on_not_ok_status() as status:
   1284           if self._created_with_new_api:
   1285             results = tf_session.TF_SessionRun_wrapper(
   1286                 self._session, None, {}, fetch_list, target_list, None, status)
   1287           else:
   1288             results = tf_session.TF_Run(self._session, None, {}, fetch_list,
   1289                                         target_list, status, None)
   1290         return fetch_handler.build_results(self, results)
   1291 
   1292       return _fetch_handler_run
   1293 
   1294   # Captures the name of a node in an error status.
   1295   _NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =')
   1296 
   1297   def _do_run(self, handle, target_list, fetch_list, feed_dict, options,
   1298               run_metadata):
   1299     """Runs a step based on the given fetches and feeds.
   1300 
   1301     Args:
   1302       handle: a handle for partial_run. None if this is just a call to run().
   1303       target_list: A list of operations to be run, but not fetched.
   1304       fetch_list: A list of tensors to be fetched.
   1305       feed_dict: A dictionary that maps tensors to numpy ndarrays.
   1306       options: A (pointer to a) [`RunOptions`] protocol buffer, or None
   1307       run_metadata: A (pointer to a) [`RunMetadata`] protocol buffer, or None
   1308 
   1309     Returns:
   1310       A list of numpy ndarrays, corresponding to the elements of
   1311       `fetch_list`.  If the ith element of `fetch_list` contains the
   1312       name of an operation, the first Tensor output of that operation
   1313       will be returned for that element.
   1314 
   1315     Raises:
   1316       tf.errors.OpError: Or one of its subclasses on error.
   1317     """
   1318     if self._created_with_new_api:
   1319       # pylint: disable=protected-access
   1320       feeds = dict((t._as_tf_output(), v) for t, v in feed_dict.items())
   1321       fetches = [t._as_tf_output() for t in fetch_list]
   1322       targets = [op._c_op for op in target_list]
   1323       # pylint: enable=protected-access
   1324     else:
   1325       feeds = dict((compat.as_bytes(t.name), v) for t, v in feed_dict.items())
   1326       fetches = _name_list(fetch_list)
   1327       targets = _name_list(target_list)
   1328 
   1329     def _run_fn(session, feed_dict, fetch_list, target_list, options,
   1330                 run_metadata):
   1331       # Ensure any changes to the graph are reflected in the runtime.
   1332       self._extend_graph()
   1333       with errors.raise_exception_on_not_ok_status() as status:
   1334         if self._created_with_new_api:
   1335           return tf_session.TF_SessionRun_wrapper(session, options, feed_dict,
   1336                                                   fetch_list, target_list,
   1337                                                   run_metadata, status)
   1338         else:
   1339           return tf_session.TF_Run(session, options, feed_dict, fetch_list,
   1340                                    target_list, status, run_metadata)
   1341 
   1342     def _prun_fn(session, handle, feed_dict, fetch_list):
   1343       if target_list:
   1344         raise RuntimeError('partial_run() requires empty target_list.')
   1345       with errors.raise_exception_on_not_ok_status() as status:
   1346         if self._created_with_new_api:
   1347           return tf_session.TF_SessionPRun_wrapper(session, handle, feed_dict,
   1348                                                    fetch_list, status)
   1349         else:
   1350           return tf_session.TF_PRun(session, handle, feed_dict, fetch_list,
   1351                                     status)
   1352 
   1353     if handle is None:
   1354       return self._do_call(_run_fn, self._session, feeds, fetches, targets,
   1355                            options, run_metadata)
   1356     else:
   1357       return self._do_call(_prun_fn, self._session, handle, feeds, fetches)
   1358 
   1359   def _do_call(self, fn, *args):
   1360     try:
   1361       return fn(*args)
   1362     except errors.OpError as e:
   1363       message = compat.as_text(e.message)
   1364       m = BaseSession._NODEDEF_NAME_RE.search(message)
   1365       node_def = None
   1366       op = None
   1367       if m is not None:
   1368         node_name = m.group(1)
   1369         try:
   1370           op = self._graph.get_operation_by_name(node_name)
   1371           node_def = op.node_def
   1372         except KeyError:
   1373           pass
   1374       raise type(e)(node_def, op, message)
   1375 
   1376   def _extend_graph(self):
   1377     # Nothing to do if we're using the new session interface
   1378     # TODO(skyewm): remove this function altogether eventually
   1379     if self._created_with_new_api:
   1380       return
   1381 
   1382     # Ensure any changes to the graph are reflected in the runtime.
   1383     with self._extend_lock:
   1384       if self._graph.version > self._current_version:
   1385         # pylint: disable=protected-access
   1386         graph_def, self._current_version = self._graph._as_graph_def(
   1387             from_version=self._current_version, add_shapes=self._add_shapes)
   1388         # pylint: enable=protected-access
   1389 
   1390         with errors.raise_exception_on_not_ok_status() as status:
   1391           tf_session.TF_ExtendGraph(self._session,
   1392                                     graph_def.SerializeToString(), status)
   1393         self._opened = True
   1394 
   1395   # The threshold to run garbage collection to delete dead tensors.
   1396   _DEAD_HANDLES_THRESHOLD = 10
   1397 
   1398   def _register_dead_handle(self, handle):
   1399     # Register a dead handle in the session. Delete the dead tensors when
   1400     # the number of dead tensors exceeds certain threshold.
   1401     tensors_to_delete = None
   1402     with self._delete_lock:
   1403       self._dead_handles.append(handle)
   1404       if len(self._dead_handles) == BaseSession._DEAD_HANDLES_THRESHOLD:
   1405         tensors_to_delete = self._dead_handles
   1406         self._dead_handles = []
   1407     # Delete the dead tensors.
   1408     if tensors_to_delete:
   1409       feeds = {}
   1410       fetches = []
   1411       for deleter_key, tensor_handle in enumerate(tensors_to_delete):
   1412         holder, deleter = session_ops._get_handle_deleter(
   1413             self.graph, deleter_key, tensor_handle)
   1414         feeds[holder] = tensor_handle
   1415         fetches.append(deleter)
   1416       self.run(fetches, feed_dict=feeds)
   1417 
   1418   def _update_with_movers(self, feed_dict, feed_map):
   1419     # If a tensor handle that is fed to a device incompatible placeholder,
   1420     # we move the tensor to the right device, generate a new tensor handle,
   1421     # and update `feed_dict` to use the new handle.
   1422     handle_movers = []
   1423     for feed_name, val in feed_map.items():
   1424       mover = session_ops._get_handle_mover(self.graph, *val)
   1425       if mover:
   1426         handle_movers.append((feed_name, val[1], mover))
   1427     # Transfer a tensor to the right device if needed.
   1428     if not handle_movers:
   1429       return []
   1430     else:
   1431       feeds = {}
   1432       fetches = []
   1433       for _, handle, mover in handle_movers:
   1434         feeds[mover[0]] = handle
   1435         fetches.append(mover[1])
   1436       handles = self.run(fetches, feed_dict=feeds)
   1437       for handle_mover, handle in zip(handle_movers, handles):
   1438         np_val = np.array(handle.handle, dtype=np.object)
   1439         feed_name = handle_mover[0]
   1440         feed_tensor = feed_map[feed_name][0]
   1441         feed_dict[feed_tensor] = np_val
   1442       return handles
   1443 
   1444 
   1445 @tf_export('Session')
   1446 class Session(BaseSession):
   1447   """A class for running TensorFlow operations.
   1448 
   1449   A `Session` object encapsulates the environment in which `Operation`
   1450   objects are executed, and `Tensor` objects are evaluated. For
   1451   example:
   1452 
   1453   ```python
   1454   # Build a graph.
   1455   a = tf.constant(5.0)
   1456   b = tf.constant(6.0)
   1457   c = a * b
   1458 
   1459   # Launch the graph in a session.
   1460   sess = tf.Session()
   1461 
   1462   # Evaluate the tensor `c`.
   1463   print(sess.run(c))
   1464   ```
   1465 
   1466   A session may own resources, such as
   1467   @{tf.Variable}, @{tf.QueueBase},
   1468   and @{tf.ReaderBase}. It is important to release
   1469   these resources when they are no longer required. To do this, either
   1470   invoke the @{tf.Session.close} method on the session, or use
   1471   the session as a context manager. The following two examples are
   1472   equivalent:
   1473 
   1474   ```python
   1475   # Using the `close()` method.
   1476   sess = tf.Session()
   1477   sess.run(...)
   1478   sess.close()
   1479 
   1480   # Using the context manager.
   1481   with tf.Session() as sess:
   1482     sess.run(...)
   1483   ```
   1484 
   1485   The
   1486   [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)
   1487   protocol buffer exposes various configuration options for a
   1488   session. For example, to create a session that uses soft constraints
   1489   for device placement, and log the resulting placement decisions,
   1490   create a session as follows:
   1491 
   1492   ```python
   1493   # Launch the graph in a session that allows soft device placement and
   1494   # logs the placement decisions.
   1495   sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
   1496                                           log_device_placement=True))
   1497   ```
   1498   """
   1499 
   1500   def __init__(self, target='', graph=None, config=None):
   1501     """Creates a new TensorFlow session.
   1502 
   1503     If no `graph` argument is specified when constructing the session,
   1504     the default graph will be launched in the session. If you are
   1505     using more than one graph (created with `tf.Graph()` in the same
   1506     process, you will have to use different sessions for each graph,
   1507     but each graph can be used in multiple sessions. In this case, it
   1508     is often clearer to pass the graph to be launched explicitly to
   1509     the session constructor.
   1510 
   1511     Args:
   1512       target: (Optional.) The execution engine to connect to.
   1513         Defaults to using an in-process engine. See
   1514         @{$distributed$Distributed TensorFlow}
   1515         for more examples.
   1516       graph: (Optional.) The `Graph` to be launched (described above).
   1517       config: (Optional.) A
   1518         [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)
   1519         protocol buffer with configuration options for the session.
   1520 
   1521     """
   1522     super(Session, self).__init__(target, graph, config=config)
   1523     # NOTE(mrry): Create these on first `__enter__` to avoid a reference cycle.
   1524     self._default_graph_context_manager = None
   1525     self._default_session_context_manager = None
   1526 
   1527   def __enter__(self):
   1528     if self._default_graph_context_manager is None:
   1529       self._default_graph_context_manager = self.graph.as_default()
   1530     else:
   1531       raise RuntimeError('Session context managers are not re-entrant. '
   1532                          'Use `Session.as_default()` if you want to enter '
   1533                          'a session multiple times.')
   1534     if self._default_session_context_manager is None:
   1535       self._default_session_context_manager = self.as_default()
   1536     self._default_graph_context_manager.__enter__()
   1537     return self._default_session_context_manager.__enter__()
   1538 
   1539   def __exit__(self, exec_type, exec_value, exec_tb):
   1540     if exec_type is errors.OpError:
   1541       logging.error('Session closing due to OpError: %s', (exec_value,))
   1542     try:
   1543       self._default_session_context_manager.__exit__(exec_type, exec_value,
   1544                                                      exec_tb)
   1545     except RuntimeError as error:
   1546       if error == exec_value:
   1547         # NOTE(skyewm): for some reason, in Python3,
   1548         # _default_session_context_manager.__exit__ will re-raise the "not
   1549         # re-entrant" exception raised in __enter__ above (note that if we're
   1550         # here, we're in the outer session context manager, since __exit__ is
   1551         # not called when __enter__ raises an exception). We still want to
   1552         # continue cleaning up this context manager before the exception is
   1553         # further propagated, so we ignore it here (note that it'll continue
   1554         # being propagated after this method completes).
   1555         pass
   1556       else:
   1557         raise
   1558     self._default_graph_context_manager.__exit__(exec_type, exec_value, exec_tb)
   1559 
   1560     self._default_session_context_manager = None
   1561     self._default_graph_context_manager = None
   1562 
   1563     self.close()
   1564 
   1565   @staticmethod
   1566   def reset(target, containers=None, config=None):
   1567     """Resets resource containers on `target`, and close all connected sessions.
   1568 
   1569     A resource container is distributed across all workers in the
   1570     same cluster as `target`.  When a resource container on `target`
   1571     is reset, resources associated with that container will be cleared.
   1572     In particular, all Variables in the container will become undefined:
   1573     they lose their values and shapes.
   1574 
   1575     NOTE:
   1576     (i) reset() is currently only implemented for distributed sessions.
   1577     (ii) Any sessions on the master named by `target` will be closed.
   1578 
   1579     If no resource containers are provided, all containers are reset.
   1580 
   1581     Args:
   1582       target: The execution engine to connect to.
   1583       containers: A list of resource container name strings, or `None` if all of
   1584         all the containers are to be reset.
   1585       config: (Optional.) Protocol buffer with configuration options.
   1586 
   1587     Raises:
   1588       tf.errors.OpError: Or one of its subclasses if an error occurs while
   1589         resetting containers.
   1590     """
   1591     if target is not None:
   1592       target = compat.as_bytes(target)
   1593     if containers is not None:
   1594       containers = [compat.as_bytes(c) for c in containers]
   1595     else:
   1596       containers = []
   1597     tf_session.TF_Reset(target, containers, config)
   1598 
   1599 
   1600 @tf_export('InteractiveSession')
   1601 class InteractiveSession(BaseSession):
   1602   """A TensorFlow `Session` for use in interactive contexts, such as a shell.
   1603 
   1604   The only difference with a regular `Session` is that an `InteractiveSession`
   1605   installs itself as the default session on construction.
   1606   The methods @{tf.Tensor.eval}
   1607   and @{tf.Operation.run}
   1608   will use that session to run ops.
   1609 
   1610   This is convenient in interactive shells and [IPython
   1611   notebooks](http://ipython.org), as it avoids having to pass an explicit
   1612   `Session` object to run ops.
   1613 
   1614   For example:
   1615 
   1616   ```python
   1617   sess = tf.InteractiveSession()
   1618   a = tf.constant(5.0)
   1619   b = tf.constant(6.0)
   1620   c = a * b
   1621   # We can just use 'c.eval()' without passing 'sess'
   1622   print(c.eval())
   1623   sess.close()
   1624   ```
   1625 
   1626   Note that a regular session installs itself as the default session when it
   1627   is created in a `with` statement.  The common usage in non-interactive
   1628   programs is to follow that pattern:
   1629 
   1630   ```python
   1631   a = tf.constant(5.0)
   1632   b = tf.constant(6.0)
   1633   c = a * b
   1634   with tf.Session():
   1635     # We can also use 'c.eval()' here.
   1636     print(c.eval())
   1637   ```
   1638   """
   1639 
   1640   def __init__(self, target='', graph=None, config=None):
   1641     """Creates a new interactive TensorFlow session.
   1642 
   1643     If no `graph` argument is specified when constructing the session,
   1644     the default graph will be launched in the session. If you are
   1645     using more than one graph (created with `tf.Graph()` in the same
   1646     process, you will have to use different sessions for each graph,
   1647     but each graph can be used in multiple sessions. In this case, it
   1648     is often clearer to pass the graph to be launched explicitly to
   1649     the session constructor.
   1650 
   1651     Args:
   1652       target: (Optional.) The execution engine to connect to.
   1653         Defaults to using an in-process engine.
   1654       graph: (Optional.) The `Graph` to be launched (described above).
   1655       config: (Optional) `ConfigProto` proto used to configure the session.
   1656     """
   1657     if not config:
   1658       # If config is not provided, choose some reasonable defaults for
   1659       # interactive use:
   1660       #
   1661       #   - Grow GPU memory as needed at the cost of fragmentation.
   1662       gpu_options = config_pb2.GPUOptions(allow_growth=True)
   1663       config = config_pb2.ConfigProto(gpu_options=gpu_options)
   1664     # Interactive sessions always place pruned graphs.
   1665     config.graph_options.place_pruned_graph = True
   1666 
   1667     super(InteractiveSession, self).__init__(target, graph, config)
   1668     self._default_session = self.as_default()
   1669     self._default_session.enforce_nesting = False
   1670     self._default_session.__enter__()
   1671     self._explicit_graph = graph
   1672     if self._explicit_graph is not None:
   1673       self._default_graph = graph.as_default()
   1674       self._default_graph.enforce_nesting = False
   1675       self._default_graph.__enter__()
   1676 
   1677   def close(self):
   1678     """Closes an `InteractiveSession`."""
   1679     super(InteractiveSession, self).close()
   1680     if self._explicit_graph is not None:
   1681       self._default_graph.__exit__(None, None, None)
   1682     self._default_session.__exit__(None, None, None)
   1683