Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Tensor Handle Operations. See the @{$python/session_ops} guide.
     17 
     18 @@get_session_handle
     19 @@get_session_handle_v2
     20 @@get_session_tensor
     21 @@delete_session_tensor
     22 """
     23 
     24 # pylint: disable=g-bad-name
     25 from __future__ import absolute_import
     26 from __future__ import division
     27 from __future__ import print_function
     28 
     29 import numpy as np
     30 
     31 from tensorflow.core.framework import resource_handle_pb2
     32 from tensorflow.python import pywrap_tensorflow_internal
     33 from tensorflow.python.framework import device as pydev
     34 from tensorflow.python.framework import dtypes
     35 from tensorflow.python.framework import ops
     36 from tensorflow.python.ops import array_ops
     37 from tensorflow.python.ops import gen_data_flow_ops
     38 from tensorflow.python.util import compat
     39 from tensorflow.python.util.tf_export import tf_export
     40 
     41 
     42 def encode_resource_handle(resource_handle):
     43   """Encode a ResourceHandle proto as custom numpy struct type."""
     44   return np.asarray(bytearray(resource_handle.SerializeToString()),
     45                     dtype=dtypes.np_resource)
     46 
     47 
     48 class TensorHandle(object):
     49   """Represents a handle for a live tensor in a session."""
     50 
     51   def __init__(self, handle, dtype, session):
     52     """Constructs a new tensor handle.
     53 
     54     A tensor handle for a persistent tensor is a python string
     55     that has the form of "tensor_name;unique_id;device_name".
     56 
     57     Args:
     58       handle: A tensor handle.
     59       dtype: The data type of the tensor represented by `handle`.
     60       session: The session in which the tensor is produced.
     61     """
     62     self._handle = compat.as_str_any(handle)
     63     self._resource_handle = None
     64     self._dtype = dtype
     65     self._session = session
     66     self._auto_gc_enabled = True
     67 
     68   def __del__(self):
     69     if self._auto_gc_enabled:
     70       self._session._register_dead_handle(self.handle)
     71 
     72   def __str__(self):
     73     return self._handle
     74 
     75   def _get_resource_handle(self):
     76     """The ResourceHandle representation of this handle."""
     77     if not self._resource_handle:
     78       self._resource_handle = resource_handle_pb2.ResourceHandleProto()
     79       self._resource_handle.device = self._handle.split(";")[-1]
     80       self._resource_handle.container = (
     81           pywrap_tensorflow_internal.TENSOR_HANDLE_KEY)
     82       self._resource_handle.name = self._handle
     83     return self._resource_handle
     84 
     85   def to_numpy_array(self):
     86     """Convert a TensorHandle object to a feedable numpy value.
     87 
     88     Returns:
     89       A numpy array of a custom struct type that can be used as a feed value
     90       to run().
     91     """
     92     return encode_resource_handle(self._get_resource_handle())
     93 
     94   @property
     95   def handle(self):
     96     """The string representation of this handle."""
     97     return self._handle
     98 
     99   def eval(self):
    100     """Return the value of the tensor represented by this handle."""
    101     if not self._auto_gc_enabled:
    102       raise TypeError("Persistent tensor %s may have already been deleted."
    103                       % self.handle)
    104     holder, reader = _get_handle_reader(self._session.graph, self._handle,
    105                                         self._dtype)
    106     return self._session.run(reader, feed_dict={holder: self._handle})
    107 
    108   def delete(self):
    109     """Force the deletion of this persistent tensor."""
    110     if not self._auto_gc_enabled:
    111       raise TypeError("Persistent tensor %s may have already been deleted."
    112                       % self.handle)
    113     self._auto_gc_enabled = False
    114     holder, deleter = _get_handle_deleter(self._session.graph, 0, self._handle)
    115     self._session.run(deleter, feed_dict={holder: self.handle})
    116 
    117   def get_raw_handle(self):
    118     """Return the raw handle of the tensor.
    119 
    120     Note that the method disables the automatic garbage collection of this
    121     persistent tensor. The caller is now responsible for managing the life
    122     time of the tensor.
    123     """
    124     self._auto_gc_enabled = False
    125     return self._handle
    126 
    127   @staticmethod
    128   def _get_device_name(handle):
    129     """The device name encoded in the handle."""
    130     handle_str = compat.as_str_any(handle)
    131     return pydev.canonical_name(handle_str.split(";")[-1])
    132 
    133   @staticmethod
    134   def _get_reader_key(handle):
    135     """The graph key for reader."""
    136     handle_parts = str(handle).split(";")
    137     return handle_parts[0] + ";" + handle_parts[-1]
    138 
    139   @staticmethod
    140   def _get_mover_key(feeder, handle):
    141     """The graph key for mover."""
    142     return feeder.op.name + ";" + TensorHandle._get_reader_key(handle)
    143 
    144 
    145 @tf_export("get_session_handle")
    146 def get_session_handle(data, name=None):
    147   """Return the handle of `data`.
    148 
    149   This is EXPERIMENTAL and subject to change.
    150 
    151   Keep `data` "in-place" in the runtime and create a handle that can be
    152   used to retrieve `data` in a subsequent run().
    153 
    154   Combined with `get_session_tensor`, we can keep a tensor produced in
    155   one run call in place, and use it as the input in a future run call.
    156 
    157   Args:
    158     data: A tensor to be stored in the session.
    159     name: Optional name prefix for the return tensor.
    160 
    161   Returns:
    162     A scalar string tensor representing a unique handle for `data`.
    163 
    164   Raises:
    165     TypeError: if `data` is not a Tensor.
    166 
    167   Example:
    168 
    169   ```python
    170   c = tf.multiply(a, b)
    171   h = tf.get_session_handle(c)
    172   h = sess.run(h)
    173 
    174   p, a = tf.get_session_tensor(h.handle, tf.float32)
    175   b = tf.multiply(a, 10)
    176   c = sess.run(b, feed_dict={p: h.handle})
    177   ```
    178 
    179   """
    180   if not isinstance(data, ops.Tensor):
    181     raise TypeError("`data` must be of type Tensor.")
    182 
    183   # Colocate this operation with data.
    184   with ops.colocate_with(data):
    185     return gen_data_flow_ops._get_session_handle(data, name=name)  # pylint: disable=protected-access
    186 
    187 
    188 @tf_export("get_session_tensor")
    189 def get_session_tensor(handle, dtype, name=None):
    190   """Get the tensor of type `dtype` by feeding a tensor handle.
    191 
    192   This is EXPERIMENTAL and subject to change.
    193 
    194   Get the value of the tensor from a tensor handle. The tensor
    195   is produced in a previous run() and stored in the state of the
    196   session.
    197 
    198   Args:
    199     handle: The string representation of a persistent tensor handle.
    200     dtype: The type of the output tensor.
    201     name: Optional name prefix for the return tensor.
    202 
    203   Returns:
    204     A pair of tensors. The first is a placeholder for feeding a
    205     tensor handle and the second is the tensor in the session state
    206     keyed by the tensor handle.
    207 
    208   Example:
    209 
    210   ```python
    211   c = tf.multiply(a, b)
    212   h = tf.get_session_handle(c)
    213   h = sess.run(h)
    214 
    215   p, a = tf.get_session_tensor(h.handle, tf.float32)
    216   b = tf.multiply(a, 10)
    217   c = sess.run(b, feed_dict={p: h.handle})
    218   ```
    219 
    220   """
    221   handle_device = TensorHandle._get_device_name(handle)
    222   with ops.device(handle_device):
    223     holder = array_ops.placeholder(dtypes.string)
    224     _register_handle_feeder(holder.graph, holder, dtype)
    225     tensor = gen_data_flow_ops._get_session_tensor(holder, dtype, name=name)
    226   return (holder, tensor)
    227 
    228 
    229 @tf_export("delete_session_tensor")
    230 def delete_session_tensor(handle, name=None):
    231   """Delete the tensor for the given tensor handle.
    232 
    233   This is EXPERIMENTAL and subject to change.
    234 
    235   Delete the tensor of a given tensor handle. The tensor is produced
    236   in a previous run() and stored in the state of the session.
    237 
    238   Args:
    239     handle: The string representation of a persistent tensor handle.
    240     name: Optional name prefix for the return tensor.
    241 
    242   Returns:
    243     A pair of graph elements. The first is a placeholder for feeding a
    244     tensor handle and the second is a deletion operation.
    245   """
    246   handle_device = TensorHandle._get_device_name(handle)
    247   with ops.device(handle_device):
    248     holder = array_ops.placeholder(dtypes.string)
    249     deleter = gen_data_flow_ops._delete_session_tensor(holder, name=name)
    250   return (holder, deleter)
    251 
    252 
    253 def _register_handle_feeder(graph, feeder, dtype):
    254   graph._handle_feeders[feeder.op.name] = dtype
    255 
    256 
    257 def _get_handle_feeder(graph, feeder):
    258   return graph._handle_feeders.get(feeder.op.name)
    259 
    260 
    261 def _get_handle_reader(graph, handle, dtype):
    262   """Return a read subgraph for this handle."""
    263   graph_key = TensorHandle._get_reader_key(handle)
    264   result = graph._handle_readers.get(graph_key)
    265   if result is None:
    266     # Create reader if we haven't done it.
    267     handle_device = TensorHandle._get_device_name(handle)
    268     with graph.as_default(), graph.device(handle_device):
    269       holder = array_ops.placeholder(dtypes.string)
    270       _register_handle_feeder(holder.graph, holder, dtype)
    271       reader = gen_data_flow_ops._get_session_tensor(holder, dtype)
    272     result = (holder, reader)
    273     graph._handle_readers[graph_key] = result
    274   return result
    275 
    276 
    277 def _get_handle_mover(graph, feeder, handle):
    278   """Return a move subgraph for this pair of feeder and handle."""
    279   dtype = _get_handle_feeder(graph, feeder)
    280   if dtype is None:
    281     return None
    282   handle_device = TensorHandle._get_device_name(handle)
    283   if feeder.op.device == handle_device:
    284     return None
    285   # Now we know we have to move the tensor.
    286   graph_key = TensorHandle._get_mover_key(feeder, handle)
    287   result = graph._handle_movers.get(graph_key)
    288   if result is None:
    289     # Create mover if we haven't done it.
    290     holder, reader = _get_handle_reader(graph, handle, dtype)
    291     with graph.as_default(), graph.device(feeder.op.device):
    292       mover = gen_data_flow_ops._get_session_handle(reader)  # pylint: disable=protected-access
    293     result = (holder, mover)
    294     graph._handle_movers[graph_key] = result
    295   return result
    296 
    297 
    298 def _get_handle_deleter(graph, deleter_key, handle):
    299   """Return a deletion subgraph for this handle."""
    300   result = graph._handle_deleters.get(deleter_key)
    301   if result is None:
    302     # Create deleter if we haven't done it.
    303     handle_device = TensorHandle._get_device_name(handle)
    304     with graph.as_default(), graph.device(handle_device):
    305       holder = array_ops.placeholder(dtypes.string)
    306       deleter = gen_data_flow_ops._delete_session_tensor(holder)
    307     result = (holder, deleter)
    308     graph._handle_deleters[deleter_key] = result
    309   return result
    310