1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """A client interface for TensorFlow.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import functools 22 import re 23 import threading 24 25 import numpy as np 26 27 from tensorflow.core.protobuf import config_pb2 28 from tensorflow.python import pywrap_tensorflow as tf_session 29 from tensorflow.python.framework import c_api_util 30 from tensorflow.python.framework import device 31 from tensorflow.python.framework import errors 32 from tensorflow.python.framework import ops 33 from tensorflow.python.framework import sparse_tensor 34 from tensorflow.python.ops import session_ops 35 from tensorflow.python.platform import tf_logging as logging 36 from tensorflow.python.util import compat 37 from tensorflow.python.util import nest 38 from tensorflow.python.util.tf_export import tf_export 39 40 41 class SessionInterface(object): 42 """Base class for implementations of TensorFlow client sessions.""" 43 44 @property 45 def graph(self): 46 """The underlying TensorFlow graph, to be used in building Operations.""" 47 raise NotImplementedError('graph') 48 49 @property 50 def sess_str(self): 51 """The TensorFlow process to which this session will connect.""" 52 raise NotImplementedError('sess_str') 53 54 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 55 """Runs operations in the session. See `BaseSession.run()` for details.""" 56 raise NotImplementedError('run') 57 58 def partial_run_setup(self, fetches, feeds=None): 59 """Sets up the feeds and fetches for partial runs in the session.""" 60 raise NotImplementedError('partial_run_setup') 61 62 def partial_run(self, handle, fetches, feed_dict=None): 63 """Continues the execution with additional feeds and fetches.""" 64 raise NotImplementedError('partial_run') 65 66 67 def _get_indexed_slices_value_from_fetches(fetched_vals): 68 return ops.IndexedSlicesValue(fetched_vals[0], fetched_vals[1], 69 fetched_vals[2] 70 if len(fetched_vals) == 3 else None) 71 72 73 def _get_feeds_for_indexed_slices(feed, feed_val): 74 return list( 75 zip([feed.values, feed.indices] if feed.dense_shape is None else 76 [feed.values, feed.indices, feed.dense_shape], feed_val)) 77 78 79 # List of extensions supported to convert run arguments into actual fetches and 80 # feeds. 81 # 82 # Each element in the list is a tuple of (Type, fetch_fn, feed_fn1, feed_fn2), 83 # where the function signatures are: 84 # fetch_fn : Type -> (list of Tensors, 85 # lambda: list of fetched np.ndarray -> TypeVal) 86 # feed_fn1 : Type, TypeVal -> list of (Tensor, value) 87 # feed_fn2 : Type -> list of Tensors 88 # 89 # `fetch_fn` describes how to expand fetch into its 90 # component Tensors and how to contract the fetched results back into 91 # a single return value. 92 # 93 # Each feed function describes how to unpack a single fed value and map it to 94 # feeds of one or more tensors and their corresponding values: `feed_fn1` is 95 # used to feed a run, `feed_fn2` to set up a partial run. 96 # 97 # TODO(touts): We could reimplement these as specialized _FeedMapper 98 # implementations after we refactor the feed handling code to use them. 99 # 100 # Eventually, this registration could be opened up to support custom Tensor 101 # expansions. 102 # pylint: disable=g-long-lambda 103 _REGISTERED_EXPANSIONS = [ 104 # SparseTensors are fetched as SparseTensorValues. They can be fed 105 # SparseTensorValues or normal tuples. 106 (sparse_tensor.SparseTensor, 107 lambda fetch: ( 108 [fetch.indices, fetch.values, fetch.dense_shape], 109 lambda fetched_vals: sparse_tensor.SparseTensorValue(*fetched_vals)), 110 lambda feed, feed_val: list(zip( 111 [feed.indices, feed.values, feed.dense_shape], feed_val)), 112 lambda feed: [feed.indices, feed.values, feed.dense_shape]), 113 # IndexedSlices are fetched as IndexedSlicesValues. They can be fed 114 # IndexedSlicesValues or normal tuples. 115 (ops.IndexedSlices, 116 lambda fetch: ( 117 [fetch.values, fetch.indices] if fetch.dense_shape is None 118 else [fetch.values, fetch.indices, fetch.dense_shape], 119 _get_indexed_slices_value_from_fetches), 120 _get_feeds_for_indexed_slices, 121 lambda feed: [feed.values, feed.indices] if feed.dense_shape is None 122 else [feed.values, feed.indices, feed.dense_shape]), 123 # The default catches all other types and performs no expansions. 124 (object, 125 lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]), 126 lambda feed, feed_val: [(feed, feed_val)], 127 lambda feed: [feed])] 128 129 # pylint: enable=g-long-lambda 130 131 132 def _convert_to_numpy_obj(numpy_dtype, obj): 133 """Explicitly convert obj based on numpy type except for string type.""" 134 return numpy_dtype(obj) if numpy_dtype is not object else str(obj) 135 136 137 def register_session_run_conversion_functions( 138 tensor_type, 139 fetch_function, 140 feed_function=None, 141 feed_function_for_partial_run=None): 142 """Register fetch and feed conversion functions for `tf.Session.run()`. 143 144 This function registers a triple of conversion functions for fetching and/or 145 feeding values of user-defined types in a call to tf.Session.run(). 146 147 An example 148 149 ```python 150 class SquaredTensor(object): 151 def __init__(self, tensor): 152 self.sq = tf.square(tensor) 153 #you can define conversion functions as follows: 154 fetch_function = lambda squared_tensor:([squared_tensor.sq], 155 lambda val: val[0]) 156 feed_function = lambda feed, feed_val: [(feed.sq, feed_val)] 157 feed_function_for_partial_run = lambda feed: [feed.sq] 158 #then after invoking this register function, you can use as follows: 159 session.run(squared_tensor1, 160 feed_dict = {squared_tensor2 : some_numpy_array}) 161 ``` 162 163 Args: 164 tensor_type: The type for which you want to register a conversion function. 165 fetch_function: A callable that takes an object of type `tensor_type` and 166 returns a tuple, where the first element is a list of `tf.Tensor` objects, 167 and the second element is a callable that takes a list of ndarrays and 168 returns an object of some value type that corresponds to `tensor_type`. 169 fetch_function describes how to expand fetch into its component Tensors 170 and how to contract the fetched results back into a single return value. 171 feed_function: A callable that takes feed_key and feed_value as input, and 172 returns a list of tuples (feed_tensor, feed_val), feed_key must have type 173 `tensor_type`, and feed_tensor must have type `tf.Tensor`. Each feed 174 function describes how to unpack a single fed value and map it to feeds 175 of one or more tensors and their corresponding values. 176 feed_function_for_partial_run: A callable for specifying tensor values to 177 feed when setting up a partial run, which takes a `tensor_type` type 178 object as input, and returns a list of Tensors. 179 """ 180 for conversion_function in _REGISTERED_EXPANSIONS: 181 if issubclass(conversion_function[0], tensor_type): 182 raise ValueError('%s has already been registered so ignore it.', 183 tensor_type) 184 return 185 _REGISTERED_EXPANSIONS.insert(0, (tensor_type, fetch_function, feed_function, 186 feed_function_for_partial_run)) 187 188 189 class _FetchMapper(object): 190 """Definition of the interface provided by fetch mappers. 191 192 Fetch mappers are utility classes used by the _FetchHandler to handle 193 arbitrary structures for the `fetch` argument to `Session.run()`. 194 195 The `fetch` argument can be of various shapes: single tensor or op, list of 196 fetches, tuple of fetches, namedtuple of fetches, or dict of fetches. The 197 structures can be arbitrarily nested. 198 199 The low level run() API only wants a list of tensor or op names. The various 200 `_FetchMapper` subclasses below take care of handling the different shapes: 201 uniquifying the fetches, and constructing results with the original shape. 202 """ 203 204 def unique_fetches(self): 205 """Return the list of unique tensors or ops needed by this fetch mapper. 206 207 Returns: 208 A list of tensors or ops. 209 """ 210 raise NotImplementedError('Must be implemented by subclasses') 211 212 def build_results(self, values): 213 """Build results that match the original shape of the fetch. 214 215 Args: 216 values: List of values returned by run(). The values correspond 217 exactly to the list tensors or ops returned by unique_fetches(). 218 219 Returns: 220 A struct of the same shape as the original fetch object handled by 221 this fetch mapper. In the returned struct, the original fetches are 222 replaced by their fetched values. 223 """ 224 raise NotImplementedError('Must be implemented by subclasses') 225 226 @staticmethod 227 def for_fetch(fetch): 228 """Creates fetch mapper that handles the structure of `fetch`. 229 230 The default graph must be the one from which we want to fetch values when 231 this function is called. 232 233 Args: 234 fetch: An arbitrary fetch structure: singleton, list, tuple, 235 namedtuple, or dict. 236 237 Returns: 238 An instance of a subclass of `_FetchMapper` that handles the shape. 239 """ 240 if fetch is None: 241 raise TypeError('Fetch argument %r has invalid type %r' % (fetch, 242 type(fetch))) 243 elif isinstance(fetch, (list, tuple)): 244 # NOTE(touts): This is also the code path for namedtuples. 245 return _ListFetchMapper(fetch) 246 elif isinstance(fetch, dict): 247 return _DictFetchMapper(fetch) 248 else: 249 # Look for a handler in the registered expansions. 250 for tensor_type, fetch_fn, _, _ in _REGISTERED_EXPANSIONS: 251 if isinstance(fetch, tensor_type): 252 fetches, contraction_fn = fetch_fn(fetch) 253 return _ElementFetchMapper(fetches, contraction_fn) 254 # Did not find anything. 255 raise TypeError('Fetch argument %r has invalid type %r' % (fetch, 256 type(fetch))) 257 258 259 class _ElementFetchMapper(_FetchMapper): 260 """Fetch mapper for singleton tensors and ops.""" 261 262 def __init__(self, fetches, contraction_fn): 263 """Creates an _ElementFetchMapper. 264 265 This is the fetch mapper used for leaves in the fetch struct. Because of 266 the expansions mechanism, a leaf can actually fetch more than one tensor. 267 268 Also note that the fetches here can be just strings (tensor or op names) or 269 any other object that the graph knows how to convert to a tensor, such as a 270 Variable. So we have to run each fetch through `as_graph_element()` to get 271 the corresponding tensor or op. 272 273 Args: 274 fetches: List of objects, as returned by a fetch_fn defined 275 in _REGISTERED_EXPANSIONS. 276 contraction_fn: Callable as returned by a fetch_fn. 277 """ 278 self._unique_fetches = [] 279 for fetch in fetches: 280 try: 281 self._unique_fetches.append(ops.get_default_graph().as_graph_element( 282 fetch, allow_tensor=True, allow_operation=True)) 283 except TypeError as e: 284 raise TypeError('Fetch argument %r has invalid type %r, ' 285 'must be a string or Tensor. (%s)' % 286 (fetch, type(fetch), str(e))) 287 except ValueError as e: 288 raise ValueError('Fetch argument %r cannot be interpreted as a ' 289 'Tensor. (%s)' % (fetch, str(e))) 290 except KeyError as e: 291 raise ValueError('Fetch argument %r cannot be interpreted as a ' 292 'Tensor. (%s)' % (fetch, str(e))) 293 self._contraction_fn = contraction_fn 294 295 def unique_fetches(self): 296 return self._unique_fetches 297 298 def build_results(self, values): 299 if not values: 300 # 'Operation' case 301 return None 302 else: 303 return self._contraction_fn(values) 304 305 306 def _uniquify_fetches(fetch_mappers): 307 """Uniquifies fetches from a list of fetch_mappers. 308 309 This is a utility function used by _ListFetchMapper and _DictFetchMapper. It 310 gathers all the unique fetches from a list of mappers and builds a list 311 containing all of them but without duplicates (unique_fetches). 312 313 It also returns a 2-D list of integers (values_indices) indicating at which 314 index in unique_fetches the fetches of the mappers are located. 315 316 This list is as follows: 317 values_indices[mapper_index][mapper_fetch_index] = unique_fetches_index 318 319 Args: 320 fetch_mappers: list of fetch mappers. 321 322 Returns: 323 A list of fetches. 324 A 2-D list of integers. 325 """ 326 unique_fetches = [] 327 value_indices = [] 328 seen_fetches = {} 329 for m in fetch_mappers: 330 m_value_indices = [] 331 for f in m.unique_fetches(): 332 j = seen_fetches.get(f) 333 if j is None: 334 j = len(seen_fetches) 335 seen_fetches[f] = j 336 unique_fetches.append(f) 337 m_value_indices.append(j) 338 value_indices.append(m_value_indices) 339 return unique_fetches, value_indices 340 341 342 class _ListFetchMapper(_FetchMapper): 343 """Fetch mapper for lists, tuples, and namedtuples.""" 344 345 def __init__(self, fetches): 346 """Creates a _ListFetchMapper. 347 348 Args: 349 fetches: List, tuple, or namedtuple of fetches. 350 """ 351 self._fetch_type = type(fetches) 352 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] 353 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 354 355 def unique_fetches(self): 356 return self._unique_fetches 357 358 def build_results(self, values): 359 # Create the list of results for each mapper. 360 results = [] 361 for m, vi in zip(self._mappers, self._value_indices): 362 results.append(m.build_results([values[j] for j in vi])) 363 # Return a value of the original type of the fetches. 364 if self._fetch_type == list: 365 return results 366 elif self._fetch_type == tuple: 367 return tuple(results) 368 else: 369 # This is the code path for namedtuple. 370 return self._fetch_type(*results) 371 372 373 class _DictFetchMapper(_FetchMapper): 374 """Fetch mapper for dicts.""" 375 376 def __init__(self, fetches): 377 """Creates a _DictFetchMapper. 378 379 Args: 380 fetches: Dict of fetches. 381 """ 382 self._fetch_type = type(fetches) 383 self._keys = fetches.keys() 384 self._mappers = [ 385 _FetchMapper.for_fetch(fetch) for fetch in fetches.values() 386 ] 387 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 388 389 def unique_fetches(self): 390 return self._unique_fetches 391 392 def build_results(self, values): 393 results = self._fetch_type() 394 for k, m, vi in zip(self._keys, self._mappers, self._value_indices): 395 results[k] = m.build_results([values[j] for j in vi]) 396 return results 397 398 399 class _FetchHandler(object): 400 """Handler for structured fetches. 401 402 Given a graph, a user-provided structure for fetches, and a feed dict, this 403 class takes care of generating a list of tensor names to fetch and op names 404 to run for a low level `run()` call. 405 406 Given the results of the low level run call, this class can also rebuild a 407 result structure matching the user-provided structure for fetches, but 408 containing the corresponding results. 409 """ 410 411 # TODO(touts): Make this class also take care of destructuring the feed 412 # dict instead of doing it in the callers. 413 414 def __init__(self, graph, fetches, feeds, feed_handles=None): 415 """Creates a fetch handler. 416 417 Args: 418 graph: Graph of the fetches. Used to check for fetchability 419 and to convert all fetches to tensors or ops as needed. 420 fetches: An arbitrary fetch structure: singleton, list, tuple, 421 namedtuple, or dict. 422 feeds: A feed dict where keys are Tensors. 423 feed_handles: A dict from feed Tensors to TensorHandle objects used as 424 direct feeds. 425 """ 426 with graph.as_default(): 427 self._fetch_mapper = _FetchMapper.for_fetch(fetches) 428 self._fetches = [] 429 self._targets = [] 430 self._feeds = feeds 431 self._feed_handles = feed_handles or {} 432 self._ops = [] 433 self._fetch_handles = {} 434 for fetch in self._fetch_mapper.unique_fetches(): 435 if isinstance(fetch, ops.Operation): 436 self._assert_fetchable(graph, fetch) 437 self._targets.append(fetch) 438 self._ops.append(True) 439 else: 440 self._assert_fetchable(graph, fetch.op) 441 self._fetches.append(fetch) 442 self._ops.append(False) 443 # Remember the fetch if it is for a tensor handle. 444 if (isinstance(fetch, ops.Tensor) and 445 (fetch.op.type == 'GetSessionHandle' or 446 fetch.op.type == 'GetSessionHandleV2')): 447 self._fetch_handles[fetch] = fetch.op.inputs[0].dtype 448 self._final_fetches = [x for x in self._fetches if x not in feeds] 449 450 def _assert_fetchable(self, graph, op): 451 if not graph.is_fetchable(op): 452 raise ValueError( 453 'Operation %r has been marked as not fetchable.' % op.name) 454 455 def fetches(self): 456 """Return the unique names of tensors to fetch. 457 458 Returns: 459 A list of strings. 460 """ 461 return self._final_fetches 462 463 def targets(self): 464 """Return the unique names of ops to run. 465 466 Returns: 467 A list of strings. 468 """ 469 return self._targets 470 471 def build_results(self, session, tensor_values): 472 """Build results matching the original fetch shape. 473 474 `tensor_values` must be a list of the same length as 475 the one returned by `fetches()`, and holding the requested 476 fetch values. 477 478 This method builds a struct with the same shape as the original `fetches` 479 passed to the constructor, in which the fetches are replaced by their 480 fetched value. 481 482 Args: 483 session: The enclosing session. Used for tensor handles. 484 tensor_values: List of values matching the list returned 485 by fetches(). 486 487 Returns: 488 A structure of the same shape as the original `fetches` argument but 489 containing tensors or None (for fetched ops). 490 """ 491 full_values = [] 492 assert len(self._final_fetches) == len(tensor_values) 493 i = 0 494 j = 0 495 for is_op in self._ops: 496 if is_op: 497 full_values.append(None) 498 else: 499 # If the fetch was in the feeds, use the fed value, otherwise 500 # use the returned value. 501 if self._fetches[i] in self._feed_handles: 502 # A fetch had a corresponding direct TensorHandle feed. Call eval() 503 # to obtain the Tensor value from the TensorHandle. 504 value = self._feed_handles[self._fetches[i]].eval() 505 else: 506 value = self._feeds.get(self._fetches[i]) 507 if value is None: 508 value = tensor_values[j] 509 j += 1 510 dtype = self._fetch_handles.get(self._fetches[i]) 511 if dtype: 512 full_values.append(session_ops.TensorHandle(value, dtype, session)) 513 else: 514 full_values.append(value) 515 i += 1 516 assert j == len(tensor_values) 517 return self._fetch_mapper.build_results(full_values) 518 519 520 def _name_list(tensor_list): 521 """Utility function for transitioning to the new session API. 522 523 Args: 524 tensor_list: a list of `Tensor`s. 525 526 Returns: 527 A list of each `Tensor`s name (as byte arrays). 528 """ 529 return [compat.as_bytes(t.name) for t in tensor_list] 530 531 532 class _DeviceAttributes(object): 533 """Struct-like object describing a device's attributes. 534 535 Each device has 3 key properties: 536 - name: the fully-qualified TensorFlow path to the device. For 537 example: /job:worker/replica:0/task:3/device:CPU:0 538 - device_type: the type of the device (e.g. CPU, GPU, TPU, etc.) 539 - memory_limit_bytes: the maximum amount of memory available on the device 540 (in bytes). 541 """ 542 543 def __init__(self, name, device_type, memory_limit_bytes): 544 self._name = device.canonical_name(name) 545 self._device_type = device_type 546 self._memory_limit_bytes = memory_limit_bytes 547 548 @property 549 def name(self): 550 return self._name 551 552 @property 553 def device_type(self): 554 return self._device_type 555 556 @property 557 def memory_limit_bytes(self): 558 return self._memory_limit_bytes 559 560 def __repr__(self): 561 return '_DeviceAttributes(%s, %s, %d)' % ( 562 self.name, 563 self.device_type, 564 self.memory_limit_bytes, 565 ) 566 567 568 class BaseSession(SessionInterface): 569 """A class for interacting with a TensorFlow computation. 570 571 The BaseSession enables incremental graph building with inline 572 execution of Operations and evaluation of Tensors. 573 """ 574 575 def __init__(self, target='', graph=None, config=None): 576 """Constructs a new TensorFlow session. 577 578 Args: 579 target: (Optional) The TensorFlow execution engine to connect to. 580 graph: (Optional) The graph to be used. If this argument is None, 581 the default graph will be used. 582 config: (Optional) ConfigProto proto used to configure the session. 583 584 Raises: 585 tf.errors.OpError: Or one of its subclasses if an error occurs while 586 creating the TensorFlow session. 587 TypeError: If one of the arguments has the wrong type. 588 """ 589 if graph is None: 590 self._graph = ops.get_default_graph() 591 else: 592 if not isinstance(graph, ops.Graph): 593 raise TypeError('graph must be a tf.Graph, but got %s' % type(graph)) 594 self._graph = graph 595 596 self._opened = False 597 self._closed = False 598 599 self._current_version = 0 600 self._extend_lock = threading.Lock() 601 if target is not None: 602 try: 603 self._target = compat.as_bytes(target) 604 except TypeError: 605 raise TypeError('target must be a string, but got %s' % type(target)) 606 else: 607 self._target = None 608 609 self._delete_lock = threading.Lock() 610 self._dead_handles = [] 611 612 if config is not None: 613 if not isinstance(config, config_pb2.ConfigProto): 614 raise TypeError( 615 'config must be a tf.ConfigProto, but got %s' % type(config)) 616 self._config = config 617 self._add_shapes = config.graph_options.infer_shapes 618 else: 619 self._config = None 620 self._add_shapes = False 621 622 # pylint: disable=protected-access 623 # We cache _USE_C_API's value because some test cases will create a session 624 # with _USE_C_API = False but set it back to True before calling close(). 625 self._created_with_new_api = ops._USE_C_API 626 # pylint: enable=protected-access 627 628 self._session = None 629 opts = tf_session.TF_NewSessionOptions(target=self._target, config=config) 630 try: 631 with errors.raise_exception_on_not_ok_status() as status: 632 if self._created_with_new_api: 633 # pylint: disable=protected-access 634 self._session = tf_session.TF_NewSession(self._graph._c_graph, opts, 635 status) 636 # pylint: enable=protected-access 637 else: 638 self._session = tf_session.TF_NewDeprecatedSession(opts, status) 639 finally: 640 tf_session.TF_DeleteSessionOptions(opts) 641 642 def list_devices(self): 643 """Lists available devices in this session. 644 645 ```python 646 devices = sess.list_devices() 647 for d in devices: 648 print(d.name) 649 ``` 650 651 Each element in the list has the following properties: 652 - `name`: A string with the full name of the device. ex: 653 `/job:worker/replica:0/task:3/device:CPU:0` 654 - `device_type`: The type of the device (e.g. `CPU`, `GPU`, `TPU`.) 655 - `memory_limit`: The maximum amount of memory available on the device. 656 Note: depending on the device, it is possible the usable memory could 657 be substantially less. 658 Raises: 659 tf.errors.OpError: If it encounters an error (e.g. session is in an 660 invalid state, or network errors occur). 661 662 Returns: 663 A list of devices in the session. 664 """ 665 with errors.raise_exception_on_not_ok_status() as status: 666 if self._created_with_new_api: 667 raw_device_list = tf_session.TF_SessionListDevices( 668 self._session, status) 669 else: 670 raw_device_list = tf_session.TF_DeprecatedSessionListDevices( 671 self._session, status) 672 device_list = [] 673 size = tf_session.TF_DeviceListCount(raw_device_list) 674 for i in range(size): 675 name = tf_session.TF_DeviceListName(raw_device_list, i, status) 676 device_type = tf_session.TF_DeviceListType(raw_device_list, i, status) 677 memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i, status) 678 device_list.append(_DeviceAttributes(name, device_type, memory)) 679 tf_session.TF_DeleteDeviceList(raw_device_list) 680 return device_list 681 682 def close(self): 683 """Closes this session. 684 685 Calling this method frees all resources associated with the session. 686 687 Raises: 688 tf.errors.OpError: Or one of its subclasses if an error occurs while 689 closing the TensorFlow session. 690 """ 691 if self._created_with_new_api: 692 if self._session and not self._closed: 693 self._closed = True 694 with errors.raise_exception_on_not_ok_status() as status: 695 tf_session.TF_CloseSession(self._session, status) 696 697 else: 698 with self._extend_lock: 699 if self._opened and not self._closed: 700 self._closed = True 701 with errors.raise_exception_on_not_ok_status() as status: 702 tf_session.TF_CloseDeprecatedSession(self._session, status) 703 704 def __del__(self): 705 # cleanly ignore all exceptions 706 try: 707 self.close() 708 except Exception: # pylint: disable=broad-except 709 pass 710 if self._session is not None: 711 try: 712 status = c_api_util.ScopedTFStatus() 713 if self._created_with_new_api: 714 tf_session.TF_DeleteSession(self._session, status) 715 else: 716 tf_session.TF_DeleteDeprecatedSession(self._session, status) 717 except AttributeError: 718 # At shutdown, `c_api_util` or `tf_session` may have been garbage 719 # collected, causing the above method calls to fail. In this case, 720 # silently leak since the program is about to terminate anyway. 721 pass 722 self._session = None 723 724 @property 725 def graph(self): 726 """The graph that was launched in this session.""" 727 return self._graph 728 729 @property 730 def graph_def(self): 731 """A serializable version of the underlying TensorFlow graph. 732 733 Returns: 734 A graph_pb2.GraphDef proto containing nodes for all of the Operations in 735 the underlying TensorFlow graph. 736 """ 737 return self._graph.as_graph_def(add_shapes=self._add_shapes) 738 739 @property 740 def sess_str(self): 741 return self._target 742 743 def as_default(self): 744 """Returns a context manager that makes this object the default session. 745 746 Use with the `with` keyword to specify that calls to 747 @{tf.Operation.run} or @{tf.Tensor.eval} should be executed in 748 this session. 749 750 ```python 751 c = tf.constant(..) 752 sess = tf.Session() 753 754 with sess.as_default(): 755 assert tf.get_default_session() is sess 756 print(c.eval()) 757 ``` 758 759 To get the current default session, use @{tf.get_default_session}. 760 761 *N.B.* The `as_default` context manager *does not* close the 762 session when you exit the context, and you must close the session 763 explicitly. 764 765 ```python 766 c = tf.constant(...) 767 sess = tf.Session() 768 with sess.as_default(): 769 print(c.eval()) 770 # ... 771 with sess.as_default(): 772 print(c.eval()) 773 774 sess.close() 775 ``` 776 777 Alternatively, you can use `with tf.Session():` to create a 778 session that is automatically closed on exiting the context, 779 including when an uncaught exception is raised. 780 781 *N.B.* The default session is a property of the current thread. If you 782 create a new thread, and wish to use the default session in that 783 thread, you must explicitly add a `with sess.as_default():` in that 784 thread's function. 785 786 *N.B.* Entering a `with sess.as_default():` block does not affect 787 the current default graph. If you are using multiple graphs, and 788 `sess.graph` is different from the value of @{tf.get_default_graph}, 789 you must explicitly enter a `with sess.graph.as_default():` block 790 to make `sess.graph` the default graph. 791 792 Returns: 793 A context manager using this session as the default session. 794 """ 795 return ops.default_session(self) 796 797 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 798 """Runs operations and evaluates tensors in `fetches`. 799 800 This method runs one "step" of TensorFlow computation, by 801 running the necessary graph fragment to execute every `Operation` 802 and evaluate every `Tensor` in `fetches`, substituting the values in 803 `feed_dict` for the corresponding input values. 804 805 The `fetches` argument may be a single graph element, or an arbitrarily 806 nested list, tuple, namedtuple, dict, or OrderedDict containing graph 807 elements at its leaves. A graph element can be one of the following types: 808 809 * An @{tf.Operation}. 810 The corresponding fetched value will be `None`. 811 * A @{tf.Tensor}. 812 The corresponding fetched value will be a numpy ndarray containing the 813 value of that tensor. 814 * A @{tf.SparseTensor}. 815 The corresponding fetched value will be a 816 @{tf.SparseTensorValue} 817 containing the value of that sparse tensor. 818 * A `get_tensor_handle` op. The corresponding fetched value will be a 819 numpy ndarray containing the handle of that tensor. 820 * A `string` which is the name of a tensor or operation in the graph. 821 822 The value returned by `run()` has the same shape as the `fetches` argument, 823 where the leaves are replaced by the corresponding values returned by 824 TensorFlow. 825 826 Example: 827 828 ```python 829 a = tf.constant([10, 20]) 830 b = tf.constant([1.0, 2.0]) 831 # 'fetches' can be a singleton 832 v = session.run(a) 833 # v is the numpy array [10, 20] 834 # 'fetches' can be a list. 835 v = session.run([a, b]) 836 # v is a Python list with 2 numpy arrays: the 1-D array [10, 20] and the 837 # 1-D array [1.0, 2.0] 838 # 'fetches' can be arbitrary lists, tuples, namedtuple, dicts: 839 MyData = collections.namedtuple('MyData', ['a', 'b']) 840 v = session.run({'k1': MyData(a, b), 'k2': [b, a]}) 841 # v is a dict with 842 # v['k1'] is a MyData namedtuple with 'a' (the numpy array [10, 20]) and 843 # 'b' (the numpy array [1.0, 2.0]) 844 # v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array 845 # [10, 20]. 846 ``` 847 848 The optional `feed_dict` argument allows the caller to override 849 the value of tensors in the graph. Each key in `feed_dict` can be 850 one of the following types: 851 852 * If the key is a @{tf.Tensor}, the 853 value may be a Python scalar, string, list, or numpy ndarray 854 that can be converted to the same `dtype` as that 855 tensor. Additionally, if the key is a 856 @{tf.placeholder}, the shape of 857 the value will be checked for compatibility with the placeholder. 858 * If the key is a 859 @{tf.SparseTensor}, 860 the value should be a 861 @{tf.SparseTensorValue}. 862 * If the key is a nested tuple of `Tensor`s or `SparseTensor`s, the value 863 should be a nested tuple with the same structure that maps to their 864 corresponding values as above. 865 866 Each value in `feed_dict` must be convertible to a numpy array of the dtype 867 of the corresponding key. 868 869 The optional `options` argument expects a [`RunOptions`] proto. The options 870 allow controlling the behavior of this particular step (e.g. turning tracing 871 on). 872 873 The optional `run_metadata` argument expects a [`RunMetadata`] proto. When 874 appropriate, the non-Tensor output of this step will be collected there. For 875 example, when users turn on tracing in `options`, the profiled info will be 876 collected into this argument and passed back. 877 878 Args: 879 fetches: A single graph element, a list of graph elements, 880 or a dictionary whose values are graph elements or lists of graph 881 elements (described above). 882 feed_dict: A dictionary that maps graph elements to values 883 (described above). 884 options: A [`RunOptions`] protocol buffer 885 run_metadata: A [`RunMetadata`] protocol buffer 886 887 Returns: 888 Either a single value if `fetches` is a single graph element, or 889 a list of values if `fetches` is a list, or a dictionary with the 890 same keys as `fetches` if that is a dictionary (described above). 891 892 Raises: 893 RuntimeError: If this `Session` is in an invalid state (e.g. has been 894 closed). 895 TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type. 896 ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a 897 `Tensor` that doesn't exist. 898 """ 899 options_ptr = tf_session.TF_NewBufferFromString( 900 compat.as_bytes(options.SerializeToString())) if options else None 901 run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None 902 903 try: 904 result = self._run(None, fetches, feed_dict, options_ptr, 905 run_metadata_ptr) 906 if run_metadata: 907 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) 908 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 909 finally: 910 if run_metadata_ptr: 911 tf_session.TF_DeleteBuffer(run_metadata_ptr) 912 if options: 913 tf_session.TF_DeleteBuffer(options_ptr) 914 return result 915 916 def partial_run(self, handle, fetches, feed_dict=None): 917 """Continues the execution with more feeds and fetches. 918 919 This is EXPERIMENTAL and subject to change. 920 921 To use partial execution, a user first calls `partial_run_setup()` and 922 then a sequence of `partial_run()`. `partial_run_setup` specifies the 923 list of feeds and fetches that will be used in the subsequent 924 `partial_run` calls. 925 926 The optional `feed_dict` argument allows the caller to override 927 the value of tensors in the graph. See run() for more information. 928 929 Below is a simple example: 930 931 ```python 932 a = array_ops.placeholder(dtypes.float32, shape=[]) 933 b = array_ops.placeholder(dtypes.float32, shape=[]) 934 c = array_ops.placeholder(dtypes.float32, shape=[]) 935 r1 = math_ops.add(a, b) 936 r2 = math_ops.multiply(r1, c) 937 938 h = sess.partial_run_setup([r1, r2], [a, b, c]) 939 res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2}) 940 res = sess.partial_run(h, r2, feed_dict={c: res}) 941 ``` 942 943 Args: 944 handle: A handle for a sequence of partial runs. 945 fetches: A single graph element, a list of graph elements, 946 or a dictionary whose values are graph elements or lists of graph 947 elements (see documentation for `run`). 948 feed_dict: A dictionary that maps graph elements to values 949 (described above). 950 951 Returns: 952 Either a single value if `fetches` is a single graph element, or 953 a list of values if `fetches` is a list, or a dictionary with the 954 same keys as `fetches` if that is a dictionary 955 (see documentation for `run`). 956 957 Raises: 958 tf.errors.OpError: Or one of its subclasses on error. 959 """ 960 # TODO(touts): Support feeding and fetching the same tensor. 961 return self._run(handle, fetches, feed_dict, None, None) 962 963 def partial_run_setup(self, fetches, feeds=None): 964 """Sets up a graph with feeds and fetches for partial run. 965 966 This is EXPERIMENTAL and subject to change. 967 968 Note that contrary to `run`, `feeds` only specifies the graph elements. 969 The tensors will be supplied by the subsequent `partial_run` calls. 970 971 Args: 972 fetches: A single graph element, or a list of graph elements. 973 feeds: A single graph element, or a list of graph elements. 974 975 Returns: 976 A handle for partial run. 977 978 Raises: 979 RuntimeError: If this `Session` is in an invalid state (e.g. has been 980 closed). 981 TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type. 982 tf.errors.OpError: Or one of its subclasses if a TensorFlow error happens. 983 """ 984 985 def _feed_fn(feed): 986 for tensor_type, _, _, feed_fn in _REGISTERED_EXPANSIONS: 987 if isinstance(feed, tensor_type): 988 return feed_fn(feed) 989 raise TypeError('Feed argument %r has invalid type %r' % (feed, 990 type(feed))) 991 992 # Check session. 993 if self._closed: 994 raise RuntimeError('Attempted to use a closed Session.') 995 if self.graph.version == 0: 996 raise RuntimeError('The Session graph is empty. Add operations to the ' 997 'graph before calling run().') 998 999 if feeds is None: 1000 feeds = [] 1001 # Create request. 1002 feed_list = [] 1003 1004 # Validate and process feed_list. 1005 is_list_feed = isinstance(feeds, (list, tuple)) 1006 if not is_list_feed: 1007 feeds = [feeds] 1008 for feed in feeds: 1009 for subfeed in _feed_fn(feed): 1010 try: 1011 subfeed_t = self.graph.as_graph_element( 1012 subfeed, allow_tensor=True, allow_operation=False) 1013 if self._created_with_new_api: 1014 # pylint: disable=protected-access 1015 feed_list.append(subfeed_t._as_tf_output()) 1016 # pylint: enable=protected-access 1017 else: 1018 feed_list.append(compat.as_bytes(subfeed_t.name)) 1019 except Exception as e: 1020 e.message = ('Cannot interpret feed_list key as Tensor: ' + e.message) 1021 e.args = (e.message,) 1022 raise e 1023 1024 # Validate and process fetches. 1025 # TODO(touts): Support feeding and fetching the same tensor. 1026 fetch_handler = _FetchHandler(self._graph, fetches, {}) 1027 1028 # Set up a graph with feeds and fetches for partial run. 1029 def _setup_fn(session, feed_list, fetch_list, target_list): 1030 self._extend_graph() 1031 with errors.raise_exception_on_not_ok_status() as status: 1032 if self._created_with_new_api: 1033 return tf_session.TF_SessionPRunSetup_wrapper( 1034 session, feed_list, fetch_list, target_list, status) 1035 else: 1036 return tf_session.TF_PRunSetup(session, feed_list, fetch_list, 1037 target_list, status) 1038 1039 if self._created_with_new_api: 1040 # pylint: disable=protected-access 1041 final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()] 1042 final_targets = [op._c_op for op in fetch_handler.targets()] 1043 # pylint: enable=protected-access 1044 else: 1045 final_fetches = _name_list(fetch_handler.fetches()) 1046 final_targets = _name_list(fetch_handler.targets()) 1047 1048 return self._do_call(_setup_fn, self._session, feed_list, final_fetches, 1049 final_targets) 1050 1051 def _run(self, handle, fetches, feed_dict, options, run_metadata): 1052 """Perform either run or partial_run, depending the presence of `handle`.""" 1053 1054 def _feed_fn(feed, feed_val): 1055 for tensor_type, _, feed_fn, _ in _REGISTERED_EXPANSIONS: 1056 if isinstance(feed, tensor_type): 1057 return feed_fn(feed, feed_val) 1058 raise TypeError('Feed argument %r has invalid type %r' % (feed, 1059 type(feed))) 1060 1061 # Check session. 1062 if self._closed: 1063 raise RuntimeError('Attempted to use a closed Session.') 1064 if self.graph.version == 0: 1065 raise RuntimeError('The Session graph is empty. Add operations to the ' 1066 'graph before calling run().') 1067 1068 # Create request. 1069 feed_dict_tensor = {} 1070 feed_map = {} 1071 1072 # Validate and process feed_dict. 1073 feed_handles = {} 1074 if feed_dict: 1075 feed_dict = nest.flatten_dict_items(feed_dict) 1076 for feed, feed_val in feed_dict.items(): 1077 for subfeed, subfeed_val in _feed_fn(feed, feed_val): 1078 try: 1079 subfeed_t = self.graph.as_graph_element( 1080 subfeed, allow_tensor=True, allow_operation=False) 1081 except Exception as e: 1082 raise TypeError( 1083 'Cannot interpret feed_dict key as Tensor: ' + e.args[0]) 1084 1085 if isinstance(subfeed_val, ops.Tensor): 1086 raise TypeError('The value of a feed cannot be a tf.Tensor object. ' 1087 'Acceptable feed values include Python scalars, ' 1088 'strings, lists, numpy ndarrays, or TensorHandles.') 1089 1090 subfeed_dtype = subfeed_t.dtype.as_numpy_dtype 1091 if isinstance(subfeed_val, int) and _convert_to_numpy_obj( 1092 subfeed_dtype, subfeed_val) != subfeed_val: 1093 raise TypeError( 1094 'Type of feed value ' + str(subfeed_val) + ' with type ' + str( 1095 type(subfeed_val)) + 1096 ' is not compatible with Tensor type ' + str(subfeed_dtype) + 1097 '. Try explicitly setting the type of the feed tensor' 1098 ' to a larger type (e.g. int64).') 1099 1100 is_tensor_handle_feed = isinstance(subfeed_val, 1101 session_ops.TensorHandle) 1102 if is_tensor_handle_feed: 1103 np_val = subfeed_val.to_numpy_array() 1104 feed_handles[subfeed_t] = subfeed_val 1105 else: 1106 np_val = np.asarray(subfeed_val, dtype=subfeed_dtype) 1107 1108 if (not is_tensor_handle_feed and 1109 not subfeed_t.get_shape().is_compatible_with(np_val.shape)): 1110 raise ValueError('Cannot feed value of shape %r for Tensor %r, ' 1111 'which has shape %r' % 1112 (np_val.shape, subfeed_t.name, 1113 str(subfeed_t.get_shape()))) 1114 if not self.graph.is_feedable(subfeed_t): 1115 raise ValueError('Tensor %s may not be fed.' % subfeed_t) 1116 1117 feed_dict_tensor[subfeed_t] = np_val 1118 feed_map[compat.as_bytes(subfeed_t.name)] = (subfeed_t, subfeed_val) 1119 1120 # Create a fetch handler to take care of the structure of fetches. 1121 fetch_handler = _FetchHandler( 1122 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles) 1123 1124 # Run request and get response. 1125 # We need to keep the returned movers alive for the following _do_run(). 1126 # These movers are no longer needed when _do_run() completes, and 1127 # are deleted when `movers` goes out of scope when this _run() ends. 1128 # TODO(yuanbyu, keveman): Revisit whether we should just treat feeding 1129 # of a handle from a different device as an error. 1130 _ = self._update_with_movers(feed_dict_tensor, feed_map) 1131 final_fetches = fetch_handler.fetches() 1132 final_targets = fetch_handler.targets() 1133 # We only want to really perform the run if fetches or targets are provided, 1134 # or if the call is a partial run that specifies feeds. 1135 if final_fetches or final_targets or (handle and feed_dict_tensor): 1136 results = self._do_run(handle, final_targets, final_fetches, 1137 feed_dict_tensor, options, run_metadata) 1138 else: 1139 results = [] 1140 return fetch_handler.build_results(self, results) 1141 1142 def make_callable(self, fetches, feed_list=None, accept_options=False): 1143 """Returns a Python callable that runs a particular step. 1144 1145 The returned callable will take `len(feed_list)` arguments whose types 1146 must be compatible feed values for the respective elements of `feed_list`. 1147 For example, if element `i` of `feed_list` is a `tf.Tensor`, the `i`th 1148 argument to the returned callable must be a numpy ndarray (or something 1149 convertible to an ndarray) with matching element type and shape. See 1150 @{tf.Session.run} for details of the allowable feed key and value types. 1151 1152 The returned callable will have the same return type as 1153 `tf.Session.run(fetches, ...)`. For example, if `fetches` is a `tf.Tensor`, 1154 the callable will return a numpy ndarray; if `fetches` is a `tf.Operation`, 1155 it will return `None`. 1156 1157 Args: 1158 fetches: A value or list of values to fetch. See @{tf.Session.run} 1159 for details of the allowable fetch types. 1160 feed_list: (Optional.) A list of `feed_dict` keys. See 1161 @{tf.Session.run} for details of the allowable feed key types. 1162 accept_options: (Optional.) Iff `True`, the returned `Callable` will be 1163 able to accept @{tf.RunOptions} and @{tf.RunMetadata} as optional 1164 keyword arguments `options` and `run_metadata`, respectively, with 1165 the same syntax and semantics as @{tf.Session.run}, which is useful 1166 for certain use cases (profiling and debugging) but will result in 1167 measurable slowdown of the `Callable`'s performance. Default: `False`. 1168 1169 Returns: 1170 A function that when called will execute the step defined by 1171 `feed_list` and `fetches` in this session. 1172 1173 Raises: 1174 TypeError: If `fetches` or `feed_list` cannot be interpreted 1175 as arguments to @{tf.Session.run}. 1176 """ 1177 if feed_list is not None: 1178 if not isinstance(feed_list, (list, tuple)): 1179 raise TypeError('`feed_list` must be a list or tuple.') 1180 # Delegate any non-empty feed lists to the existing `run()` logic. 1181 # TODO(mrry): Refactor the feed handling logic from 1182 # `Session._run()` so that we can convert the feeds to a list of 1183 # strings here. 1184 def _generic_run(*feed_args, **kwargs): 1185 feed_dict = { 1186 feed: feed_val 1187 for feed, feed_val in zip(feed_list, feed_args) 1188 } 1189 return self.run(fetches, feed_dict=feed_dict, **kwargs) 1190 1191 return _generic_run 1192 1193 # Ensure any changes to the graph are reflected in the runtime. 1194 # Note that we don't need to do this on subsequent calls to the 1195 # returned object, because the arguments to `fetches` must already be 1196 # in the graph. 1197 self._extend_graph() 1198 1199 # Create a fetch handler to take care of the structure of fetches. 1200 fetch_handler = _FetchHandler(self._graph, fetches, {}) 1201 if self._created_with_new_api: 1202 # pylint: disable=protected-access 1203 fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()] 1204 target_list = [op._c_op for op in fetch_handler.targets()] 1205 # pylint: enable=protected-access 1206 else: 1207 fetch_list = _name_list(fetch_handler.fetches()) 1208 target_list = _name_list(fetch_handler.targets()) 1209 1210 def _callable_template_with_options_and_metadata(fetch_list, 1211 target_list, 1212 fetch_handler, 1213 options=None, 1214 run_metadata=None): 1215 """Template callable that accepts RunOptions and RunMetadata.""" 1216 options_ptr = tf_session.TF_NewBufferFromString( 1217 compat.as_bytes(options.SerializeToString())) if options else None 1218 run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None 1219 try: 1220 with errors.raise_exception_on_not_ok_status() as status: 1221 if self._created_with_new_api: 1222 results = tf_session.TF_SessionRun_wrapper( 1223 self._session, options_ptr, {}, fetch_list, target_list, 1224 run_metadata_ptr, status) 1225 else: 1226 results = tf_session.TF_Run(self._session, options_ptr, {}, 1227 fetch_list, target_list, status, 1228 run_metadata_ptr) 1229 if fetch_handler: 1230 results = fetch_handler.build_results(self, results) 1231 else: 1232 results = results[0] if results else None 1233 if run_metadata: 1234 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) 1235 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 1236 finally: 1237 if run_metadata_ptr: 1238 tf_session.TF_DeleteBuffer(run_metadata_ptr) 1239 if options: 1240 tf_session.TF_DeleteBuffer(options_ptr) 1241 return results 1242 1243 if accept_options: 1244 return functools.partial(_callable_template_with_options_and_metadata, 1245 fetch_list, target_list, fetch_handler) 1246 elif isinstance(fetches, ops.Operation): 1247 # Special case for fetching a single operation, because the 1248 # function will have no return value. 1249 assert not fetch_list 1250 assert len(target_list) == 1 1251 1252 def _single_operation_run(): 1253 with errors.raise_exception_on_not_ok_status() as status: 1254 if self._created_with_new_api: 1255 tf_session.TF_SessionRun_wrapper(self._session, None, {}, [], 1256 target_list, None, status) 1257 else: 1258 tf_session.TF_Run(self._session, None, {}, [], target_list, status, 1259 None) 1260 1261 return _single_operation_run 1262 elif isinstance(fetches, ops.Tensor): 1263 # Special case for fetching a single tensor, because the 1264 # function can return the result of `TF_Run()` directly. 1265 assert len(fetch_list) == 1 1266 assert not target_list 1267 1268 def _single_tensor_run(): 1269 with errors.raise_exception_on_not_ok_status() as status: 1270 if self._created_with_new_api: 1271 results = tf_session.TF_SessionRun_wrapper( 1272 self._session, None, {}, fetch_list, [], None, status) 1273 else: 1274 results = tf_session.TF_Run(self._session, None, {}, fetch_list, [], 1275 status, None) 1276 return results[0] 1277 1278 return _single_tensor_run 1279 else: 1280 # In all other cases, we must use `fetch_handler` to build the 1281 # results for us. 1282 def _fetch_handler_run(): 1283 with errors.raise_exception_on_not_ok_status() as status: 1284 if self._created_with_new_api: 1285 results = tf_session.TF_SessionRun_wrapper( 1286 self._session, None, {}, fetch_list, target_list, None, status) 1287 else: 1288 results = tf_session.TF_Run(self._session, None, {}, fetch_list, 1289 target_list, status, None) 1290 return fetch_handler.build_results(self, results) 1291 1292 return _fetch_handler_run 1293 1294 # Captures the name of a node in an error status. 1295 _NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =') 1296 1297 def _do_run(self, handle, target_list, fetch_list, feed_dict, options, 1298 run_metadata): 1299 """Runs a step based on the given fetches and feeds. 1300 1301 Args: 1302 handle: a handle for partial_run. None if this is just a call to run(). 1303 target_list: A list of operations to be run, but not fetched. 1304 fetch_list: A list of tensors to be fetched. 1305 feed_dict: A dictionary that maps tensors to numpy ndarrays. 1306 options: A (pointer to a) [`RunOptions`] protocol buffer, or None 1307 run_metadata: A (pointer to a) [`RunMetadata`] protocol buffer, or None 1308 1309 Returns: 1310 A list of numpy ndarrays, corresponding to the elements of 1311 `fetch_list`. If the ith element of `fetch_list` contains the 1312 name of an operation, the first Tensor output of that operation 1313 will be returned for that element. 1314 1315 Raises: 1316 tf.errors.OpError: Or one of its subclasses on error. 1317 """ 1318 if self._created_with_new_api: 1319 # pylint: disable=protected-access 1320 feeds = dict((t._as_tf_output(), v) for t, v in feed_dict.items()) 1321 fetches = [t._as_tf_output() for t in fetch_list] 1322 targets = [op._c_op for op in target_list] 1323 # pylint: enable=protected-access 1324 else: 1325 feeds = dict((compat.as_bytes(t.name), v) for t, v in feed_dict.items()) 1326 fetches = _name_list(fetch_list) 1327 targets = _name_list(target_list) 1328 1329 def _run_fn(session, feed_dict, fetch_list, target_list, options, 1330 run_metadata): 1331 # Ensure any changes to the graph are reflected in the runtime. 1332 self._extend_graph() 1333 with errors.raise_exception_on_not_ok_status() as status: 1334 if self._created_with_new_api: 1335 return tf_session.TF_SessionRun_wrapper(session, options, feed_dict, 1336 fetch_list, target_list, 1337 run_metadata, status) 1338 else: 1339 return tf_session.TF_Run(session, options, feed_dict, fetch_list, 1340 target_list, status, run_metadata) 1341 1342 def _prun_fn(session, handle, feed_dict, fetch_list): 1343 if target_list: 1344 raise RuntimeError('partial_run() requires empty target_list.') 1345 with errors.raise_exception_on_not_ok_status() as status: 1346 if self._created_with_new_api: 1347 return tf_session.TF_SessionPRun_wrapper(session, handle, feed_dict, 1348 fetch_list, status) 1349 else: 1350 return tf_session.TF_PRun(session, handle, feed_dict, fetch_list, 1351 status) 1352 1353 if handle is None: 1354 return self._do_call(_run_fn, self._session, feeds, fetches, targets, 1355 options, run_metadata) 1356 else: 1357 return self._do_call(_prun_fn, self._session, handle, feeds, fetches) 1358 1359 def _do_call(self, fn, *args): 1360 try: 1361 return fn(*args) 1362 except errors.OpError as e: 1363 message = compat.as_text(e.message) 1364 m = BaseSession._NODEDEF_NAME_RE.search(message) 1365 node_def = None 1366 op = None 1367 if m is not None: 1368 node_name = m.group(1) 1369 try: 1370 op = self._graph.get_operation_by_name(node_name) 1371 node_def = op.node_def 1372 except KeyError: 1373 pass 1374 raise type(e)(node_def, op, message) 1375 1376 def _extend_graph(self): 1377 # Nothing to do if we're using the new session interface 1378 # TODO(skyewm): remove this function altogether eventually 1379 if self._created_with_new_api: 1380 return 1381 1382 # Ensure any changes to the graph are reflected in the runtime. 1383 with self._extend_lock: 1384 if self._graph.version > self._current_version: 1385 # pylint: disable=protected-access 1386 graph_def, self._current_version = self._graph._as_graph_def( 1387 from_version=self._current_version, add_shapes=self._add_shapes) 1388 # pylint: enable=protected-access 1389 1390 with errors.raise_exception_on_not_ok_status() as status: 1391 tf_session.TF_ExtendGraph(self._session, 1392 graph_def.SerializeToString(), status) 1393 self._opened = True 1394 1395 # The threshold to run garbage collection to delete dead tensors. 1396 _DEAD_HANDLES_THRESHOLD = 10 1397 1398 def _register_dead_handle(self, handle): 1399 # Register a dead handle in the session. Delete the dead tensors when 1400 # the number of dead tensors exceeds certain threshold. 1401 tensors_to_delete = None 1402 with self._delete_lock: 1403 self._dead_handles.append(handle) 1404 if len(self._dead_handles) == BaseSession._DEAD_HANDLES_THRESHOLD: 1405 tensors_to_delete = self._dead_handles 1406 self._dead_handles = [] 1407 # Delete the dead tensors. 1408 if tensors_to_delete: 1409 feeds = {} 1410 fetches = [] 1411 for deleter_key, tensor_handle in enumerate(tensors_to_delete): 1412 holder, deleter = session_ops._get_handle_deleter( 1413 self.graph, deleter_key, tensor_handle) 1414 feeds[holder] = tensor_handle 1415 fetches.append(deleter) 1416 self.run(fetches, feed_dict=feeds) 1417 1418 def _update_with_movers(self, feed_dict, feed_map): 1419 # If a tensor handle that is fed to a device incompatible placeholder, 1420 # we move the tensor to the right device, generate a new tensor handle, 1421 # and update `feed_dict` to use the new handle. 1422 handle_movers = [] 1423 for feed_name, val in feed_map.items(): 1424 mover = session_ops._get_handle_mover(self.graph, *val) 1425 if mover: 1426 handle_movers.append((feed_name, val[1], mover)) 1427 # Transfer a tensor to the right device if needed. 1428 if not handle_movers: 1429 return [] 1430 else: 1431 feeds = {} 1432 fetches = [] 1433 for _, handle, mover in handle_movers: 1434 feeds[mover[0]] = handle 1435 fetches.append(mover[1]) 1436 handles = self.run(fetches, feed_dict=feeds) 1437 for handle_mover, handle in zip(handle_movers, handles): 1438 np_val = np.array(handle.handle, dtype=np.object) 1439 feed_name = handle_mover[0] 1440 feed_tensor = feed_map[feed_name][0] 1441 feed_dict[feed_tensor] = np_val 1442 return handles 1443 1444 1445 @tf_export('Session') 1446 class Session(BaseSession): 1447 """A class for running TensorFlow operations. 1448 1449 A `Session` object encapsulates the environment in which `Operation` 1450 objects are executed, and `Tensor` objects are evaluated. For 1451 example: 1452 1453 ```python 1454 # Build a graph. 1455 a = tf.constant(5.0) 1456 b = tf.constant(6.0) 1457 c = a * b 1458 1459 # Launch the graph in a session. 1460 sess = tf.Session() 1461 1462 # Evaluate the tensor `c`. 1463 print(sess.run(c)) 1464 ``` 1465 1466 A session may own resources, such as 1467 @{tf.Variable}, @{tf.QueueBase}, 1468 and @{tf.ReaderBase}. It is important to release 1469 these resources when they are no longer required. To do this, either 1470 invoke the @{tf.Session.close} method on the session, or use 1471 the session as a context manager. The following two examples are 1472 equivalent: 1473 1474 ```python 1475 # Using the `close()` method. 1476 sess = tf.Session() 1477 sess.run(...) 1478 sess.close() 1479 1480 # Using the context manager. 1481 with tf.Session() as sess: 1482 sess.run(...) 1483 ``` 1484 1485 The 1486 [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto) 1487 protocol buffer exposes various configuration options for a 1488 session. For example, to create a session that uses soft constraints 1489 for device placement, and log the resulting placement decisions, 1490 create a session as follows: 1491 1492 ```python 1493 # Launch the graph in a session that allows soft device placement and 1494 # logs the placement decisions. 1495 sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, 1496 log_device_placement=True)) 1497 ``` 1498 """ 1499 1500 def __init__(self, target='', graph=None, config=None): 1501 """Creates a new TensorFlow session. 1502 1503 If no `graph` argument is specified when constructing the session, 1504 the default graph will be launched in the session. If you are 1505 using more than one graph (created with `tf.Graph()` in the same 1506 process, you will have to use different sessions for each graph, 1507 but each graph can be used in multiple sessions. In this case, it 1508 is often clearer to pass the graph to be launched explicitly to 1509 the session constructor. 1510 1511 Args: 1512 target: (Optional.) The execution engine to connect to. 1513 Defaults to using an in-process engine. See 1514 @{$distributed$Distributed TensorFlow} 1515 for more examples. 1516 graph: (Optional.) The `Graph` to be launched (described above). 1517 config: (Optional.) A 1518 [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto) 1519 protocol buffer with configuration options for the session. 1520 1521 """ 1522 super(Session, self).__init__(target, graph, config=config) 1523 # NOTE(mrry): Create these on first `__enter__` to avoid a reference cycle. 1524 self._default_graph_context_manager = None 1525 self._default_session_context_manager = None 1526 1527 def __enter__(self): 1528 if self._default_graph_context_manager is None: 1529 self._default_graph_context_manager = self.graph.as_default() 1530 else: 1531 raise RuntimeError('Session context managers are not re-entrant. ' 1532 'Use `Session.as_default()` if you want to enter ' 1533 'a session multiple times.') 1534 if self._default_session_context_manager is None: 1535 self._default_session_context_manager = self.as_default() 1536 self._default_graph_context_manager.__enter__() 1537 return self._default_session_context_manager.__enter__() 1538 1539 def __exit__(self, exec_type, exec_value, exec_tb): 1540 if exec_type is errors.OpError: 1541 logging.error('Session closing due to OpError: %s', (exec_value,)) 1542 try: 1543 self._default_session_context_manager.__exit__(exec_type, exec_value, 1544 exec_tb) 1545 except RuntimeError as error: 1546 if error == exec_value: 1547 # NOTE(skyewm): for some reason, in Python3, 1548 # _default_session_context_manager.__exit__ will re-raise the "not 1549 # re-entrant" exception raised in __enter__ above (note that if we're 1550 # here, we're in the outer session context manager, since __exit__ is 1551 # not called when __enter__ raises an exception). We still want to 1552 # continue cleaning up this context manager before the exception is 1553 # further propagated, so we ignore it here (note that it'll continue 1554 # being propagated after this method completes). 1555 pass 1556 else: 1557 raise 1558 self._default_graph_context_manager.__exit__(exec_type, exec_value, exec_tb) 1559 1560 self._default_session_context_manager = None 1561 self._default_graph_context_manager = None 1562 1563 self.close() 1564 1565 @staticmethod 1566 def reset(target, containers=None, config=None): 1567 """Resets resource containers on `target`, and close all connected sessions. 1568 1569 A resource container is distributed across all workers in the 1570 same cluster as `target`. When a resource container on `target` 1571 is reset, resources associated with that container will be cleared. 1572 In particular, all Variables in the container will become undefined: 1573 they lose their values and shapes. 1574 1575 NOTE: 1576 (i) reset() is currently only implemented for distributed sessions. 1577 (ii) Any sessions on the master named by `target` will be closed. 1578 1579 If no resource containers are provided, all containers are reset. 1580 1581 Args: 1582 target: The execution engine to connect to. 1583 containers: A list of resource container name strings, or `None` if all of 1584 all the containers are to be reset. 1585 config: (Optional.) Protocol buffer with configuration options. 1586 1587 Raises: 1588 tf.errors.OpError: Or one of its subclasses if an error occurs while 1589 resetting containers. 1590 """ 1591 if target is not None: 1592 target = compat.as_bytes(target) 1593 if containers is not None: 1594 containers = [compat.as_bytes(c) for c in containers] 1595 else: 1596 containers = [] 1597 tf_session.TF_Reset(target, containers, config) 1598 1599 1600 @tf_export('InteractiveSession') 1601 class InteractiveSession(BaseSession): 1602 """A TensorFlow `Session` for use in interactive contexts, such as a shell. 1603 1604 The only difference with a regular `Session` is that an `InteractiveSession` 1605 installs itself as the default session on construction. 1606 The methods @{tf.Tensor.eval} 1607 and @{tf.Operation.run} 1608 will use that session to run ops. 1609 1610 This is convenient in interactive shells and [IPython 1611 notebooks](http://ipython.org), as it avoids having to pass an explicit 1612 `Session` object to run ops. 1613 1614 For example: 1615 1616 ```python 1617 sess = tf.InteractiveSession() 1618 a = tf.constant(5.0) 1619 b = tf.constant(6.0) 1620 c = a * b 1621 # We can just use 'c.eval()' without passing 'sess' 1622 print(c.eval()) 1623 sess.close() 1624 ``` 1625 1626 Note that a regular session installs itself as the default session when it 1627 is created in a `with` statement. The common usage in non-interactive 1628 programs is to follow that pattern: 1629 1630 ```python 1631 a = tf.constant(5.0) 1632 b = tf.constant(6.0) 1633 c = a * b 1634 with tf.Session(): 1635 # We can also use 'c.eval()' here. 1636 print(c.eval()) 1637 ``` 1638 """ 1639 1640 def __init__(self, target='', graph=None, config=None): 1641 """Creates a new interactive TensorFlow session. 1642 1643 If no `graph` argument is specified when constructing the session, 1644 the default graph will be launched in the session. If you are 1645 using more than one graph (created with `tf.Graph()` in the same 1646 process, you will have to use different sessions for each graph, 1647 but each graph can be used in multiple sessions. In this case, it 1648 is often clearer to pass the graph to be launched explicitly to 1649 the session constructor. 1650 1651 Args: 1652 target: (Optional.) The execution engine to connect to. 1653 Defaults to using an in-process engine. 1654 graph: (Optional.) The `Graph` to be launched (described above). 1655 config: (Optional) `ConfigProto` proto used to configure the session. 1656 """ 1657 if not config: 1658 # If config is not provided, choose some reasonable defaults for 1659 # interactive use: 1660 # 1661 # - Grow GPU memory as needed at the cost of fragmentation. 1662 gpu_options = config_pb2.GPUOptions(allow_growth=True) 1663 config = config_pb2.ConfigProto(gpu_options=gpu_options) 1664 # Interactive sessions always place pruned graphs. 1665 config.graph_options.place_pruned_graph = True 1666 1667 super(InteractiveSession, self).__init__(target, graph, config) 1668 self._default_session = self.as_default() 1669 self._default_session.enforce_nesting = False 1670 self._default_session.__enter__() 1671 self._explicit_graph = graph 1672 if self._explicit_graph is not None: 1673 self._default_graph = graph.as_default() 1674 self._default_graph.enforce_nesting = False 1675 self._default_graph.__enter__() 1676 1677 def close(self): 1678 """Closes an `InteractiveSession`.""" 1679 super(InteractiveSession, self).close() 1680 if self._explicit_graph is not None: 1681 self._default_graph.__exit__(None, None, None) 1682 self._default_session.__exit__(None, None, None) 1683