1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Tensor Handle Operations. See the @{$python/session_ops} guide. 17 18 @@get_session_handle 19 @@get_session_handle_v2 20 @@get_session_tensor 21 @@delete_session_tensor 22 """ 23 24 # pylint: disable=g-bad-name 25 from __future__ import absolute_import 26 from __future__ import division 27 from __future__ import print_function 28 29 import numpy as np 30 31 from tensorflow.core.framework import resource_handle_pb2 32 from tensorflow.python import pywrap_tensorflow_internal 33 from tensorflow.python.framework import device as pydev 34 from tensorflow.python.framework import dtypes 35 from tensorflow.python.framework import ops 36 from tensorflow.python.ops import array_ops 37 from tensorflow.python.ops import gen_data_flow_ops 38 from tensorflow.python.util import compat 39 from tensorflow.python.util.tf_export import tf_export 40 41 42 def encode_resource_handle(resource_handle): 43 """Encode a ResourceHandle proto as custom numpy struct type.""" 44 return np.asarray(bytearray(resource_handle.SerializeToString()), 45 dtype=dtypes.np_resource) 46 47 48 class TensorHandle(object): 49 """Represents a handle for a live tensor in a session.""" 50 51 def __init__(self, handle, dtype, session): 52 """Constructs a new tensor handle. 53 54 A tensor handle for a persistent tensor is a python string 55 that has the form of "tensor_name;unique_id;device_name". 56 57 Args: 58 handle: A tensor handle. 59 dtype: The data type of the tensor represented by `handle`. 60 session: The session in which the tensor is produced. 61 """ 62 self._handle = compat.as_str_any(handle) 63 self._resource_handle = None 64 self._dtype = dtype 65 self._session = session 66 self._auto_gc_enabled = True 67 68 def __del__(self): 69 if self._auto_gc_enabled: 70 self._session._register_dead_handle(self.handle) 71 72 def __str__(self): 73 return self._handle 74 75 def _get_resource_handle(self): 76 """The ResourceHandle representation of this handle.""" 77 if not self._resource_handle: 78 self._resource_handle = resource_handle_pb2.ResourceHandleProto() 79 self._resource_handle.device = self._handle.split(";")[-1] 80 self._resource_handle.container = ( 81 pywrap_tensorflow_internal.TENSOR_HANDLE_KEY) 82 self._resource_handle.name = self._handle 83 return self._resource_handle 84 85 def to_numpy_array(self): 86 """Convert a TensorHandle object to a feedable numpy value. 87 88 Returns: 89 A numpy array of a custom struct type that can be used as a feed value 90 to run(). 91 """ 92 return encode_resource_handle(self._get_resource_handle()) 93 94 @property 95 def handle(self): 96 """The string representation of this handle.""" 97 return self._handle 98 99 def eval(self): 100 """Return the value of the tensor represented by this handle.""" 101 if not self._auto_gc_enabled: 102 raise TypeError("Persistent tensor %s may have already been deleted." 103 % self.handle) 104 holder, reader = _get_handle_reader(self._session.graph, self._handle, 105 self._dtype) 106 return self._session.run(reader, feed_dict={holder: self._handle}) 107 108 def delete(self): 109 """Force the deletion of this persistent tensor.""" 110 if not self._auto_gc_enabled: 111 raise TypeError("Persistent tensor %s may have already been deleted." 112 % self.handle) 113 self._auto_gc_enabled = False 114 holder, deleter = _get_handle_deleter(self._session.graph, 0, self._handle) 115 self._session.run(deleter, feed_dict={holder: self.handle}) 116 117 def get_raw_handle(self): 118 """Return the raw handle of the tensor. 119 120 Note that the method disables the automatic garbage collection of this 121 persistent tensor. The caller is now responsible for managing the life 122 time of the tensor. 123 """ 124 self._auto_gc_enabled = False 125 return self._handle 126 127 @staticmethod 128 def _get_device_name(handle): 129 """The device name encoded in the handle.""" 130 handle_str = compat.as_str_any(handle) 131 return pydev.canonical_name(handle_str.split(";")[-1]) 132 133 @staticmethod 134 def _get_reader_key(handle): 135 """The graph key for reader.""" 136 handle_parts = str(handle).split(";") 137 return handle_parts[0] + ";" + handle_parts[-1] 138 139 @staticmethod 140 def _get_mover_key(feeder, handle): 141 """The graph key for mover.""" 142 return feeder.op.name + ";" + TensorHandle._get_reader_key(handle) 143 144 145 @tf_export("get_session_handle") 146 def get_session_handle(data, name=None): 147 """Return the handle of `data`. 148 149 This is EXPERIMENTAL and subject to change. 150 151 Keep `data` "in-place" in the runtime and create a handle that can be 152 used to retrieve `data` in a subsequent run(). 153 154 Combined with `get_session_tensor`, we can keep a tensor produced in 155 one run call in place, and use it as the input in a future run call. 156 157 Args: 158 data: A tensor to be stored in the session. 159 name: Optional name prefix for the return tensor. 160 161 Returns: 162 A scalar string tensor representing a unique handle for `data`. 163 164 Raises: 165 TypeError: if `data` is not a Tensor. 166 167 Example: 168 169 ```python 170 c = tf.multiply(a, b) 171 h = tf.get_session_handle(c) 172 h = sess.run(h) 173 174 p, a = tf.get_session_tensor(h.handle, tf.float32) 175 b = tf.multiply(a, 10) 176 c = sess.run(b, feed_dict={p: h.handle}) 177 ``` 178 179 """ 180 if not isinstance(data, ops.Tensor): 181 raise TypeError("`data` must be of type Tensor.") 182 183 # Colocate this operation with data. 184 with ops.colocate_with(data): 185 return gen_data_flow_ops._get_session_handle(data, name=name) # pylint: disable=protected-access 186 187 188 @tf_export("get_session_tensor") 189 def get_session_tensor(handle, dtype, name=None): 190 """Get the tensor of type `dtype` by feeding a tensor handle. 191 192 This is EXPERIMENTAL and subject to change. 193 194 Get the value of the tensor from a tensor handle. The tensor 195 is produced in a previous run() and stored in the state of the 196 session. 197 198 Args: 199 handle: The string representation of a persistent tensor handle. 200 dtype: The type of the output tensor. 201 name: Optional name prefix for the return tensor. 202 203 Returns: 204 A pair of tensors. The first is a placeholder for feeding a 205 tensor handle and the second is the tensor in the session state 206 keyed by the tensor handle. 207 208 Example: 209 210 ```python 211 c = tf.multiply(a, b) 212 h = tf.get_session_handle(c) 213 h = sess.run(h) 214 215 p, a = tf.get_session_tensor(h.handle, tf.float32) 216 b = tf.multiply(a, 10) 217 c = sess.run(b, feed_dict={p: h.handle}) 218 ``` 219 220 """ 221 handle_device = TensorHandle._get_device_name(handle) 222 with ops.device(handle_device): 223 holder = array_ops.placeholder(dtypes.string) 224 _register_handle_feeder(holder.graph, holder, dtype) 225 tensor = gen_data_flow_ops._get_session_tensor(holder, dtype, name=name) 226 return (holder, tensor) 227 228 229 @tf_export("delete_session_tensor") 230 def delete_session_tensor(handle, name=None): 231 """Delete the tensor for the given tensor handle. 232 233 This is EXPERIMENTAL and subject to change. 234 235 Delete the tensor of a given tensor handle. The tensor is produced 236 in a previous run() and stored in the state of the session. 237 238 Args: 239 handle: The string representation of a persistent tensor handle. 240 name: Optional name prefix for the return tensor. 241 242 Returns: 243 A pair of graph elements. The first is a placeholder for feeding a 244 tensor handle and the second is a deletion operation. 245 """ 246 handle_device = TensorHandle._get_device_name(handle) 247 with ops.device(handle_device): 248 holder = array_ops.placeholder(dtypes.string) 249 deleter = gen_data_flow_ops._delete_session_tensor(holder, name=name) 250 return (holder, deleter) 251 252 253 def _register_handle_feeder(graph, feeder, dtype): 254 graph._handle_feeders[feeder.op.name] = dtype 255 256 257 def _get_handle_feeder(graph, feeder): 258 return graph._handle_feeders.get(feeder.op.name) 259 260 261 def _get_handle_reader(graph, handle, dtype): 262 """Return a read subgraph for this handle.""" 263 graph_key = TensorHandle._get_reader_key(handle) 264 result = graph._handle_readers.get(graph_key) 265 if result is None: 266 # Create reader if we haven't done it. 267 handle_device = TensorHandle._get_device_name(handle) 268 with graph.as_default(), graph.device(handle_device): 269 holder = array_ops.placeholder(dtypes.string) 270 _register_handle_feeder(holder.graph, holder, dtype) 271 reader = gen_data_flow_ops._get_session_tensor(holder, dtype) 272 result = (holder, reader) 273 graph._handle_readers[graph_key] = result 274 return result 275 276 277 def _get_handle_mover(graph, feeder, handle): 278 """Return a move subgraph for this pair of feeder and handle.""" 279 dtype = _get_handle_feeder(graph, feeder) 280 if dtype is None: 281 return None 282 handle_device = TensorHandle._get_device_name(handle) 283 if feeder.op.device == handle_device: 284 return None 285 # Now we know we have to move the tensor. 286 graph_key = TensorHandle._get_mover_key(feeder, handle) 287 result = graph._handle_movers.get(graph_key) 288 if result is None: 289 # Create mover if we haven't done it. 290 holder, reader = _get_handle_reader(graph, handle, dtype) 291 with graph.as_default(), graph.device(feeder.op.device): 292 mover = gen_data_flow_ops._get_session_handle(reader) # pylint: disable=protected-access 293 result = (holder, mover) 294 graph._handle_movers[graph_key] = result 295 return result 296 297 298 def _get_handle_deleter(graph, deleter_key, handle): 299 """Return a deletion subgraph for this handle.""" 300 result = graph._handle_deleters.get(deleter_key) 301 if result is None: 302 # Create deleter if we haven't done it. 303 handle_device = TensorHandle._get_device_name(handle) 304 with graph.as_default(), graph.device(handle_device): 305 holder = array_ops.placeholder(dtypes.string) 306 deleter = gen_data_flow_ops._delete_session_tensor(holder) 307 result = (holder, deleter) 308 graph._handle_deleters[deleter_key] = result 309 return result 310