1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tensorflow.python.client.session.Session.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import collections 21 import os 22 import sys 23 import threading 24 import time 25 26 import numpy as np 27 import six 28 from six.moves import xrange # pylint: disable=redefined-builtin 29 30 from tensorflow.core.framework import attr_value_pb2 31 from tensorflow.core.framework import types_pb2 32 from tensorflow.core.lib.core import error_codes_pb2 33 from tensorflow.core.protobuf import config_pb2 34 from tensorflow.python.client import session 35 from tensorflow.python.framework import common_shapes 36 from tensorflow.python.framework import constant_op 37 from tensorflow.python.framework import dtypes 38 from tensorflow.python.framework import errors 39 from tensorflow.python.framework import function 40 from tensorflow.python.framework import ops 41 from tensorflow.python.framework import sparse_tensor 42 from tensorflow.python.framework import tensor_util 43 from tensorflow.python.framework import test_util 44 from tensorflow.python.framework import versions 45 from tensorflow.python.ops import array_ops 46 from tensorflow.python.ops import control_flow_ops 47 from tensorflow.python.ops import data_flow_ops 48 from tensorflow.python.ops import gen_control_flow_ops 49 from tensorflow.python.ops import math_ops 50 # Import resource_variable_ops for the variables-to-tensor implicit conversion. 51 from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import 52 from tensorflow.python.ops import state_ops 53 from tensorflow.python.ops import variables 54 from tensorflow.python.platform import googletest 55 from tensorflow.python.training import server_lib 56 from tensorflow.python.util import compat 57 58 # NOTE(mrry): Dummy shape registration for ops used in the tests, since they 59 # don't have C++ op registrations on which to attach C++ shape fns. 60 ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape) 61 62 63 @test_util.with_c_api 64 class SessionTest(test_util.TensorFlowTestCase): 65 66 def testUseExistingGraph(self): 67 with ops.Graph().as_default() as g, ops.device('/cpu:0'): 68 a = constant_op.constant(6.0, shape=[1, 1]) 69 b = constant_op.constant(7.0, shape=[1, 1]) 70 c = math_ops.matmul(a, b, name='matmul') 71 with session.Session(graph=g): 72 result = c.eval() 73 self.assertAllEqual(result, [[42.0]]) 74 75 def testUseDefaultGraph(self): 76 with ops.Graph().as_default(), ops.device('/cpu:0'): 77 a = constant_op.constant(6.0, shape=[1, 1]) 78 b = constant_op.constant(7.0, shape=[1, 1]) 79 c = math_ops.matmul(a, b, name='matmul') 80 with session.Session(): 81 result = c.eval() 82 self.assertAllEqual(result, [[42.0]]) 83 84 def testCreate(self): 85 with session.Session(): 86 inp = constant_op.constant(10.0, shape=[2, 3], name='W1') 87 copy = array_ops.identity(inp) 88 # Test with feed. 89 # TODO(mrry): Investigate why order='F' didn't work. 90 arr = np.asarray([[0, 1, 2], [3, 4, 5]], dtype=np.float32, order='C') 91 copy_val = copy.eval({'W1:0': arr}) 92 self.assertAllEqual(arr, copy_val) 93 # Test without feed. 94 copy_val = copy.eval() 95 self.assertAllEqual( 96 np.asarray( 97 [[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]], dtype=np.float32), 98 copy_val) 99 100 def testManyCPUs(self): 101 # TODO(keveman): Implement ListDevices and test for the number of 102 # devices returned by ListDevices. 103 with session.Session( 104 config=config_pb2.ConfigProto(device_count={ 105 'CPU': 2 106 })): 107 inp = constant_op.constant(10.0, name='W1') 108 self.assertAllEqual(inp.eval(), 10.0) 109 110 def testPerSessionThreads(self): 111 # TODO(keveman): Implement ListDevices and test for the number of 112 # devices returned by ListDevices. 113 with session.Session( 114 config=config_pb2.ConfigProto(use_per_session_threads=True)): 115 inp = constant_op.constant(10.0, name='W1') 116 self.assertAllEqual(inp.eval(), 10.0) 117 118 def testSessionInterOpThreadPool(self): 119 config = config_pb2.ConfigProto() 120 pool = config.session_inter_op_thread_pool.add() 121 with session.Session(config=config) as s: 122 inp = constant_op.constant(10.0, name='W1') 123 results = s.run([inp]) 124 self.assertAllEqual([10.0], results) 125 126 pool = config.session_inter_op_thread_pool.add() 127 pool.num_threads = 1 128 with session.Session(config=config) as s: 129 inp = constant_op.constant(20.0, name='W2') 130 results = s.run([inp]) 131 self.assertAllEqual([20.0], results) 132 133 pool = config.session_inter_op_thread_pool.add() 134 pool.num_threads = 1 135 pool.global_name = 't1' 136 run_options = config_pb2.RunOptions() 137 run_options.inter_op_thread_pool = ( 138 len(config.session_inter_op_thread_pool) - 1) 139 with session.Session(config=config) as s: 140 inp = constant_op.constant(30.0, name='W2') 141 results = s.run([inp], options=run_options) 142 self.assertAllEqual([30.0], results) 143 144 def testErrorsReported(self): 145 with session.Session() as s: 146 constant_op.constant(10.0, name='W1') 147 with self.assertRaises(ValueError): 148 s.run('foo:0') 149 150 def testErrorPayload(self): 151 with session.Session(): 152 a = array_ops.placeholder(dtypes.float32) 153 with self.assertRaisesOpError(lambda e: e.op == a.op): 154 a.eval() 155 156 def testErrorCodeWithNoNodeDef(self): 157 with session.Session() as s: 158 a = array_ops.placeholder(dtypes.float32, shape=[]) 159 b = array_ops.placeholder(dtypes.float32, shape=[]) 160 r1 = math_ops.add(a, b) 161 162 def exc_predicate(e): 163 return (e.op is None and e.node_def is None and 164 e.error_code == error_codes_pb2.INVALID_ARGUMENT) 165 166 with self.assertRaisesOpError(exc_predicate): 167 # Run with a bogus handle. 168 s.partial_run('foo', r1, feed_dict={a: 1, b: 2}) 169 170 def testOpConstructionErrorPayload(self): 171 if ops._USE_C_API: 172 return # No shape registration for 'ConstructionFails' 173 174 with session.Session(): 175 failing_op = ops.get_default_graph().create_op( 176 'ConstructionFails', [], [], name='f') 177 178 def exc_predicate(e): 179 return (e.op == failing_op and 180 e.error_code == error_codes_pb2.INVALID_ARGUMENT) 181 182 with self.assertRaisesOpError(exc_predicate): 183 failing_op.run() 184 185 def testErrorBasedOn(self): 186 with session.Session() as sess: 187 a = constant_op.constant(0.0, shape=[2, 3]) 188 # NOTE(mrry): The original_op is nonsense, but used here to test that the 189 # errors are reported correctly. 190 # pylint: disable=protected-access 191 with sess.graph._original_op(a.op): 192 b = array_ops.identity(a, name='id') 193 with sess.graph._original_op(b.op): 194 c = array_ops.placeholder(dtypes.float32) 195 # pylint: enable=protected-access 196 197 def exc_predicate(e): 198 return (e.op == c.op and e.op._original_op == b.op and 199 e.op._original_op._original_op == a.op) 200 201 with self.assertRaisesOpError(exc_predicate): 202 c.eval() 203 204 def testFetchNone(self): 205 with session.Session() as s: 206 a = constant_op.constant(1.0) 207 with self.assertRaises(TypeError): 208 s.run(None) 209 with self.assertRaises(TypeError): 210 s.run([None]) 211 with self.assertRaises(TypeError): 212 s.run({'b': None}) 213 with self.assertRaises(TypeError): 214 s.run({'a': a, 'b': None}) 215 216 def testFetchSingleton(self): 217 with session.Session() as sess: 218 a = constant_op.constant(42.0) 219 res = sess.run(a) 220 self.assertEqual(42.0, res) 221 res = sess.run(a.op) # An op, not a tensor. 222 self.assertEqual(None, res) 223 tensor_runner = sess.make_callable(a) 224 res = tensor_runner() 225 self.assertEqual(42.0, res) 226 op_runner = sess.make_callable(a.op) 227 res = op_runner() 228 self.assertEqual(None, res) 229 230 def testFetchSingletonByName(self): 231 with session.Session() as sess: 232 a = constant_op.constant(42.0) 233 res = sess.run(a.name) 234 self.assertEqual(42.0, res) 235 res = sess.run(a.op) # An op, not a tensor. 236 self.assertEqual(None, res) 237 238 def testFetchList(self): 239 with session.Session() as sess: 240 a = constant_op.constant(42.0) 241 b = control_flow_ops.no_op() # An op, not a tensor. 242 c = constant_op.constant(44.0) 243 v = variables.Variable([54.0]) 244 assign = v.assign([63.0]) 245 res = sess.run([a, b, c, a.name, assign.op]) 246 self.assertTrue(isinstance(res, list)) 247 self.assertEqual([42.0, None, 44.0, 42.0, None], res) 248 list_runner = sess.make_callable([a, b, c, a.name, assign.op]) 249 res = list_runner() 250 self.assertTrue(isinstance(res, list)) 251 self.assertEqual([42.0, None, 44.0, 42.0, None], res) 252 253 def testFetchTuple(self): 254 with session.Session() as sess: 255 a = constant_op.constant(42.0) 256 b = control_flow_ops.no_op() # An op, not a tensor. 257 c = constant_op.constant(44.0) 258 res = sess.run((a, b, c, a.name)) 259 self.assertTrue(isinstance(res, tuple)) 260 self.assertEqual((42.0, None, 44.0, 42.0), res) 261 tuple_runner = sess.make_callable((a, b, c, a.name)) 262 res = tuple_runner() 263 self.assertTrue(isinstance(res, tuple)) 264 self.assertEqual((42.0, None, 44.0, 42.0), res) 265 266 def testFetchNamedTuple(self): 267 # pylint: disable=invalid-name 268 ABC = collections.namedtuple('ABC', ['a', 'b', 'c']) 269 # pylint: enable=invalid-name 270 with session.Session() as sess: 271 a = constant_op.constant(42.0) 272 b = control_flow_ops.no_op() # An op, not a tensor. 273 c = constant_op.constant(44.0) 274 res = sess.run(ABC(a, b, c)) 275 self.assertTrue(isinstance(res, ABC)) 276 self.assertEqual(42.0, res.a) 277 self.assertEqual(None, res.b) 278 self.assertEqual(44.0, res.c) 279 namedtuple_runner = sess.make_callable(ABC(a, b, c)) 280 res = namedtuple_runner() 281 self.assertTrue(isinstance(res, ABC)) 282 self.assertEqual(42.0, res.a) 283 self.assertEqual(None, res.b) 284 self.assertEqual(44.0, res.c) 285 286 def testFetchDict(self): 287 with session.Session() as sess: 288 a = constant_op.constant(42.0) 289 b = control_flow_ops.no_op() # An op, not a tensor. 290 c = constant_op.constant(44.0) 291 res = sess.run({'a': a, 'b': b, 'c': c}) 292 self.assertTrue(isinstance(res, dict)) 293 self.assertEqual(42.0, res['a']) 294 self.assertEqual(None, res['b']) 295 self.assertEqual(44.0, res['c']) 296 297 def testFetchOrderedDict(self): 298 with session.Session() as sess: 299 a = constant_op.constant(42.0) 300 b = control_flow_ops.no_op() # An op, not a tensor. 301 c = constant_op.constant(44.0) 302 res = sess.run(collections.OrderedDict([(3, a), (2, b), (1, c)])) 303 self.assertTrue(isinstance(res, collections.OrderedDict)) 304 self.assertEqual([3, 2, 1], list(res.keys())) 305 self.assertEqual(42.0, res[3]) 306 self.assertEqual(None, res[2]) 307 self.assertEqual(44.0, res[1]) 308 309 def testFetchNestingEmptyOneLevel(self): 310 with session.Session() as sess: 311 a_val = 11.0 312 a = constant_op.constant(a_val) 313 314 res = sess.run([[], tuple(), {}]) 315 self.assertTrue(isinstance(res, list)) 316 self.assertEquals(3, len(res)) 317 self.assertTrue(isinstance(res[0], list)) 318 self.assertEqual(0, len(res[0])) 319 self.assertTrue(isinstance(res[1], tuple)) 320 self.assertEqual(0, len(res[1])) 321 self.assertTrue(isinstance(res[2], dict)) 322 self.assertEqual(0, len(res[2])) 323 324 res = sess.run([[], tuple(), {}, a]) 325 self.assertTrue(isinstance(res, list)) 326 self.assertEquals(4, len(res)) 327 self.assertTrue(isinstance(res[0], list)) 328 self.assertEqual(0, len(res[0])) 329 self.assertTrue(isinstance(res[1], tuple)) 330 self.assertEqual(0, len(res[1])) 331 self.assertTrue(isinstance(res[2], dict)) 332 self.assertEqual(0, len(res[2])) 333 self.assertEqual(a_val, res[3]) 334 335 def testFetchNestingOneLevel(self): 336 with session.Session() as sess: 337 # pylint: disable=invalid-name 338 ABC = collections.namedtuple('ABC', ['a', 'b', 'c']) 339 DEFG = collections.namedtuple('DEFG', ['d', 'e', 'f', 'g']) 340 # pylint: enable=invalid-name 341 a_val = 42.0 342 b_val = None 343 c_val = 44.0 344 a = constant_op.constant(a_val) 345 b = control_flow_ops.no_op() # An op, not a tensor. 346 c = constant_op.constant(c_val) 347 # List of lists, tuples, namedtuple, and dict 348 res = sess.run([[a, b, c], (a, b, c), 349 ABC(a=a, b=b, c=c), { 350 'a': a.name, 351 'c': c, 352 'b': b 353 }]) 354 self.assertTrue(isinstance(res, list)) 355 self.assertEqual(4, len(res)) 356 self.assertTrue(isinstance(res[0], list)) 357 self.assertEqual(3, len(res[0])) 358 self.assertEqual(a_val, res[0][0]) 359 self.assertEqual(b_val, res[0][1]) 360 self.assertEqual(c_val, res[0][2]) 361 self.assertTrue(isinstance(res[1], tuple)) 362 self.assertEqual(3, len(res[1])) 363 self.assertEqual(a_val, res[1][0]) 364 self.assertEqual(b_val, res[1][1]) 365 self.assertEqual(c_val, res[1][2]) 366 self.assertTrue(isinstance(res[2], ABC)) 367 self.assertEqual(a_val, res[2].a) 368 self.assertEqual(b_val, res[2].b) 369 self.assertEqual(c_val, res[2].c) 370 self.assertTrue(isinstance(res[3], dict)) 371 self.assertEqual(3, len(res[3])) 372 self.assertEqual(a_val, res[3]['a']) 373 self.assertEqual(b_val, res[3]['b']) 374 self.assertEqual(c_val, res[3]['c']) 375 # Tuple of lists, tuples, namedtuple, and dict 376 res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b, c=c), { 377 'a': a, 378 'c': c, 379 'b': b 380 })) 381 self.assertTrue(isinstance(res, tuple)) 382 self.assertEqual(4, len(res)) 383 self.assertTrue(isinstance(res[0], list)) 384 self.assertEqual(3, len(res[0])) 385 self.assertEqual(a_val, res[0][0]) 386 self.assertEqual(b_val, res[0][1]) 387 self.assertEqual(c_val, res[0][2]) 388 self.assertTrue(isinstance(res[1], tuple)) 389 self.assertEqual(3, len(res[1])) 390 self.assertEqual(a_val, res[1][0]) 391 self.assertEqual(b_val, res[1][1]) 392 self.assertEqual(c_val, res[1][2]) 393 self.assertTrue(isinstance(res[2], ABC)) 394 self.assertEqual(a_val, res[2].a) 395 self.assertEqual(b_val, res[2].b) 396 self.assertEqual(c_val, res[2].c) 397 self.assertTrue(isinstance(res[3], dict)) 398 self.assertEqual(3, len(res[3])) 399 self.assertEqual(a_val, res[3]['a']) 400 self.assertEqual(b_val, res[3]['b']) 401 self.assertEqual(c_val, res[3]['c']) 402 # Namedtuple of lists, tuples, namedtuples, and dict 403 res = sess.run( 404 DEFG( 405 d=[a, b, c], 406 e=(a, b, c), 407 f=ABC(a=a.name, b=b, c=c), 408 g={ 409 'a': a, 410 'c': c, 411 'b': b 412 })) 413 self.assertTrue(isinstance(res, DEFG)) 414 self.assertTrue(isinstance(res.d, list)) 415 self.assertEqual(3, len(res.d)) 416 self.assertEqual(a_val, res.d[0]) 417 self.assertEqual(b_val, res.d[1]) 418 self.assertEqual(c_val, res.d[2]) 419 self.assertTrue(isinstance(res.e, tuple)) 420 self.assertEqual(3, len(res.e)) 421 self.assertEqual(a_val, res.e[0]) 422 self.assertEqual(b_val, res.e[1]) 423 self.assertEqual(c_val, res.e[2]) 424 self.assertTrue(isinstance(res.f, ABC)) 425 self.assertEqual(a_val, res.f.a) 426 self.assertEqual(b_val, res.f.b) 427 self.assertEqual(c_val, res.f.c) 428 self.assertTrue(isinstance(res.g, dict)) 429 self.assertEqual(3, len(res.g)) 430 self.assertEqual(a_val, res.g['a']) 431 self.assertEqual(b_val, res.g['b']) 432 self.assertEqual(c_val, res.g['c']) 433 # Dict of lists, tuples, namedtuples, and dict 434 res = sess.run({ 435 'd': [a, b, c], 436 'e': (a, b, c), 437 'f': ABC(a=a, b=b, c=c), 438 'g': { 439 'a': a.name, 440 'c': c, 441 'b': b 442 } 443 }) 444 self.assertTrue(isinstance(res, dict)) 445 self.assertEqual(4, len(res)) 446 self.assertTrue(isinstance(res['d'], list)) 447 self.assertEqual(3, len(res['d'])) 448 self.assertEqual(a_val, res['d'][0]) 449 self.assertEqual(b_val, res['d'][1]) 450 self.assertEqual(c_val, res['d'][2]) 451 self.assertTrue(isinstance(res['e'], tuple)) 452 self.assertEqual(3, len(res['e'])) 453 self.assertEqual(a_val, res['e'][0]) 454 self.assertEqual(b_val, res['e'][1]) 455 self.assertEqual(c_val, res['e'][2]) 456 self.assertTrue(isinstance(res['f'], ABC)) 457 self.assertEqual(a_val, res['f'].a) 458 self.assertEqual(b_val, res['f'].b) 459 self.assertEqual(c_val, res['f'].c) 460 self.assertTrue(isinstance(res['g'], dict)) 461 self.assertEqual(3, len(res['g'])) 462 self.assertEqual(a_val, res['g']['a']) 463 self.assertEqual(b_val, res['g']['b']) 464 self.assertEqual(c_val, res['g']['c']) 465 466 def testFetchTensorObject(self): 467 with session.Session() as s: 468 a = constant_op.constant(1.0, shape=[1, 2]) 469 b = constant_op.constant(2.0, shape=[2, 3]) 470 c = math_ops.matmul(a, b) 471 results_with_list = s.run([c]) 472 self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_list[0]) 473 results_with_single = s.run(c) 474 self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_single) 475 results_with_get = c.eval() 476 self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_get) 477 a_val, b_val = s.run([a, b]) # Test multiple fetches. 478 self.assertAllEqual([[1.0, 1.0]], a_val) 479 self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], b_val) 480 results_with_dict = s.run({'a': [a], 'b': b, 'z': [a, b]}) 481 self.assertAllEqual([[1.0, 1.0]], results_with_dict['a'][0]) 482 self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], 483 results_with_dict['b']) 484 self.assertAllEqual(results_with_dict['a'][0], results_with_dict['z'][0]) 485 self.assertAllEqual(results_with_dict['b'], results_with_dict['z'][1]) 486 487 # Test nested structures 488 results_with_nested_list = s.run([[[a, b], b], a, [a, b]]) 489 self.assertAllEqual([[1.0, 1.0]], results_with_nested_list[0][0][0]) 490 self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], 491 results_with_nested_list[0][0][1]) 492 self.assertAllEqual(results_with_nested_list[0][0][0], 493 results_with_nested_list[1]) 494 self.assertAllEqual(results_with_nested_list[1], 495 results_with_nested_list[2][0]) 496 self.assertAllEqual(results_with_nested_list[0][0][1], 497 results_with_nested_list[0][1]) 498 self.assertAllEqual(results_with_nested_list[0][1], 499 results_with_nested_list[2][1]) 500 501 def testFetchScalar(self): 502 with session.Session() as s: 503 for scalar in np.int32, np.int64, np.float16, np.float32, np.float64: 504 x = scalar(7) 505 y = scalar(8) 506 tf_x = constant_op.constant(x, shape=[]) 507 tf_y = constant_op.constant(y) 508 tf_xy = math_ops.add(tf_x, tf_y) 509 # Single fetch 510 xy = s.run(tf_xy) 511 self.assertEqual(scalar, type(xy)) 512 self.assertEqual(x + y, xy) 513 # List fetch 514 xy, = s.run([tf_xy]) 515 self.assertEqual(scalar, type(xy)) 516 self.assertEqual(x + y, xy) 517 # Dict fetch 518 xy = s.run({'xy': tf_xy})['xy'] 519 self.assertEqual(scalar, type(xy)) 520 self.assertEqual(x + y, xy) 521 # Nested list fetch 522 xy = s.run([[[tf_xy]], tf_xy, [tf_xy]]) 523 self.assertAllEqual(xy, [[[x + y]], x + y, [x + y]]) 524 self.assertEqual(scalar, type(xy[0][0][0])) 525 self.assertEqual(scalar, type(xy[1])) 526 self.assertEqual(scalar, type(xy[2][0])) 527 528 def testFetchOperationObject(self): 529 with session.Session() as s: 530 a = constant_op.constant(1.0, shape=[1, 2]) 531 v = variables.Variable(a, name='testFetchOperationObject_v') 532 s.run(v.initializer) 533 v_val = s.run(v) 534 self.assertAllEqual([[1.0, 1.0]], v_val) 535 536 def testFetchSparseTensor(self): 537 with session.Session() as s: 538 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 539 values = np.array([1.0, 2.0]).astype(np.float32) 540 shape = np.array([7, 9, 2]).astype(np.int64) 541 sp = sparse_tensor.SparseTensor( 542 constant_op.constant(indices), constant_op.constant(values), 543 constant_op.constant(shape)) 544 # Single fetch, use as tuple 545 sp_out = s.run(sp) 546 indices_out, values_out, shape_out = sp_out 547 self.assertAllEqual(indices_out, indices) 548 self.assertAllEqual(values_out, values) 549 self.assertAllEqual(shape_out, shape) 550 # Single fetch, use as SparseTensorValue 551 sp_out = s.run(sp) 552 self.assertAllEqual(sp_out.indices, indices) 553 self.assertAllEqual(sp_out.values, values) 554 self.assertAllEqual(sp_out.dense_shape, shape) 555 # Tuple fetch, use as tuple 556 indices_out, values_out, shape_out = s.run(sp) 557 self.assertAllEqual(indices_out, indices) 558 self.assertAllEqual(values_out, values) 559 self.assertAllEqual(shape_out, shape) 560 # List fetch, use as tuple 561 (indices_out, values_out, shape_out), = s.run([sp]) 562 self.assertAllEqual(indices_out, indices) 563 self.assertAllEqual(values_out, values) 564 self.assertAllEqual(shape_out, shape) 565 # List fetch, use as SparseTensorValue 566 sp_out, = s.run([sp]) 567 self.assertAllEqual(sp_out.indices, indices) 568 self.assertAllEqual(sp_out.values, values) 569 self.assertAllEqual(sp_out.dense_shape, shape) 570 # Dict fetch (single value), use as tuple 571 indices_out, values_out, shape_out = s.run({'sp': sp})['sp'] 572 self.assertAllEqual(indices_out, indices) 573 self.assertAllEqual(values_out, values) 574 self.assertAllEqual(shape_out, shape) 575 # Dict fetch (list value), use as tuple 576 (indices_out, values_out, shape_out), = s.run({'sp': [sp]})['sp'] 577 self.assertAllEqual(indices_out, indices) 578 self.assertAllEqual(values_out, values) 579 self.assertAllEqual(shape_out, shape) 580 # Dict fetch, use as SparseTensorValue 581 sp_out = s.run({'sp': sp})['sp'] 582 self.assertAllEqual(sp_out.indices, indices) 583 self.assertAllEqual(sp_out.values, values) 584 self.assertAllEqual(sp_out.dense_shape, shape) 585 # Nested list fetch use as tuple 586 sp_out = s.run([[[sp]], sp]) 587 indices_out, values_out, shape_out = sp_out[0][0][0] 588 self.assertAllEqual(indices_out, indices) 589 self.assertAllEqual(values_out, values) 590 self.assertAllEqual(shape_out, shape) 591 indices_out, values_out, shape_out = sp_out[1] 592 self.assertAllEqual(indices_out, indices) 593 self.assertAllEqual(values_out, values) 594 self.assertAllEqual(shape_out, shape) 595 # Nested list fetch, use as SparseTensorValue 596 sp_out = s.run([[[sp]], sp]) 597 self.assertAllEqual(sp_out[0][0][0].indices, indices) 598 self.assertAllEqual(sp_out[0][0][0].values, values) 599 self.assertAllEqual(sp_out[0][0][0].dense_shape, shape) 600 self.assertAllEqual(sp_out[1].indices, indices) 601 self.assertAllEqual(sp_out[1].values, values) 602 self.assertAllEqual(sp_out[1].dense_shape, shape) 603 604 def testFeedSparseTensor(self): 605 with session.Session() as s: 606 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 607 values = np.array([1.0, 2.0]).astype(np.float32) 608 shape = np.array([7, 9, 2]).astype(np.int64) 609 sp = sparse_tensor.SparseTensor( 610 array_ops.placeholder(dtype=np.int64, shape=(2, 3)), 611 array_ops.placeholder(dtype=np.float32, shape=(2,)), 612 array_ops.placeholder(dtype=np.int64, shape=(3,)), 613 ) 614 sp_indices = array_ops.identity(sp.indices) 615 sp_values = array_ops.identity(sp.values) 616 sp_shape = array_ops.identity(sp.dense_shape) 617 sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape) 618 # Feed with tuple 619 indices_out, values_out, shape_out = s.run( 620 [sp_indices, sp_values, sp_shape], { 621 sp: (indices, values, shape) 622 }) 623 self.assertAllEqual(indices_out, indices) 624 self.assertAllEqual(values_out, values) 625 self.assertAllEqual(shape_out, shape) 626 # Feed with tuple, fetch sp directly 627 sp_out = s.run(sp, {sp: (indices, values, shape)}) 628 self.assertAllEqual(sp_out.indices, indices) 629 self.assertAllEqual(sp_out.values, values) 630 self.assertAllEqual(sp_out.dense_shape, shape) 631 # Feed with SparseTensorValue 632 indices_out, values_out, shape_out = s.run( 633 [sp_indices, sp_values, sp_shape], { 634 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 635 }) 636 self.assertAllEqual(indices_out, indices) 637 self.assertAllEqual(values_out, values) 638 self.assertAllEqual(shape_out, shape) 639 # Feed with SparseTensorValue, fetch SparseTensorValue 640 sp2_out = s.run(sp2, { 641 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 642 }) 643 self.assertAllEqual(sp2_out.indices, indices) 644 self.assertAllEqual(sp2_out.values, values) 645 self.assertAllEqual(sp2_out.dense_shape, shape) 646 # Feed SparseTensorValue and fetch sp directly. 647 sp_out = s.run(sp, { 648 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 649 }) 650 self.assertAllEqual(sp_out.indices, indices) 651 self.assertAllEqual(sp_out.values, values) 652 self.assertAllEqual(sp_out.dense_shape, shape) 653 654 def testFeedSparsePlaceholder(self): 655 with session.Session() as s: 656 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 657 values = np.array([1.0, 2.0]).astype(np.float32) 658 shape = np.array([7, 9, 2]).astype(np.int64) 659 sp = array_ops.sparse_placeholder(dtype=np.float32, name='placeholder1') 660 sp_indices = array_ops.identity(sp.indices) 661 sp_values = array_ops.identity(sp.values) 662 sp_shape = array_ops.identity(sp.dense_shape) 663 sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape) 664 # Feed with tuple 665 indices_out, values_out, shape_out = s.run( 666 [sp_indices, sp_values, sp_shape], { 667 sp: (indices, values, shape) 668 }) 669 self.assertAllEqual(indices_out, indices) 670 self.assertAllEqual(values_out, values) 671 self.assertAllEqual(shape_out, shape) 672 # Feed with SparseTensorValue 673 indices_out, values_out, shape_out = s.run( 674 [sp_indices, sp_values, sp_shape], { 675 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 676 }) 677 self.assertAllEqual(indices_out, indices) 678 self.assertAllEqual(values_out, values) 679 self.assertAllEqual(shape_out, shape) 680 # Feed with SparseTensorValue, fetch SparseTensorValue 681 sp2_out = s.run(sp2, { 682 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 683 }) 684 self.assertAllEqual(sp2_out.indices, indices) 685 self.assertAllEqual(sp2_out.values, values) 686 self.assertAllEqual(sp2_out.dense_shape, shape) 687 688 def testFeedSparsePlaceholderPartialShape(self): 689 with session.Session() as s: 690 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 691 values = np.array([1.0, 2.0]).astype(np.float32) 692 shape = np.array([7, 9, 2]).astype(np.int64) 693 sp = array_ops.sparse_placeholder( 694 shape=[None, 9, 2], dtype=np.float32, name='placeholder1') 695 sp_indices = array_ops.identity(sp.indices) 696 sp_values = array_ops.identity(sp.values) 697 sp_shape = array_ops.identity(sp.dense_shape) 698 sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape) 699 # Feed with tuple 700 indices_out, values_out, shape_out = s.run( 701 [sp_indices, sp_values, sp_shape], { 702 sp: (indices, values, shape) 703 }) 704 self.assertAllEqual(indices_out, indices) 705 self.assertAllEqual(values_out, values) 706 self.assertAllEqual(shape_out, shape) 707 # Feed with SparseTensorValue 708 indices_out, values_out, shape_out = s.run( 709 [sp_indices, sp_values, sp_shape], { 710 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 711 }) 712 self.assertAllEqual(indices_out, indices) 713 self.assertAllEqual(values_out, values) 714 self.assertAllEqual(shape_out, shape) 715 # Feed with SparseTensorValue, fetch SparseTensorValue 716 sp2_out = s.run(sp2, { 717 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 718 }) 719 self.assertAllEqual(sp2_out.indices, indices) 720 self.assertAllEqual(sp2_out.values, values) 721 self.assertAllEqual(sp2_out.dense_shape, shape) 722 723 def testFeedSparsePlaceholderConstantShape(self): 724 with session.Session() as s: 725 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 726 values = np.array([1.0, 2.0]).astype(np.float32) 727 shape = np.array([7, 9, 2]).astype(np.int64) 728 sp = array_ops.sparse_placeholder( 729 dtype=np.float32, shape=shape, name='placeholder1') 730 self.assertAllEqual(sp.dense_shape.eval(session=s), shape) 731 self.assertAllEqual(tensor_util.constant_value(sp.dense_shape), shape) 732 sp_indices = array_ops.identity(sp.indices) 733 sp_values = array_ops.identity(sp.values) 734 sp_shape = array_ops.identity(sp.dense_shape) 735 # Feed with tuple 736 indices_out, values_out, shape_out = s.run( 737 [sp_indices, sp_values, sp_shape], { 738 sp: (indices, values) 739 }) 740 self.assertAllEqual(indices_out, indices) 741 self.assertAllEqual(values_out, values) 742 self.assertAllEqual(shape_out, shape) 743 744 def testFetchIndexedSlices(self): 745 with session.Session() as s: 746 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 747 values = np.array([1.0, 2.0]).astype(np.float32) 748 dense_shape = np.array([7, 9, 2]).astype(np.int64) 749 ind = ops.IndexedSlices( 750 constant_op.constant(values), constant_op.constant(indices), 751 constant_op.constant(dense_shape)) 752 # Single fetch, use as tuple 753 ind_out = s.run(ind) 754 values_out, indices_out, dense_shape_out = ind_out 755 self.assertAllEqual(values_out, values) 756 self.assertAllEqual(indices_out, indices) 757 self.assertAllEqual(dense_shape_out, dense_shape) 758 # Single fetch, use as IndexedSlicesValue 759 ind_out = s.run(ind) 760 self.assertAllEqual(ind_out.values, values) 761 self.assertAllEqual(ind_out.indices, indices) 762 self.assertAllEqual(ind_out.dense_shape, dense_shape) 763 # Tuple fetch, use as tuple 764 values_out, indices_out, dense_shape_out = s.run(ind) 765 self.assertAllEqual(values_out, values) 766 self.assertAllEqual(indices_out, indices) 767 self.assertAllEqual(dense_shape_out, dense_shape) 768 # List fetch, use as tuple 769 (values_out, indices_out, dense_shape_out), = s.run([ind]) 770 self.assertAllEqual(values_out, values) 771 self.assertAllEqual(indices_out, indices) 772 self.assertAllEqual(dense_shape_out, dense_shape) 773 # List fetch, use as IndexedSlicesValue 774 ind_out, = s.run([ind]) 775 self.assertAllEqual(ind_out.values, values) 776 self.assertAllEqual(ind_out.indices, indices) 777 self.assertAllEqual(ind_out.dense_shape, dense_shape) 778 779 def testFeedIndexedSlices(self): 780 with session.Session() as s: 781 values = np.array([1.0, 2.0]).astype(np.float32) 782 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 783 dense_shape = np.array([7, 9, 2]).astype(np.int64) 784 ind = ops.IndexedSlices( 785 array_ops.placeholder(dtype=np.float32, shape=(2,)), 786 array_ops.placeholder(dtype=np.int64, shape=(2, 3)), 787 array_ops.placeholder(dtype=np.int64, shape=(3,)), 788 ) 789 ind_values = array_ops.identity(ind.values) 790 ind_indices = array_ops.identity(ind.indices) 791 ind_dense_shape = array_ops.identity(ind.dense_shape) 792 ind2 = ops.IndexedSlices(ind_values, ind_indices, ind_dense_shape) 793 # Feed with tuple 794 values_out, indices_out, dense_shape_out = s.run( 795 [ind_values, ind_indices, ind_dense_shape], { 796 ind: (values, indices, dense_shape) 797 }) 798 self.assertAllEqual(values_out, values) 799 self.assertAllEqual(indices_out, indices) 800 self.assertAllEqual(dense_shape_out, dense_shape) 801 # Feed with IndexedSlicesValue 802 values_out, indices_out, dense_shape_out = s.run( 803 [ind_values, ind_indices, ind_dense_shape], { 804 ind: ops.IndexedSlicesValue(values, indices, dense_shape) 805 }) 806 self.assertAllEqual(values_out, values) 807 self.assertAllEqual(indices_out, indices) 808 self.assertAllEqual(dense_shape_out, dense_shape) 809 # Feed with IndexedSlicesValue, fetch IndexedSlicesValue 810 ind2_out = s.run(ind2, { 811 ind: ops.IndexedSlicesValue(values, indices, dense_shape) 812 }) 813 self.assertAllEqual(ind2_out.values, values) 814 self.assertAllEqual(ind2_out.indices, indices) 815 self.assertAllEqual(ind2_out.dense_shape, dense_shape) 816 817 def testFetchIndexedSlicesWithoutDenseShape(self): 818 with session.Session() as s: 819 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 820 values = np.array([1.0, 2.0]).astype(np.float32) 821 dense_shape = None 822 ind = ops.IndexedSlices( 823 constant_op.constant(values), constant_op.constant(indices), None) 824 # Single fetch, use as tuple 825 ind_out = s.run(ind) 826 values_out, indices_out, dense_shape_out = ind_out 827 self.assertAllEqual(values_out, values) 828 self.assertAllEqual(indices_out, indices) 829 self.assertAllEqual(dense_shape_out, dense_shape) 830 # Single fetch, use as IndexedSlicesValue 831 ind_out = s.run(ind) 832 self.assertAllEqual(ind_out.values, values) 833 self.assertAllEqual(ind_out.indices, indices) 834 self.assertAllEqual(ind_out.dense_shape, dense_shape) 835 # Tuple fetch, use as tuple 836 values_out, indices_out, dense_shape_out = s.run(ind) 837 self.assertAllEqual(values_out, values) 838 self.assertAllEqual(indices_out, indices) 839 self.assertAllEqual(dense_shape_out, dense_shape) 840 # List fetch, use as tuple 841 (values_out, indices_out, dense_shape_out), = s.run([ind]) 842 self.assertAllEqual(values_out, values) 843 self.assertAllEqual(indices_out, indices) 844 self.assertAllEqual(dense_shape_out, dense_shape) 845 # List fetch, use as IndexedSlicesValue 846 ind_out, = s.run([ind]) 847 self.assertAllEqual(ind_out.values, values) 848 self.assertAllEqual(ind_out.indices, indices) 849 self.assertAllEqual(ind_out.dense_shape, dense_shape) 850 851 def testFeedIndexedSlicesWithoutDenseShape(self): 852 with session.Session() as s: 853 values = np.array([1.0, 2.0]).astype(np.float32) 854 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 855 dense_shape = None 856 ind = ops.IndexedSlices( 857 array_ops.placeholder(dtype=np.float32, shape=(2,)), 858 array_ops.placeholder(dtype=np.int64, shape=(2, 3)), None) 859 ind_values = array_ops.identity(ind.values) 860 ind_indices = array_ops.identity(ind.indices) 861 ind2 = ops.IndexedSlices(ind_values, ind_indices) 862 # Feed with tuple 863 values_out, indices_out = s.run([ind_values, ind_indices], { 864 ind: (values, indices) 865 }) 866 self.assertAllEqual(values_out, values) 867 self.assertAllEqual(indices_out, indices) 868 # Feed with IndexedSlicesValue 869 values_out, indices_out = s.run([ind_values, ind_indices], { 870 ind: ops.IndexedSlicesValue(values, indices, dense_shape) 871 }) 872 self.assertAllEqual(values_out, values) 873 self.assertAllEqual(indices_out, indices) 874 # Feed with IndexedSlicesValue, fetch IndexedSlicesValue 875 ind2_out = s.run(ind2, { 876 ind: ops.IndexedSlicesValue(values, indices, dense_shape) 877 }) 878 self.assertAllEqual(ind2_out.values, values) 879 self.assertAllEqual(ind2_out.indices, indices) 880 self.assertAllEqual(ind2_out.dense_shape, dense_shape) 881 882 def testExtendWithStatelessOperations(self): 883 with session.Session() as s: 884 a = constant_op.constant(1.0, shape=[1, 2]) 885 b = constant_op.constant(2.0, shape=[2, 3]) 886 c = math_ops.matmul(a, b) 887 c_val = s.run(c) 888 self.assertAllEqual([[4.0, 4.0, 4.0]], c_val) 889 d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1]) 890 e = math_ops.matmul(c, d) 891 # Extend will happen here. 892 e_val = s.run(e) 893 self.assertAllEqual([[24.0]], e_val) 894 895 def testExtendWithStatefulOperations(self): 896 with session.Session() as s: 897 a = constant_op.constant(1.0, shape=[1, 2]) 898 b = constant_op.constant(2.0, shape=[2, 3]) 899 c = math_ops.matmul(a, b) 900 v = variables.Variable(c, name='testExtendWithStatefulOperations_v') 901 v.initializer.run() 902 v_val = v.eval() 903 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 904 d = constant_op.constant(3.0, shape=[2, 3]) 905 e = math_ops.matmul(a, d) 906 assign_e_to_v = state_ops.assign(v, e) 907 # Extend will happen here. 908 e_val = e.eval() 909 self.assertAllEqual([[6.0, 6.0, 6.0]], e_val) 910 v_val = v.eval() 911 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 912 s.run(assign_e_to_v) 913 v_val = v.eval() 914 self.assertAllEqual([[6.0, 6.0, 6.0]], v_val) 915 916 def testExtendWithGroupBy(self): 917 with session.Session() as s: 918 a = constant_op.constant(1.0, shape=[1, 2]) 919 p = variables.Variable(a, name='testExtendWithGroupBy_p') 920 a_val = a.eval() # Force an Extend after this op. 921 self.assertAllEqual([[1.0, 1.0]], a_val) 922 923 b = constant_op.constant(2.0, shape=[1, 2]) 924 q = variables.Variable(b, name='testExtendWithGroupBy_q') 925 # Extend will happen here. 926 init = control_flow_ops.group(p.initializer, q.initializer) 927 s.run(init) 928 p_val, q_val = s.run([p, q]) 929 930 self.assertAllEqual([[1.0, 1.0]], p_val) 931 self.assertAllEqual([[2.0, 2.0]], q_val) 932 933 def testTensorGetMethod(self): 934 with session.Session(): 935 a = constant_op.constant(1.0, shape=[1, 2]) 936 b = constant_op.constant(2.0, shape=[2, 3]) 937 c = math_ops.matmul(a, b) 938 939 c_val = c.eval() 940 self.assertAllEqual([[4.0, 4.0, 4.0]], c_val) 941 942 fed_c_val = c.eval(feed_dict={a.name: [[4.0, 4.0]]}) 943 self.assertAllEqual([[16.0, 16.0, 16.0]], fed_c_val) 944 945 def testOperationRunMethod(self): 946 with session.Session(): 947 a = constant_op.constant(1.0, shape=[1, 2]) 948 b = constant_op.constant(2.0, shape=[1, 2], name='b') 949 v = variables.Variable(a, a.dtype) 950 assign_a_to_v = state_ops.assign(v, a) 951 952 assign_a_to_v.eval() 953 954 v_val = v.eval() 955 self.assertAllEqual([[1.0, 1.0]], v_val) 956 957 assign_b_to_v = state_ops.assign(v, b) 958 959 assign_b_to_v.eval() 960 v_val = v.eval() 961 self.assertAllEqual([[2.0, 2.0]], v_val) 962 963 assign_b_to_v.eval(feed_dict={'b:0': [[3.0, 3.0]]}) 964 v_val = v.eval() 965 self.assertAllEqual([[3.0, 3.0]], v_val) 966 967 def testDefaultGraph(self): 968 with session.Session() as s: 969 self.assertEqual(ops.get_default_graph(), s.graph) 970 a = constant_op.constant(1.0, shape=[1, 2]) 971 b = constant_op.constant(2.0, shape=[2, 3]) 972 self.assertEqual(ops.get_default_graph(), a.graph) 973 self.assertEqual(ops.get_default_graph(), b.graph) 974 c = math_ops.matmul(a, b) 975 v = variables.Variable(c, name='testDefaultGraph_v') 976 v.initializer.run() 977 v_val = v.eval() 978 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 979 d = constant_op.constant(3.0, shape=[2, 3]) 980 e = math_ops.matmul(a, d) 981 assign_e_to_v = state_ops.assign(v, e) 982 e_val = e.eval() 983 self.assertAllEqual([[6.0, 6.0, 6.0]], e_val) 984 v_val = v.eval() 985 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 986 s.run(assign_e_to_v) 987 v_val = v.eval() 988 self.assertAllEqual([[6.0, 6.0, 6.0]], v_val) 989 self.assertEqual(ops.get_default_graph(), s.graph) 990 991 def _testDefaultGraphInThread(self, constructed_event, continue_event, i): 992 with session.Session() as s: 993 self.assertEqual(ops.get_default_graph(), s.graph) 994 a = constant_op.constant(1.0, shape=[1, 2]) 995 b = constant_op.constant(2.0, shape=[2, 3]) 996 c = math_ops.matmul(a, b) 997 v = variables.Variable(c, name='var_%d' % i) 998 999 # Block here until all threads have constructed their graph. 1000 constructed_event.set() 1001 continue_event.wait() 1002 1003 assign_c_to_v = state_ops.assign(v, c) 1004 v.initializer.run() 1005 assign_c_to_v.eval() 1006 v_val = v.eval() 1007 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 1008 d = constant_op.constant(3.0, shape=[2, 3]) 1009 e = math_ops.matmul(a, d) 1010 assign_e_to_v = state_ops.assign(v, e) 1011 e_val = e.eval() 1012 self.assertAllEqual([[6.0, 6.0, 6.0]], e_val) 1013 v_val = v.eval() 1014 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 1015 s.run(assign_e_to_v) 1016 v_val = v.eval() 1017 self.assertAllEqual([[6.0, 6.0, 6.0]], v_val) 1018 self.assertEqual(ops.get_default_graph(), s.graph) 1019 1020 def testDefaultGraphWithThreads(self): 1021 # Fork ten threads that use their thread-local default graph. 1022 threads = [] 1023 constructed_events = [threading.Event() for _ in range(10)] 1024 continue_event = threading.Event() 1025 for i, constructed_event in enumerate(constructed_events): 1026 t = self.checkedThread( 1027 target=self._testDefaultGraphInThread, 1028 args=(constructed_event, continue_event, i)) 1029 threads.append(t) 1030 for t in threads: 1031 t.start() 1032 for constructed_event in constructed_events: 1033 constructed_event.wait() 1034 continue_event.set() 1035 for t in threads: 1036 t.join() 1037 1038 def testParallelRun(self): 1039 with session.Session() as sess: 1040 c = constant_op.constant(5.0) 1041 ev = threading.Event() 1042 1043 def run_step(): 1044 ev.wait() 1045 val = c.eval(session=sess) 1046 self.assertEqual(val, 5.0) 1047 1048 threads = [self.checkedThread(target=run_step) for _ in range(100)] 1049 for t in threads: 1050 t.start() 1051 ev.set() 1052 for t in threads: 1053 t.join() 1054 1055 def testRunFeedDict(self): 1056 with session.Session() as s: 1057 x = array_ops.zeros([2]) 1058 1059 y = s.run(2 * x, feed_dict={x: np.ones(2).astype(np.float32)}) 1060 self.assertAllEqual(y, 2 * np.ones(2)) 1061 1062 y = s.run(2 * x, feed_dict={x.name: np.ones(2).astype(np.float32)}) 1063 self.assertAllEqual(y, 2 * np.ones(2)) 1064 1065 y = s.run(2 * x, feed_dict={x: [1, 1]}) 1066 assert (y == 2 * np.ones(2)).all() 1067 1068 # Test nested tuple keys 1069 z = (((array_ops.zeros([2]),),), array_ops.zeros([2]), 1070 (array_ops.zeros([2]),)) 1071 result = [z[0][0][0] * 2, z[1] * 2, z[2][0] * 2] 1072 values = (((np.array([1, 1]),),), np.array([2, 2]), (np.array([3, 3]),)) 1073 result_value = s.run(result, feed_dict={z: values}) 1074 self.assertAllEqual(result_value[0], 2 * np.ones(2)) 1075 self.assertAllEqual(result_value[1], 2 * np.array([2, 2])) 1076 self.assertAllEqual(result_value[2], 2 * np.array([3, 3])) 1077 1078 def testGraphDef(self): 1079 with session.Session() as sess: 1080 self.assertProtoEquals('versions { producer: %d min_consumer: %d }' % 1081 (versions.GRAPH_DEF_VERSION, 1082 versions.GRAPH_DEF_VERSION_MIN_CONSUMER), 1083 sess.graph_def) 1084 c = constant_op.constant(5.0, name='c') 1085 self.assertEquals(len(sess.graph_def.node), 1) 1086 d = constant_op.constant(6.0, name='d') 1087 self.assertEquals(len(sess.graph_def.node), 2) 1088 self.assertAllEqual(c.eval(), 5.0) 1089 self.assertAllEqual(d.eval(), 6.0) 1090 e = constant_op.constant(7.0, name='e') 1091 self.assertEquals(len(sess.graph_def.node), 3) 1092 self.assertAllEqual(e.eval(), 7.0) 1093 1094 def testUseAfterClose(self): 1095 with session.Session() as sess: 1096 c = constant_op.constant(5.0) 1097 self.assertAllEqual(sess.run(c), 5.0) 1098 with self.assertRaisesWithPredicateMatch( 1099 RuntimeError, lambda e: 'Attempted to use a closed Session.' in str(e)): 1100 sess.run(c) 1101 1102 def testUseAfterCloseConcurrent(self): 1103 with session.Session() as sess: 1104 c = constant_op.constant(5.0) 1105 self.assertAllEqual(sess.run(c), 5.0) 1106 1107 def update_thread(): 1108 with self.assertRaisesWithPredicateMatch( 1109 RuntimeError, 1110 lambda e: 'Attempted to use a closed Session.' in str(e)): 1111 while True: 1112 sess.run(c) 1113 1114 t = threading.Thread(target=update_thread) 1115 t.start() 1116 time.sleep(0.1) 1117 sess.close() 1118 t.join() 1119 1120 def testUseEmptyGraph(self): 1121 with session.Session() as sess: 1122 with self.assertRaisesRegexp(RuntimeError, 'The Session graph is empty.'): 1123 sess.run([]) 1124 with self.assertRaisesRegexp(RuntimeError, 'The Session graph is empty.'): 1125 sess.run(()) 1126 with self.assertRaisesRegexp(RuntimeError, 'The Session graph is empty.'): 1127 sess.run({}) 1128 1129 def testNotEntered(self): 1130 # pylint: disable=protected-access 1131 self.assertEqual(ops._default_session_stack.get_default(), None) 1132 # pylint: enable=protected-access 1133 with ops.device('/cpu:0'): 1134 sess = session.Session() 1135 c_1 = constant_op.constant(5.0) 1136 with sess.graph.as_default(): 1137 c_2 = constant_op.constant(5.0) 1138 self.assertEqual(c_1.graph, c_2.graph) 1139 self.assertEqual(sess.run(c_2), 5.0) 1140 with self.assertRaisesWithPredicateMatch( 1141 ValueError, lambda e: 'No default session is registered.' in str(e)): 1142 c_2.eval() 1143 1144 def testInteractive(self): 1145 with ops.device('/cpu:0'): 1146 sess = session.InteractiveSession() 1147 a = constant_op.constant(1.0, shape=[1, 2]) 1148 b = constant_op.constant(2.0, shape=[2, 3]) 1149 c = math_ops.matmul(a, b) 1150 self.assertAllEqual([[4.0, 4.0, 4.0]], c.eval()) 1151 d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1]) 1152 e = math_ops.matmul(c, d) 1153 self.assertAllEqual([[24.0]], e.eval()) 1154 sess.close() 1155 1156 def testInteractivePlacePrunedGraph(self): 1157 sess = session.InteractiveSession() 1158 1159 # Build a graph that has a bad op in it (no kernel). 1160 # 1161 # This test currently does not link in any GPU kernels, 1162 # which is why placing this is invalid. If at some point 1163 # GPU kernels are added to this test, some other different 1164 # op / device combo should be chosen. 1165 with ops.device('/device:GPU:0'): 1166 a = constant_op.constant(1.0, shape=[1, 2]) 1167 1168 b = constant_op.constant(1.0, shape=[1, 2]) 1169 1170 # Only run the valid op, this should work. 1171 b.eval() 1172 1173 with self.assertRaises(errors.InvalidArgumentError): 1174 a.eval() 1175 sess.close() 1176 1177 def testDefaultSessionPlacePrunedGraph(self): 1178 sess = session.Session() 1179 1180 # Build a graph that has a bad op in it (no kernel). 1181 # 1182 # This test currently does not link in any GPU kernels, 1183 # which is why placing this is invalid. If at some point 1184 # GPU kernels are added to this test, some other different 1185 # op / device combo should be chosen. 1186 with ops.device('/device:GPU:0'): 1187 _ = constant_op.constant(1.0, shape=[1, 2]) 1188 1189 b = constant_op.constant(1.0, shape=[1, 2]) 1190 1191 with self.assertRaises(errors.InvalidArgumentError): 1192 # Even though we don't run the bad op, we place the entire 1193 # graph, which should fail with a non-interactive session. 1194 sess.run(b) 1195 1196 sess.close() 1197 1198 def testSharedGraph(self): 1199 with ops.Graph().as_default() as g, ops.device('/cpu:0'): 1200 a = constant_op.constant(1.0, shape=[1, 2]) 1201 b = constant_op.constant(2.0, shape=[2, 3]) 1202 c = math_ops.matmul(a, b) 1203 1204 with session.Session(graph=g) as sess1: 1205 with session.Session(graph=g) as sess2: 1206 self.assertAllEqual(sess1.run(c), sess2.run(c)) 1207 1208 def testDuplicatedInputs(self): 1209 with session.Session() as sess: 1210 a = constant_op.constant(1.0, shape=[1, 2]) 1211 b = constant_op.constant(2.0, shape=[1, 3]) 1212 a_val, b_val, a2_val = sess.run([a, b, a]) 1213 self.assertAllEqual(a_val, [[1.0, 1.0]]) 1214 self.assertAllEqual(b_val, [[2.0, 2.0, 2.0]]) 1215 self.assertAllEqual(a2_val, [[1.0, 1.0]]) 1216 1217 def testFeedAndFetch(self): 1218 with session.Session() as sess: 1219 for dtype in [ 1220 dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, 1221 dtypes.uint8, dtypes.int16, dtypes.int8, dtypes.int64, dtypes.bool, 1222 dtypes.complex64, dtypes.complex128 1223 ]: 1224 for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]: 1225 np_dtype = dtype.as_numpy_dtype 1226 1227 feed_t = array_ops.placeholder(dtype=dtype, shape=shape) 1228 out_t = array_ops.identity(feed_t) 1229 1230 np_array = np.random.randint(-10, 10, shape) 1231 1232 if dtype == dtypes.bool: 1233 np_array = np_array > 0 1234 elif dtype == dtypes.complex64: 1235 np_array = np.sqrt(np_array.astype(np_dtype)) 1236 elif dtype == dtypes.complex64: 1237 np_array = np.sqrt(np_array.astype(np_dtype)) 1238 else: 1239 np_array = np_array.astype(np_dtype) 1240 1241 self.assertAllEqual(np_array, 1242 sess.run(out_t, feed_dict={ 1243 feed_t: np_array 1244 })) 1245 # Check that we can also get the feed back. 1246 self.assertAllEqual(np_array, 1247 sess.run(feed_t, feed_dict={ 1248 feed_t: np_array 1249 })) 1250 # Also check that we can get both back. 1251 out_v, feed_v = sess.run( 1252 [out_t, feed_t], feed_dict={ 1253 feed_t: np_array 1254 }) 1255 self.assertAllEqual(np_array, out_v) 1256 self.assertAllEqual(np_array, feed_v) 1257 1258 feed_fetch_runner = sess.make_callable([out_t, feed_t], [feed_t]) 1259 out_v, feed_v = feed_fetch_runner(np_array) 1260 self.assertAllEqual(np_array, out_v) 1261 self.assertAllEqual(np_array, feed_v) 1262 1263 def testMakeCallableOnTensorWithRunOptions(self): 1264 with session.Session() as sess: 1265 a = constant_op.constant(42.0) 1266 tensor_runner = sess.make_callable(a, accept_options=True) 1267 run_options = config_pb2.RunOptions( 1268 trace_level=config_pb2.RunOptions.FULL_TRACE) 1269 run_metadata = config_pb2.RunMetadata() 1270 self.assertEqual(0, len(run_metadata.step_stats.dev_stats)) 1271 res = tensor_runner(options=run_options, run_metadata=run_metadata) 1272 self.assertEqual(42.0, res) 1273 self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) 1274 1275 def testMakeCallableOnOperationWithRunOptions(self): 1276 with session.Session() as sess: 1277 a = variables.Variable(42.0) 1278 b = state_ops.assign_add(a, 1.0) 1279 sess.run(a.initializer) 1280 tensor_runner = sess.make_callable(b.op, accept_options=True) 1281 run_options = config_pb2.RunOptions( 1282 trace_level=config_pb2.RunOptions.FULL_TRACE) 1283 run_metadata = config_pb2.RunMetadata() 1284 self.assertEqual(0, len(run_metadata.step_stats.dev_stats)) 1285 tensor_runner(options=run_options, run_metadata=run_metadata) 1286 self.assertEqual(43.0, sess.run(a)) 1287 self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) 1288 1289 def testMakeCallableWithFeedListAndRunOptions(self): 1290 with session.Session() as sess: 1291 ph = array_ops.placeholder(dtypes.float32) 1292 a = math_ops.add(ph, 1.0) 1293 tensor_runner = sess.make_callable( 1294 a, feed_list=[ph.name], accept_options=True) 1295 run_options = config_pb2.RunOptions( 1296 trace_level=config_pb2.RunOptions.FULL_TRACE) 1297 run_metadata = config_pb2.RunMetadata() 1298 self.assertEqual(0, len(run_metadata.step_stats.dev_stats)) 1299 self.assertAllClose(42.0, 1300 tensor_runner( 1301 41.0, 1302 options=run_options, 1303 run_metadata=run_metadata)) 1304 self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) 1305 1306 def testFeedError(self): 1307 with session.Session() as sess: 1308 feed_t = array_ops.placeholder(dtype=dtypes.float32) 1309 out_t = array_ops.identity(feed_t) 1310 feed_val = constant_op.constant(5.0) 1311 with self.assertRaisesRegexp(TypeError, 'cannot be a tf.Tensor object'): 1312 sess.run(out_t, feed_dict={feed_t: feed_val}) 1313 with self.assertRaisesRegexp(TypeError, 'cannot be a tf.Tensor object'): 1314 out_t.eval(feed_dict={feed_t: feed_val}) 1315 with self.assertRaisesRegexp(TypeError, 'cannot be a tf.Tensor object'): 1316 out_t.op.run(feed_dict={feed_t: feed_val}) 1317 1318 def testFeedPrecisionLossError(self): 1319 with session.Session() as sess: 1320 largest_int64 = np.iinfo(np.int64).max 1321 1322 feed_int_implicit_int32 = constant_op.constant(1) 1323 feed_int_explicit_int32 = constant_op.constant(1, dtype=dtypes.int32) 1324 1325 out_t = constant_op.constant(1.0) 1326 1327 with self.assertRaisesRegexp(TypeError, 1328 'is not compatible with Tensor type'): 1329 sess.run(out_t, feed_dict={feed_int_implicit_int32: largest_int64}) 1330 with self.assertRaisesRegexp(TypeError, 1331 'is not compatible with Tensor type'): 1332 sess.run(out_t, feed_dict={feed_int_explicit_int32: largest_int64}) 1333 1334 def testStringFetch(self): 1335 with session.Session(): 1336 for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]: 1337 size = 1 1338 for s in shape: 1339 size *= s 1340 c_list = np.array( 1341 [compat.as_bytes(str(i)) for i in xrange(size)], 1342 dtype=np.object).reshape(shape) if size > 0 else [] 1343 c = constant_op.constant(c_list) 1344 self.assertAllEqual(c.eval(), c_list) 1345 1346 def testStringFeed(self): 1347 with session.Session() as sess: 1348 for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]: 1349 size = 1 1350 for s in shape: 1351 size *= s 1352 c_list = np.array( 1353 [compat.as_bytes(str(i)) for i in xrange(size)], 1354 dtype=np.object).reshape(shape) 1355 feed_t = array_ops.placeholder(dtype=dtypes.string, shape=shape) 1356 c = array_ops.identity(feed_t) 1357 self.assertAllEqual(sess.run(c, feed_dict={feed_t: c_list}), c_list) 1358 self.assertAllEqual( 1359 sess.run(feed_t, feed_dict={ 1360 feed_t: c_list 1361 }), c_list) 1362 c_v, feed_v = sess.run([c, feed_t], feed_dict={feed_t: c_list}) 1363 self.assertAllEqual(c_v, c_list) 1364 self.assertAllEqual(feed_v, c_list) 1365 1366 def testStringFeedWithNullCharacters(self): 1367 with session.Session(): 1368 c_list = [b'\n\x01\x00', b'\n\x00\x01'] 1369 feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[2]) 1370 c = array_ops.identity(feed_t) 1371 out = c.eval(feed_dict={feed_t: c_list}) 1372 self.assertEqual(c_list[0], out[0]) 1373 self.assertEqual(c_list[1], out[1]) 1374 1375 def testStringFeedWithUnicode(self): 1376 with session.Session(): 1377 c_list = [ 1378 u'\n\x01\x00', u'\n\x00\x01', u'\u26a3 unicode', 1379 u'\U0001f60e deal with it' 1380 ] 1381 feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[len(c_list)]) 1382 c = array_ops.identity(feed_t) 1383 1384 out = c.eval(feed_dict={feed_t: c_list}) 1385 for i in range(len(c_list)): 1386 self.assertEqual(c_list[i], out[i].decode('utf-8')) 1387 1388 out = c.eval(feed_dict={feed_t: np.array(c_list, dtype=np.object)}) 1389 for i in range(len(c_list)): 1390 self.assertEqual(c_list[i], out[i].decode('utf-8')) 1391 1392 def testInvalidTargetFails(self): 1393 with self.assertRaisesRegexp( 1394 errors.NotFoundError, 1395 'No session factory registered for the given session options'): 1396 session.Session('INVALID_TARGET') 1397 1398 def testFetchByNameDifferentStringTypes(self): 1399 with session.Session() as sess: 1400 c = constant_op.constant(42.0, name='c') 1401 d = constant_op.constant(43.0, name=u'd') 1402 e = constant_op.constant(44.0, name=b'e') 1403 f = constant_op.constant(45.0, name=r'f') 1404 1405 self.assertTrue(isinstance(c.name, six.text_type)) 1406 self.assertTrue(isinstance(d.name, six.text_type)) 1407 self.assertTrue(isinstance(e.name, six.text_type)) 1408 self.assertTrue(isinstance(f.name, six.text_type)) 1409 1410 self.assertEqual(42.0, sess.run('c:0')) 1411 self.assertEqual(42.0, sess.run(u'c:0')) 1412 self.assertEqual(42.0, sess.run(b'c:0')) 1413 self.assertEqual(42.0, sess.run(r'c:0')) 1414 1415 self.assertEqual(43.0, sess.run('d:0')) 1416 self.assertEqual(43.0, sess.run(u'd:0')) 1417 self.assertEqual(43.0, sess.run(b'd:0')) 1418 self.assertEqual(43.0, sess.run(r'd:0')) 1419 1420 self.assertEqual(44.0, sess.run('e:0')) 1421 self.assertEqual(44.0, sess.run(u'e:0')) 1422 self.assertEqual(44.0, sess.run(b'e:0')) 1423 self.assertEqual(44.0, sess.run(r'e:0')) 1424 1425 self.assertEqual(45.0, sess.run('f:0')) 1426 self.assertEqual(45.0, sess.run(u'f:0')) 1427 self.assertEqual(45.0, sess.run(b'f:0')) 1428 self.assertEqual(45.0, sess.run(r'f:0')) 1429 1430 def testIncorrectGraph(self): 1431 with ops.Graph().as_default() as g_1: 1432 c_1 = constant_op.constant(1.0, name='c') 1433 1434 with ops.Graph().as_default() as g_2: 1435 c_2 = constant_op.constant(2.0, name='c') 1436 1437 self.assertEqual('c', c_1.op.name) 1438 self.assertEqual('c', c_2.op.name) 1439 1440 with session.Session(graph=g_1) as sess_1: 1441 self.assertEqual(1.0, sess_1.run(c_1)) 1442 with self.assertRaises(ValueError): 1443 sess_1.run(c_2) 1444 with self.assertRaises(ValueError): 1445 sess_1.run(c_2.op) 1446 1447 with session.Session(graph=g_2) as sess_2: 1448 with self.assertRaises(ValueError): 1449 sess_2.run(c_1) 1450 with self.assertRaises(ValueError): 1451 sess_2.run(c_1.op) 1452 self.assertEqual(2.0, sess_2.run(c_2)) 1453 1454 def testFeedDictKeyException(self): 1455 with session.Session() as sess: 1456 a = constant_op.constant(1.0, dtypes.float32, name='a') 1457 with self.assertRaisesRegexp(TypeError, 'Cannot interpret feed_dict'): 1458 sess.run(a, feed_dict={'a': [2.0]}) 1459 1460 def testPerStepTrace(self): 1461 run_options = config_pb2.RunOptions( 1462 trace_level=config_pb2.RunOptions.FULL_TRACE) 1463 run_metadata = config_pb2.RunMetadata() 1464 1465 with ops.device('/cpu:0'): 1466 with session.Session() as sess: 1467 sess.run(constant_op.constant(1.0)) 1468 self.assertTrue(not run_metadata.HasField('step_stats')) 1469 1470 sess.run(constant_op.constant(1.0), run_metadata=run_metadata) 1471 self.assertTrue(not run_metadata.HasField('step_stats')) 1472 1473 sess.run( 1474 constant_op.constant(1.0), 1475 options=run_options, 1476 run_metadata=run_metadata) 1477 1478 self.assertTrue(run_metadata.HasField('step_stats')) 1479 self.assertEquals(len(run_metadata.step_stats.dev_stats), 1) 1480 1481 def testRunOptionsRunMetadata(self): 1482 run_options = config_pb2.RunOptions( 1483 trace_level=config_pb2.RunOptions.FULL_TRACE) 1484 run_metadata = config_pb2.RunMetadata() 1485 1486 with ops.device('/cpu:0'): 1487 with session.Session() as sess: 1488 # all combinations are valid 1489 sess.run(constant_op.constant(1.0), options=None, run_metadata=None) 1490 sess.run( 1491 constant_op.constant(1.0), options=None, run_metadata=run_metadata) 1492 self.assertTrue(not run_metadata.HasField('step_stats')) 1493 1494 sess.run( 1495 constant_op.constant(1.0), options=run_options, run_metadata=None) 1496 self.assertTrue(not run_metadata.HasField('step_stats')) 1497 1498 sess.run( 1499 constant_op.constant(1.0), 1500 options=run_options, 1501 run_metadata=run_metadata) 1502 1503 self.assertTrue(run_metadata.HasField('step_stats')) 1504 self.assertEquals(len(run_metadata.step_stats.dev_stats), 1) 1505 1506 def testFeedShapeCompatibility(self): 1507 # TODO(nolivia): C API doesn't yet handle marking nodes as not feedable. 1508 if ops._USE_C_API: 1509 return 1510 1511 with session.Session() as sess: 1512 some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0]) 1513 new_shape = constant_op.constant([2, 2]) 1514 reshaped_tensor = array_ops.reshape(some_tensor, new_shape) 1515 1516 with self.assertRaisesRegexp(ValueError, 'Cannot feed value of shape'): 1517 sess.run(reshaped_tensor, feed_dict={some_tensor: [1.0, 2.0, 3.0]}) 1518 1519 with self.assertRaisesRegexp(ValueError, 'may not be fed'): 1520 sess.run(reshaped_tensor, feed_dict={new_shape: [3, 7]}) 1521 1522 def testInferShapesFalse(self): 1523 with ops.Graph().as_default(), ops.device('/cpu:0'): 1524 a = constant_op.constant([[1, 2]]) 1525 sess = session.Session() 1526 self.assertFalse('_output_shapes' in sess.graph_def.node[0].attr) 1527 # Avoid lint error regarding 'unused' var a. 1528 self.assertTrue(a == a) 1529 1530 def testInferShapesTrue(self): 1531 config = config_pb2.ConfigProto( 1532 graph_options=config_pb2.GraphOptions(infer_shapes=True)) 1533 with ops.Graph().as_default(), ops.device('/cpu:0'): 1534 a = constant_op.constant([[1, 2]]) 1535 sess = session.Session(config=config) 1536 self.assertTrue('_output_shapes' in sess.graph_def.node[0].attr) 1537 # Avoid lint error regarding 'unused' var a. 1538 self.assertTrue(a == a) 1539 1540 def testBuildCostModel(self): 1541 run_options = config_pb2.RunOptions() 1542 config = config_pb2.ConfigProto( 1543 allow_soft_placement=True, 1544 graph_options=config_pb2.GraphOptions(build_cost_model=100)) 1545 with session.Session(config=config) as sess: 1546 with ops.device('/device:GPU:0'): 1547 a = array_ops.placeholder(dtypes.float32, shape=[]) 1548 b = math_ops.add(a, a) 1549 c = array_ops.identity(b) 1550 d = math_ops.multiply(c, c) 1551 for step in xrange(120): 1552 run_metadata = config_pb2.RunMetadata() 1553 sess.run( 1554 d, 1555 feed_dict={a: 1.0}, 1556 options=run_options, 1557 run_metadata=run_metadata) 1558 if step == 99: 1559 self.assertTrue(run_metadata.HasField('cost_graph')) 1560 else: 1561 self.assertFalse(run_metadata.HasField('cost_graph')) 1562 1563 def runTestOutputPartitionGraphs(self, sess): 1564 run_options = config_pb2.RunOptions(output_partition_graphs=True) 1565 a = constant_op.constant(1) 1566 run_metadata = config_pb2.RunMetadata() 1567 sess.run(a, options=run_options, run_metadata=run_metadata) 1568 self.assertGreater(len(run_metadata.partition_graphs), 0) 1569 sess.run(a, run_metadata=run_metadata) 1570 self.assertEqual(len(run_metadata.partition_graphs), 0) 1571 1572 def testOutputPartitionGraphsDirect(self): 1573 self.runTestOutputPartitionGraphs(session.Session()) 1574 1575 def testOutputPartitionGraphsDistributed(self): 1576 server = server_lib.Server.create_local_server() 1577 self.runTestOutputPartitionGraphs(session.Session(server.target)) 1578 1579 def testNonInteractiveSessionNesting(self): 1580 sess1 = session.Session() 1581 sess1_controller = sess1.as_default() 1582 sess1_controller.__enter__() 1583 1584 sess2 = session.Session() 1585 sess2_controller = sess2.as_default() 1586 sess2_controller.__enter__() 1587 1588 with self.assertRaisesRegexp(AssertionError, 'Nesting violated'): 1589 sess1_controller.__exit__(None, None, None) 1590 1591 ops._default_session_stack.reset() 1592 1593 def testInteractiveSessionNesting(self): 1594 sess1 = session.InteractiveSession() 1595 sess2 = session.InteractiveSession() 1596 del sess1 1597 del sess2 1598 1599 def testAsDefault(self): 1600 c = constant_op.constant(37) 1601 sess = session.Session() 1602 with sess.as_default(): 1603 self.assertEqual(37, c.eval()) 1604 1605 # Ensure that the session remains valid even when it is not captured. 1606 with session.Session().as_default(): 1607 self.assertEqual(37, c.eval()) 1608 1609 def testReentry(self): 1610 sess = session.Session() 1611 with self.assertRaisesRegexp(RuntimeError, 'not re-entrant'): 1612 with sess: 1613 with sess: 1614 pass 1615 1616 def testInvalidArgument(self): 1617 with self.assertRaisesRegexp(TypeError, 'target must be a string'): 1618 session.Session(37) 1619 with self.assertRaisesRegexp(TypeError, 'config must be a tf.ConfigProto'): 1620 session.Session(config=37) 1621 with self.assertRaisesRegexp(TypeError, 'graph must be a tf.Graph'): 1622 session.Session(graph=37) 1623 1624 def testTimeoutWithShortOperations(self): 1625 num_epochs = 5 1626 q = data_flow_ops.FIFOQueue(capacity=50, dtypes=[dtypes.int32], shapes=[()]) 1627 enqueue_op = q.enqueue_many(constant_op.constant([1, 2])) 1628 1629 # Use a 10-second timeout, which should be longer than any 1630 # non-blocking enqueue_many op. 1631 config = config_pb2.ConfigProto(operation_timeout_in_ms=10000) 1632 with session.Session(config=config) as sess: 1633 for _ in range(num_epochs): 1634 sess.run(enqueue_op) 1635 self.assertEqual(sess.run(q.size()), num_epochs * 2) 1636 1637 def testRegisterFetchAndFeedConversionFunctions(self): 1638 1639 class SquaredTensor(object): 1640 1641 def __init__(self, tensor): 1642 self.sq = math_ops.square(tensor) 1643 1644 fetch_fn = lambda squared_tensor: ([squared_tensor.sq], lambda val: val[0]) 1645 feed_fn1 = lambda feed, feed_val: [(feed.sq, feed_val)] 1646 feed_fn2 = lambda feed: [feed.sq] 1647 1648 session.register_session_run_conversion_functions(SquaredTensor, fetch_fn, 1649 feed_fn1, feed_fn2) 1650 with self.assertRaises(ValueError): 1651 session.register_session_run_conversion_functions(SquaredTensor, fetch_fn, 1652 feed_fn1, feed_fn2) 1653 with self.test_session() as sess: 1654 np1 = np.array([1.0, 1.5, 2.0, 2.5]) 1655 np2 = np.array([3.0, 3.5, 4.0, 4.5]) 1656 squared_tensor = SquaredTensor(np2) 1657 squared_eval = sess.run(squared_tensor) 1658 self.assertAllClose(np2 * np2, squared_eval) 1659 squared_eval = sess.run( 1660 squared_tensor, feed_dict={ 1661 squared_tensor: np1 * np1 1662 }) 1663 self.assertAllClose(np1 * np1, squared_eval) 1664 partial_run = sess.partial_run_setup([squared_tensor], []) 1665 squared_eval = sess.partial_run(partial_run, squared_tensor) 1666 self.assertAllClose(np2 * np2, squared_eval) 1667 1668 def testDefaultLogDevicePlacement(self): 1669 1670 class CaptureStderr(str): 1671 """Class to capture stderr from C++ shared library.""" 1672 1673 def __enter__(self): 1674 self._esc = compat.as_str('\b') 1675 self._output = compat.as_str('') 1676 self._stderr = sys.stderr 1677 self._fd = self._stderr.fileno() 1678 self._out_pipe, in_pipe = os.pipe() 1679 # Save the original io stream. 1680 self._dup_fd = os.dup(self._fd) 1681 # Replace the original io stream with in pipe. 1682 os.dup2(in_pipe, self._fd) 1683 return self 1684 1685 def __exit__(self, *args): 1686 self._stderr.write(self._esc) 1687 self._stderr.flush() 1688 self.read() 1689 os.close(self._out_pipe) 1690 # Restore the original io stream. 1691 os.dup2(self._dup_fd, self._fd) 1692 1693 def read(self): 1694 while True: 1695 data = os.read(self._out_pipe, 1) 1696 if not data or compat.as_str(data) == self._esc: 1697 break 1698 self._output += compat.as_str(data) 1699 1700 def __str__(self): 1701 return self._output 1702 1703 # Passing the config to the server, but not the session should still result 1704 # in logging device placement. 1705 config = config_pb2.ConfigProto(log_device_placement=True) 1706 server = server_lib.Server.create_local_server(config=config) 1707 a = constant_op.constant(1) 1708 b = constant_op.constant(2) 1709 c = a + b 1710 with session.Session(server.target) as sess: 1711 with CaptureStderr() as log: 1712 sess.run(c) 1713 # Ensure that we did log device placement. 1714 self.assertTrue('/job:local/replica:0/task:0/device:CPU:0' in str(log), 1715 str(log)) 1716 1717 def testLocalMasterSessionTimeout(self): 1718 # Test that the timeout passed in a config to the session works correctly. 1719 config = config_pb2.ConfigProto(operation_timeout_in_ms=1000) 1720 server = server_lib.Server.create_local_server() 1721 q = data_flow_ops.FIFOQueue(1, dtypes.float32) 1722 dequeued_t = q.dequeue() 1723 1724 with session.Session(server.target, config=config) as sess: 1725 # Intentionally do not run any enqueue_ops so that dequeue will block 1726 # until operation_timeout_in_ms. 1727 with self.assertRaises(errors.DeadlineExceededError): 1728 sess.run(dequeued_t) 1729 1730 def testDefaultServerTimeout(self): 1731 # Test that the default server config timeout gets used when no Session 1732 # config is provided. 1733 config = config_pb2.ConfigProto(operation_timeout_in_ms=1000) 1734 server = server_lib.Server.create_local_server(config=config) 1735 q = data_flow_ops.FIFOQueue(1, dtypes.float32) 1736 dequeued_t = q.dequeue() 1737 1738 with session.Session(server.target) as sess: 1739 # Intentionally do not run any enqueue_ops so that dequeue will block 1740 # until operation_timeout_in_ms. 1741 with self.assertRaises(errors.DeadlineExceededError): 1742 sess.run(dequeued_t) 1743 1744 def runTestBuildGraphError(self, sess): 1745 # Ensure that errors from building the graph get propagated. 1746 data = array_ops.placeholder(dtypes.float32, shape=[]) 1747 # pylint: disable=protected-access 1748 enter_1 = gen_control_flow_ops._enter(data, 'foo_1', False) 1749 enter_2 = gen_control_flow_ops._enter(data, 'foo_2', False) 1750 # pylint: enable=protected-access 1751 res = math_ops.add(enter_1, enter_2) 1752 with self.assertRaisesOpError('has inputs from different frames'): 1753 sess.run(res, feed_dict={data: 1.0}) 1754 1755 def testBuildGraphErrorDirect(self): 1756 self.runTestBuildGraphError(session.Session()) 1757 1758 def testBuildGraphErrorDist(self): 1759 server = server_lib.Server.create_local_server() 1760 self.runTestBuildGraphError(session.Session(server.target)) 1761 1762 def testDeviceAttributes(self): 1763 attrs = session._DeviceAttributes( 1764 '/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337) 1765 self.assertEqual(1337, attrs.memory_limit_bytes) 1766 self.assertEqual('/job:worker/replica:0/task:3/device:CPU:2', attrs.name) 1767 self.assertEqual('TYPE', attrs.device_type) 1768 str_repr = '%s' % attrs 1769 self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr) 1770 1771 def testDeviceAttributesCanonicalization(self): 1772 attrs = session._DeviceAttributes('/job:worker/replica:0/task:3/cpu:1', 1773 'TYPE', 1337) 1774 self.assertEqual(1337, attrs.memory_limit_bytes) 1775 self.assertEqual('/job:worker/replica:0/task:3/device:CPU:1', attrs.name) 1776 self.assertEqual('TYPE', attrs.device_type) 1777 str_repr = '%s' % attrs 1778 self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr) 1779 1780 def runTestAddFunctionToSession(self, target=''): 1781 """Add a function to a session after the graph has already been run.""" 1782 1783 @function.Defun(dtypes.float32) 1784 def foo(x): 1785 return x + 1 1786 1787 x = constant_op.constant(1.0) 1788 with session.Session(target=target) as sess: 1789 sess.run(x) 1790 f = foo(x) 1791 result = sess.run(f) 1792 self.assertEqual(result, 2.0) 1793 1794 def testAddFunctionToSession(self): 1795 self.runTestAddFunctionToSession() 1796 1797 def testAddFunctionToGrpcSession(self): 1798 server = server_lib.Server.create_local_server() 1799 self.runTestAddFunctionToSession(server.target) 1800 1801 def testOpenAndCloseGrpcSession(self): 1802 server = server_lib.Server.create_local_server() 1803 with session.Session(server.target): 1804 pass 1805 1806 def testOpenAndCloseSession(self): 1807 with session.Session(): 1808 pass 1809 1810 def testAutoConvertAndCheckData(self): 1811 with self.test_session() as sess: 1812 a = array_ops.placeholder(dtype=dtypes.string) 1813 with self.assertRaisesRegexp( 1814 TypeError, 'Type of feed value 1 with type <(\w+) \'int\'> is not'): 1815 sess.run(a, feed_dict={a: 1}) 1816 1817 1818 class GraphMutationTest(test_util.TensorFlowTestCase): 1819 1820 def setUp(self): 1821 self._original_use_c_api_value = ops._USE_C_API 1822 ops._USE_C_API = True 1823 super(GraphMutationTest, self).setUp() 1824 1825 def tearDown(self): 1826 ops._USE_C_API = self._original_use_c_api_value 1827 super(GraphMutationTest, self).tearDown() 1828 1829 def testUpdateInputAfterRunning(self): 1830 with ops.Graph().as_default() as g: 1831 a = constant_op.constant(1.0) 1832 b = constant_op.constant(2.0) 1833 c = a + b 1834 1835 with session.Session(graph=g) as sess: 1836 self.assertAllEqual(3.0, sess.run(c)) 1837 c.op._update_input(1, a) # pylint: disable=protected-access 1838 with self.assertRaisesRegexp( 1839 errors.FailedPreconditionError, 1840 'add.*was changed by updating input tensor after it was run'): 1841 sess.run(c) 1842 1843 # Check that running the graph with a new session is fine 1844 with session.Session(graph=g) as sess2: 1845 self.assertAllEqual(2.0, sess2.run(c)) 1846 1847 def testSetDeviceAfterRunning(self): 1848 with ops.Graph().as_default() as g: 1849 a = constant_op.constant(1.0) 1850 b = constant_op.constant(2.0) 1851 c = a + b 1852 1853 with session.Session(graph=g) as sess: 1854 self.assertAllEqual(3.0, sess.run(c)) 1855 c.op._set_device('/cpu:0') # pylint: disable=protected-access 1856 with self.assertRaisesRegexp( 1857 errors.FailedPreconditionError, 1858 'add.*was changed by setting device after it was run'): 1859 sess.run(c) 1860 1861 def testSetAttrAfterRunning(self): 1862 with ops.Graph().as_default() as g: 1863 a = constant_op.constant(1.0, dtype=dtypes.float32) 1864 b = math_ops.cast(a, dtypes.float64) 1865 1866 with session.Session(graph=g) as sess: 1867 self.assertAllEqual(1.0, sess.run(b)) 1868 b.op._set_attr('DstT', attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT)) 1869 with self.assertRaisesRegexp( 1870 errors.FailedPreconditionError, 1871 'Cast.*was changed by setting attribute after it was run'): 1872 sess.run(b) 1873 1874 def testRunModifyRun(self): 1875 with ops.Graph().as_default() as g: 1876 a = constant_op.constant(1.0) 1877 b = constant_op.constant(2.0) 1878 c = a + b 1879 1880 with session.Session(graph=g) as sess: 1881 self.assertAllEqual(3.0, sess.run(c)) 1882 1883 d = b + c 1884 d.op._update_input(0, a) # pylint: disable=protected-access 1885 self.assertAllEqual(3.0, sess.run(c)) 1886 self.assertAllEqual(4.0, sess.run(d)) 1887 1888 def testRunModifyRunTwoSessions(self): 1889 with ops.Graph().as_default() as g: 1890 a = constant_op.constant(1.0) 1891 b = constant_op.constant(2.0) 1892 c = a + b 1893 1894 with session.Session(graph=g) as sess1: 1895 with session.Session(graph=g) as sess2: 1896 self.assertAllEqual(3.0, sess1.run(c)) 1897 self.assertAllEqual(3.0, sess2.run(c)) 1898 1899 d = b + c 1900 d.op._update_input(0, a) # pylint: disable=protected-access 1901 self.assertAllEqual(3.0, sess2.run(c)) 1902 self.assertAllEqual(4.0, sess2.run(d)) 1903 1904 d.op._update_input(0, b) # pylint: disable=protected-access 1905 self.assertAllEqual(3.0, sess1.run(c)) 1906 self.assertAllEqual(5.0, sess1.run(d)) 1907 1908 with self.assertRaisesRegexp( 1909 errors.FailedPreconditionError, 1910 'add.*was changed by updating input tensor after it was run'): 1911 sess2.run(c) 1912 1913 def testTwoSessionsOneRunBeforeModification(self): 1914 with ops.Graph().as_default() as g, ops.device('/cpu:0'): 1915 a = constant_op.constant(1.0) 1916 b = constant_op.constant(2.0) 1917 c = a + b 1918 1919 with session.Session(graph=g) as sess1: 1920 with session.Session(graph=g) as sess2: 1921 sess1.run(c) 1922 1923 c.op._set_device('/cpu:0') # pylint: disable=protected-access 1924 1925 with self.assertRaisesRegexp( 1926 errors.FailedPreconditionError, 1927 'add.*was changed by setting device after it was run'): 1928 sess1.run(c) 1929 1930 # sess2 was not run before modification 1931 self.assertAllEqual(3.0, sess2.run(c)) 1932 1933 def testTwoSessionsBothRunBeforeModification(self): 1934 with ops.Graph().as_default() as g, ops.device('/cpu:0'): 1935 a = constant_op.constant(1.0) 1936 b = constant_op.constant(2.0) 1937 c = a + b 1938 1939 with session.Session(graph=g) as sess1: 1940 with session.Session(graph=g) as sess2: 1941 sess1.run(c) 1942 sess2.run(c) 1943 1944 c.op._set_device('/cpu:0') # pylint: disable=protected-access 1945 1946 with self.assertRaisesRegexp( 1947 errors.FailedPreconditionError, 1948 'add.*was changed by setting device after it was run'): 1949 sess1.run(c) 1950 1951 with self.assertRaisesRegexp( 1952 errors.FailedPreconditionError, 1953 'add.*was changed by setting device after it was run'): 1954 sess2.run(c) 1955 1956 1957 if __name__ == '__main__': 1958 googletest.main() 1959