Home | History | Annotate | Download | only in client
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.python.client.session.Session."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import collections
     21 import os
     22 import sys
     23 import threading
     24 import time
     25 
     26 import numpy as np
     27 import six
     28 from six.moves import xrange  # pylint: disable=redefined-builtin
     29 
     30 from tensorflow.core.framework import attr_value_pb2
     31 from tensorflow.core.framework import types_pb2
     32 from tensorflow.core.lib.core import error_codes_pb2
     33 from tensorflow.core.protobuf import config_pb2
     34 from tensorflow.python.client import session
     35 from tensorflow.python.framework import common_shapes
     36 from tensorflow.python.framework import constant_op
     37 from tensorflow.python.framework import dtypes
     38 from tensorflow.python.framework import errors
     39 from tensorflow.python.framework import function
     40 from tensorflow.python.framework import ops
     41 from tensorflow.python.framework import sparse_tensor
     42 from tensorflow.python.framework import tensor_util
     43 from tensorflow.python.framework import test_util
     44 from tensorflow.python.framework import versions
     45 from tensorflow.python.ops import array_ops
     46 from tensorflow.python.ops import control_flow_ops
     47 from tensorflow.python.ops import data_flow_ops
     48 from tensorflow.python.ops import gen_control_flow_ops
     49 from tensorflow.python.ops import math_ops
     50 # Import resource_variable_ops for the variables-to-tensor implicit conversion.
     51 from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import
     52 from tensorflow.python.ops import state_ops
     53 from tensorflow.python.ops import variables
     54 from tensorflow.python.platform import googletest
     55 from tensorflow.python.training import server_lib
     56 from tensorflow.python.util import compat
     57 
     58 # NOTE(mrry): Dummy shape registration for ops used in the tests, since they
     59 # don't have C++ op registrations on which to attach C++ shape fns.
     60 ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
     61 
     62 
     63 @test_util.with_c_api
     64 class SessionTest(test_util.TensorFlowTestCase):
     65 
     66   def testUseExistingGraph(self):
     67     with ops.Graph().as_default() as g, ops.device('/cpu:0'):
     68       a = constant_op.constant(6.0, shape=[1, 1])
     69       b = constant_op.constant(7.0, shape=[1, 1])
     70       c = math_ops.matmul(a, b, name='matmul')
     71     with session.Session(graph=g):
     72       result = c.eval()
     73       self.assertAllEqual(result, [[42.0]])
     74 
     75   def testUseDefaultGraph(self):
     76     with ops.Graph().as_default(), ops.device('/cpu:0'):
     77       a = constant_op.constant(6.0, shape=[1, 1])
     78       b = constant_op.constant(7.0, shape=[1, 1])
     79       c = math_ops.matmul(a, b, name='matmul')
     80       with session.Session():
     81         result = c.eval()
     82         self.assertAllEqual(result, [[42.0]])
     83 
     84   def testCreate(self):
     85     with session.Session():
     86       inp = constant_op.constant(10.0, shape=[2, 3], name='W1')
     87       copy = array_ops.identity(inp)
     88       # Test with feed.
     89       # TODO(mrry): Investigate why order='F' didn't work.
     90       arr = np.asarray([[0, 1, 2], [3, 4, 5]], dtype=np.float32, order='C')
     91       copy_val = copy.eval({'W1:0': arr})
     92       self.assertAllEqual(arr, copy_val)
     93       # Test without feed.
     94       copy_val = copy.eval()
     95       self.assertAllEqual(
     96           np.asarray(
     97               [[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]], dtype=np.float32),
     98           copy_val)
     99 
    100   def testManyCPUs(self):
    101     # TODO(keveman): Implement ListDevices and test for the number of
    102     # devices returned by ListDevices.
    103     with session.Session(
    104         config=config_pb2.ConfigProto(device_count={
    105             'CPU': 2
    106         })):
    107       inp = constant_op.constant(10.0, name='W1')
    108       self.assertAllEqual(inp.eval(), 10.0)
    109 
    110   def testPerSessionThreads(self):
    111     # TODO(keveman): Implement ListDevices and test for the number of
    112     # devices returned by ListDevices.
    113     with session.Session(
    114         config=config_pb2.ConfigProto(use_per_session_threads=True)):
    115       inp = constant_op.constant(10.0, name='W1')
    116       self.assertAllEqual(inp.eval(), 10.0)
    117 
    118   def testSessionInterOpThreadPool(self):
    119     config = config_pb2.ConfigProto()
    120     pool = config.session_inter_op_thread_pool.add()
    121     with session.Session(config=config) as s:
    122       inp = constant_op.constant(10.0, name='W1')
    123       results = s.run([inp])
    124       self.assertAllEqual([10.0], results)
    125 
    126     pool = config.session_inter_op_thread_pool.add()
    127     pool.num_threads = 1
    128     with session.Session(config=config) as s:
    129       inp = constant_op.constant(20.0, name='W2')
    130       results = s.run([inp])
    131       self.assertAllEqual([20.0], results)
    132 
    133     pool = config.session_inter_op_thread_pool.add()
    134     pool.num_threads = 1
    135     pool.global_name = 't1'
    136     run_options = config_pb2.RunOptions()
    137     run_options.inter_op_thread_pool = (
    138         len(config.session_inter_op_thread_pool) - 1)
    139     with session.Session(config=config) as s:
    140       inp = constant_op.constant(30.0, name='W2')
    141       results = s.run([inp], options=run_options)
    142       self.assertAllEqual([30.0], results)
    143 
    144   def testErrorsReported(self):
    145     with session.Session() as s:
    146       constant_op.constant(10.0, name='W1')
    147       with self.assertRaises(ValueError):
    148         s.run('foo:0')
    149 
    150   def testErrorPayload(self):
    151     with session.Session():
    152       a = array_ops.placeholder(dtypes.float32)
    153       with self.assertRaisesOpError(lambda e: e.op == a.op):
    154         a.eval()
    155 
    156   def testErrorCodeWithNoNodeDef(self):
    157     with session.Session() as s:
    158       a = array_ops.placeholder(dtypes.float32, shape=[])
    159       b = array_ops.placeholder(dtypes.float32, shape=[])
    160       r1 = math_ops.add(a, b)
    161 
    162       def exc_predicate(e):
    163         return (e.op is None and e.node_def is None and
    164                 e.error_code == error_codes_pb2.INVALID_ARGUMENT)
    165 
    166       with self.assertRaisesOpError(exc_predicate):
    167         # Run with a bogus handle.
    168         s.partial_run('foo', r1, feed_dict={a: 1, b: 2})
    169 
    170   def testOpConstructionErrorPayload(self):
    171     if ops._USE_C_API:
    172       return  # No shape registration for 'ConstructionFails'
    173 
    174     with session.Session():
    175       failing_op = ops.get_default_graph().create_op(
    176           'ConstructionFails', [], [], name='f')
    177 
    178       def exc_predicate(e):
    179         return (e.op == failing_op and
    180                 e.error_code == error_codes_pb2.INVALID_ARGUMENT)
    181 
    182       with self.assertRaisesOpError(exc_predicate):
    183         failing_op.run()
    184 
    185   def testErrorBasedOn(self):
    186     with session.Session() as sess:
    187       a = constant_op.constant(0.0, shape=[2, 3])
    188       # NOTE(mrry): The original_op is nonsense, but used here to test that the
    189       #   errors are reported correctly.
    190       # pylint: disable=protected-access
    191       with sess.graph._original_op(a.op):
    192         b = array_ops.identity(a, name='id')
    193       with sess.graph._original_op(b.op):
    194         c = array_ops.placeholder(dtypes.float32)
    195       # pylint: enable=protected-access
    196 
    197       def exc_predicate(e):
    198         return (e.op == c.op and e.op._original_op == b.op and
    199                 e.op._original_op._original_op == a.op)
    200 
    201       with self.assertRaisesOpError(exc_predicate):
    202         c.eval()
    203 
    204   def testFetchNone(self):
    205     with session.Session() as s:
    206       a = constant_op.constant(1.0)
    207       with self.assertRaises(TypeError):
    208         s.run(None)
    209       with self.assertRaises(TypeError):
    210         s.run([None])
    211       with self.assertRaises(TypeError):
    212         s.run({'b': None})
    213       with self.assertRaises(TypeError):
    214         s.run({'a': a, 'b': None})
    215 
    216   def testFetchSingleton(self):
    217     with session.Session() as sess:
    218       a = constant_op.constant(42.0)
    219       res = sess.run(a)
    220       self.assertEqual(42.0, res)
    221       res = sess.run(a.op)  # An op, not a tensor.
    222       self.assertEqual(None, res)
    223       tensor_runner = sess.make_callable(a)
    224       res = tensor_runner()
    225       self.assertEqual(42.0, res)
    226       op_runner = sess.make_callable(a.op)
    227       res = op_runner()
    228       self.assertEqual(None, res)
    229 
    230   def testFetchSingletonByName(self):
    231     with session.Session() as sess:
    232       a = constant_op.constant(42.0)
    233       res = sess.run(a.name)
    234       self.assertEqual(42.0, res)
    235       res = sess.run(a.op)  # An op, not a tensor.
    236       self.assertEqual(None, res)
    237 
    238   def testFetchList(self):
    239     with session.Session() as sess:
    240       a = constant_op.constant(42.0)
    241       b = control_flow_ops.no_op()  # An op, not a tensor.
    242       c = constant_op.constant(44.0)
    243       v = variables.Variable([54.0])
    244       assign = v.assign([63.0])
    245       res = sess.run([a, b, c, a.name, assign.op])
    246       self.assertTrue(isinstance(res, list))
    247       self.assertEqual([42.0, None, 44.0, 42.0, None], res)
    248       list_runner = sess.make_callable([a, b, c, a.name, assign.op])
    249       res = list_runner()
    250       self.assertTrue(isinstance(res, list))
    251       self.assertEqual([42.0, None, 44.0, 42.0, None], res)
    252 
    253   def testFetchTuple(self):
    254     with session.Session() as sess:
    255       a = constant_op.constant(42.0)
    256       b = control_flow_ops.no_op()  # An op, not a tensor.
    257       c = constant_op.constant(44.0)
    258       res = sess.run((a, b, c, a.name))
    259       self.assertTrue(isinstance(res, tuple))
    260       self.assertEqual((42.0, None, 44.0, 42.0), res)
    261       tuple_runner = sess.make_callable((a, b, c, a.name))
    262       res = tuple_runner()
    263       self.assertTrue(isinstance(res, tuple))
    264       self.assertEqual((42.0, None, 44.0, 42.0), res)
    265 
    266   def testFetchNamedTuple(self):
    267     # pylint: disable=invalid-name
    268     ABC = collections.namedtuple('ABC', ['a', 'b', 'c'])
    269     # pylint: enable=invalid-name
    270     with session.Session() as sess:
    271       a = constant_op.constant(42.0)
    272       b = control_flow_ops.no_op()  # An op, not a tensor.
    273       c = constant_op.constant(44.0)
    274       res = sess.run(ABC(a, b, c))
    275       self.assertTrue(isinstance(res, ABC))
    276       self.assertEqual(42.0, res.a)
    277       self.assertEqual(None, res.b)
    278       self.assertEqual(44.0, res.c)
    279       namedtuple_runner = sess.make_callable(ABC(a, b, c))
    280       res = namedtuple_runner()
    281       self.assertTrue(isinstance(res, ABC))
    282       self.assertEqual(42.0, res.a)
    283       self.assertEqual(None, res.b)
    284       self.assertEqual(44.0, res.c)
    285 
    286   def testFetchDict(self):
    287     with session.Session() as sess:
    288       a = constant_op.constant(42.0)
    289       b = control_flow_ops.no_op()  # An op, not a tensor.
    290       c = constant_op.constant(44.0)
    291       res = sess.run({'a': a, 'b': b, 'c': c})
    292       self.assertTrue(isinstance(res, dict))
    293       self.assertEqual(42.0, res['a'])
    294       self.assertEqual(None, res['b'])
    295       self.assertEqual(44.0, res['c'])
    296 
    297   def testFetchOrderedDict(self):
    298     with session.Session() as sess:
    299       a = constant_op.constant(42.0)
    300       b = control_flow_ops.no_op()  # An op, not a tensor.
    301       c = constant_op.constant(44.0)
    302       res = sess.run(collections.OrderedDict([(3, a), (2, b), (1, c)]))
    303       self.assertTrue(isinstance(res, collections.OrderedDict))
    304       self.assertEqual([3, 2, 1], list(res.keys()))
    305       self.assertEqual(42.0, res[3])
    306       self.assertEqual(None, res[2])
    307       self.assertEqual(44.0, res[1])
    308 
    309   def testFetchNestingEmptyOneLevel(self):
    310     with session.Session() as sess:
    311       a_val = 11.0
    312       a = constant_op.constant(a_val)
    313 
    314       res = sess.run([[], tuple(), {}])
    315       self.assertTrue(isinstance(res, list))
    316       self.assertEquals(3, len(res))
    317       self.assertTrue(isinstance(res[0], list))
    318       self.assertEqual(0, len(res[0]))
    319       self.assertTrue(isinstance(res[1], tuple))
    320       self.assertEqual(0, len(res[1]))
    321       self.assertTrue(isinstance(res[2], dict))
    322       self.assertEqual(0, len(res[2]))
    323 
    324       res = sess.run([[], tuple(), {}, a])
    325       self.assertTrue(isinstance(res, list))
    326       self.assertEquals(4, len(res))
    327       self.assertTrue(isinstance(res[0], list))
    328       self.assertEqual(0, len(res[0]))
    329       self.assertTrue(isinstance(res[1], tuple))
    330       self.assertEqual(0, len(res[1]))
    331       self.assertTrue(isinstance(res[2], dict))
    332       self.assertEqual(0, len(res[2]))
    333       self.assertEqual(a_val, res[3])
    334 
    335   def testFetchNestingOneLevel(self):
    336     with session.Session() as sess:
    337       # pylint: disable=invalid-name
    338       ABC = collections.namedtuple('ABC', ['a', 'b', 'c'])
    339       DEFG = collections.namedtuple('DEFG', ['d', 'e', 'f', 'g'])
    340       # pylint: enable=invalid-name
    341       a_val = 42.0
    342       b_val = None
    343       c_val = 44.0
    344       a = constant_op.constant(a_val)
    345       b = control_flow_ops.no_op()  # An op, not a tensor.
    346       c = constant_op.constant(c_val)
    347       # List of lists, tuples, namedtuple, and dict
    348       res = sess.run([[a, b, c], (a, b, c),
    349                       ABC(a=a, b=b, c=c), {
    350                           'a': a.name,
    351                           'c': c,
    352                           'b': b
    353                       }])
    354       self.assertTrue(isinstance(res, list))
    355       self.assertEqual(4, len(res))
    356       self.assertTrue(isinstance(res[0], list))
    357       self.assertEqual(3, len(res[0]))
    358       self.assertEqual(a_val, res[0][0])
    359       self.assertEqual(b_val, res[0][1])
    360       self.assertEqual(c_val, res[0][2])
    361       self.assertTrue(isinstance(res[1], tuple))
    362       self.assertEqual(3, len(res[1]))
    363       self.assertEqual(a_val, res[1][0])
    364       self.assertEqual(b_val, res[1][1])
    365       self.assertEqual(c_val, res[1][2])
    366       self.assertTrue(isinstance(res[2], ABC))
    367       self.assertEqual(a_val, res[2].a)
    368       self.assertEqual(b_val, res[2].b)
    369       self.assertEqual(c_val, res[2].c)
    370       self.assertTrue(isinstance(res[3], dict))
    371       self.assertEqual(3, len(res[3]))
    372       self.assertEqual(a_val, res[3]['a'])
    373       self.assertEqual(b_val, res[3]['b'])
    374       self.assertEqual(c_val, res[3]['c'])
    375       # Tuple of lists, tuples, namedtuple, and dict
    376       res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b, c=c), {
    377           'a': a,
    378           'c': c,
    379           'b': b
    380       }))
    381       self.assertTrue(isinstance(res, tuple))
    382       self.assertEqual(4, len(res))
    383       self.assertTrue(isinstance(res[0], list))
    384       self.assertEqual(3, len(res[0]))
    385       self.assertEqual(a_val, res[0][0])
    386       self.assertEqual(b_val, res[0][1])
    387       self.assertEqual(c_val, res[0][2])
    388       self.assertTrue(isinstance(res[1], tuple))
    389       self.assertEqual(3, len(res[1]))
    390       self.assertEqual(a_val, res[1][0])
    391       self.assertEqual(b_val, res[1][1])
    392       self.assertEqual(c_val, res[1][2])
    393       self.assertTrue(isinstance(res[2], ABC))
    394       self.assertEqual(a_val, res[2].a)
    395       self.assertEqual(b_val, res[2].b)
    396       self.assertEqual(c_val, res[2].c)
    397       self.assertTrue(isinstance(res[3], dict))
    398       self.assertEqual(3, len(res[3]))
    399       self.assertEqual(a_val, res[3]['a'])
    400       self.assertEqual(b_val, res[3]['b'])
    401       self.assertEqual(c_val, res[3]['c'])
    402       # Namedtuple of lists, tuples, namedtuples, and dict
    403       res = sess.run(
    404           DEFG(
    405               d=[a, b, c],
    406               e=(a, b, c),
    407               f=ABC(a=a.name, b=b, c=c),
    408               g={
    409                   'a': a,
    410                   'c': c,
    411                   'b': b
    412               }))
    413       self.assertTrue(isinstance(res, DEFG))
    414       self.assertTrue(isinstance(res.d, list))
    415       self.assertEqual(3, len(res.d))
    416       self.assertEqual(a_val, res.d[0])
    417       self.assertEqual(b_val, res.d[1])
    418       self.assertEqual(c_val, res.d[2])
    419       self.assertTrue(isinstance(res.e, tuple))
    420       self.assertEqual(3, len(res.e))
    421       self.assertEqual(a_val, res.e[0])
    422       self.assertEqual(b_val, res.e[1])
    423       self.assertEqual(c_val, res.e[2])
    424       self.assertTrue(isinstance(res.f, ABC))
    425       self.assertEqual(a_val, res.f.a)
    426       self.assertEqual(b_val, res.f.b)
    427       self.assertEqual(c_val, res.f.c)
    428       self.assertTrue(isinstance(res.g, dict))
    429       self.assertEqual(3, len(res.g))
    430       self.assertEqual(a_val, res.g['a'])
    431       self.assertEqual(b_val, res.g['b'])
    432       self.assertEqual(c_val, res.g['c'])
    433       # Dict of lists, tuples, namedtuples, and dict
    434       res = sess.run({
    435           'd': [a, b, c],
    436           'e': (a, b, c),
    437           'f': ABC(a=a, b=b, c=c),
    438           'g': {
    439               'a': a.name,
    440               'c': c,
    441               'b': b
    442           }
    443       })
    444       self.assertTrue(isinstance(res, dict))
    445       self.assertEqual(4, len(res))
    446       self.assertTrue(isinstance(res['d'], list))
    447       self.assertEqual(3, len(res['d']))
    448       self.assertEqual(a_val, res['d'][0])
    449       self.assertEqual(b_val, res['d'][1])
    450       self.assertEqual(c_val, res['d'][2])
    451       self.assertTrue(isinstance(res['e'], tuple))
    452       self.assertEqual(3, len(res['e']))
    453       self.assertEqual(a_val, res['e'][0])
    454       self.assertEqual(b_val, res['e'][1])
    455       self.assertEqual(c_val, res['e'][2])
    456       self.assertTrue(isinstance(res['f'], ABC))
    457       self.assertEqual(a_val, res['f'].a)
    458       self.assertEqual(b_val, res['f'].b)
    459       self.assertEqual(c_val, res['f'].c)
    460       self.assertTrue(isinstance(res['g'], dict))
    461       self.assertEqual(3, len(res['g']))
    462       self.assertEqual(a_val, res['g']['a'])
    463       self.assertEqual(b_val, res['g']['b'])
    464       self.assertEqual(c_val, res['g']['c'])
    465 
    466   def testFetchTensorObject(self):
    467     with session.Session() as s:
    468       a = constant_op.constant(1.0, shape=[1, 2])
    469       b = constant_op.constant(2.0, shape=[2, 3])
    470       c = math_ops.matmul(a, b)
    471       results_with_list = s.run([c])
    472       self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_list[0])
    473       results_with_single = s.run(c)
    474       self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_single)
    475       results_with_get = c.eval()
    476       self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_get)
    477       a_val, b_val = s.run([a, b])  # Test multiple fetches.
    478       self.assertAllEqual([[1.0, 1.0]], a_val)
    479       self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], b_val)
    480       results_with_dict = s.run({'a': [a], 'b': b, 'z': [a, b]})
    481       self.assertAllEqual([[1.0, 1.0]], results_with_dict['a'][0])
    482       self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]],
    483                           results_with_dict['b'])
    484       self.assertAllEqual(results_with_dict['a'][0], results_with_dict['z'][0])
    485       self.assertAllEqual(results_with_dict['b'], results_with_dict['z'][1])
    486 
    487       # Test nested structures
    488       results_with_nested_list = s.run([[[a, b], b], a, [a, b]])
    489       self.assertAllEqual([[1.0, 1.0]], results_with_nested_list[0][0][0])
    490       self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]],
    491                           results_with_nested_list[0][0][1])
    492       self.assertAllEqual(results_with_nested_list[0][0][0],
    493                           results_with_nested_list[1])
    494       self.assertAllEqual(results_with_nested_list[1],
    495                           results_with_nested_list[2][0])
    496       self.assertAllEqual(results_with_nested_list[0][0][1],
    497                           results_with_nested_list[0][1])
    498       self.assertAllEqual(results_with_nested_list[0][1],
    499                           results_with_nested_list[2][1])
    500 
    501   def testFetchScalar(self):
    502     with session.Session() as s:
    503       for scalar in np.int32, np.int64, np.float16, np.float32, np.float64:
    504         x = scalar(7)
    505         y = scalar(8)
    506         tf_x = constant_op.constant(x, shape=[])
    507         tf_y = constant_op.constant(y)
    508         tf_xy = math_ops.add(tf_x, tf_y)
    509         # Single fetch
    510         xy = s.run(tf_xy)
    511         self.assertEqual(scalar, type(xy))
    512         self.assertEqual(x + y, xy)
    513         # List fetch
    514         xy, = s.run([tf_xy])
    515         self.assertEqual(scalar, type(xy))
    516         self.assertEqual(x + y, xy)
    517         # Dict fetch
    518         xy = s.run({'xy': tf_xy})['xy']
    519         self.assertEqual(scalar, type(xy))
    520         self.assertEqual(x + y, xy)
    521         # Nested list fetch
    522         xy = s.run([[[tf_xy]], tf_xy, [tf_xy]])
    523         self.assertAllEqual(xy, [[[x + y]], x + y, [x + y]])
    524         self.assertEqual(scalar, type(xy[0][0][0]))
    525         self.assertEqual(scalar, type(xy[1]))
    526         self.assertEqual(scalar, type(xy[2][0]))
    527 
    528   def testFetchOperationObject(self):
    529     with session.Session() as s:
    530       a = constant_op.constant(1.0, shape=[1, 2])
    531       v = variables.Variable(a, name='testFetchOperationObject_v')
    532       s.run(v.initializer)
    533       v_val = s.run(v)
    534       self.assertAllEqual([[1.0, 1.0]], v_val)
    535 
    536   def testFetchSparseTensor(self):
    537     with session.Session() as s:
    538       indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
    539       values = np.array([1.0, 2.0]).astype(np.float32)
    540       shape = np.array([7, 9, 2]).astype(np.int64)
    541       sp = sparse_tensor.SparseTensor(
    542           constant_op.constant(indices), constant_op.constant(values),
    543           constant_op.constant(shape))
    544       # Single fetch, use as tuple
    545       sp_out = s.run(sp)
    546       indices_out, values_out, shape_out = sp_out
    547       self.assertAllEqual(indices_out, indices)
    548       self.assertAllEqual(values_out, values)
    549       self.assertAllEqual(shape_out, shape)
    550       # Single fetch, use as SparseTensorValue
    551       sp_out = s.run(sp)
    552       self.assertAllEqual(sp_out.indices, indices)
    553       self.assertAllEqual(sp_out.values, values)
    554       self.assertAllEqual(sp_out.dense_shape, shape)
    555       # Tuple fetch, use as tuple
    556       indices_out, values_out, shape_out = s.run(sp)
    557       self.assertAllEqual(indices_out, indices)
    558       self.assertAllEqual(values_out, values)
    559       self.assertAllEqual(shape_out, shape)
    560       # List fetch, use as tuple
    561       (indices_out, values_out, shape_out), = s.run([sp])
    562       self.assertAllEqual(indices_out, indices)
    563       self.assertAllEqual(values_out, values)
    564       self.assertAllEqual(shape_out, shape)
    565       # List fetch, use as SparseTensorValue
    566       sp_out, = s.run([sp])
    567       self.assertAllEqual(sp_out.indices, indices)
    568       self.assertAllEqual(sp_out.values, values)
    569       self.assertAllEqual(sp_out.dense_shape, shape)
    570       # Dict fetch (single value), use as tuple
    571       indices_out, values_out, shape_out = s.run({'sp': sp})['sp']
    572       self.assertAllEqual(indices_out, indices)
    573       self.assertAllEqual(values_out, values)
    574       self.assertAllEqual(shape_out, shape)
    575       # Dict fetch (list value), use as tuple
    576       (indices_out, values_out, shape_out), = s.run({'sp': [sp]})['sp']
    577       self.assertAllEqual(indices_out, indices)
    578       self.assertAllEqual(values_out, values)
    579       self.assertAllEqual(shape_out, shape)
    580       # Dict fetch, use as SparseTensorValue
    581       sp_out = s.run({'sp': sp})['sp']
    582       self.assertAllEqual(sp_out.indices, indices)
    583       self.assertAllEqual(sp_out.values, values)
    584       self.assertAllEqual(sp_out.dense_shape, shape)
    585       # Nested list fetch use as tuple
    586       sp_out = s.run([[[sp]], sp])
    587       indices_out, values_out, shape_out = sp_out[0][0][0]
    588       self.assertAllEqual(indices_out, indices)
    589       self.assertAllEqual(values_out, values)
    590       self.assertAllEqual(shape_out, shape)
    591       indices_out, values_out, shape_out = sp_out[1]
    592       self.assertAllEqual(indices_out, indices)
    593       self.assertAllEqual(values_out, values)
    594       self.assertAllEqual(shape_out, shape)
    595       # Nested list fetch, use as SparseTensorValue
    596       sp_out = s.run([[[sp]], sp])
    597       self.assertAllEqual(sp_out[0][0][0].indices, indices)
    598       self.assertAllEqual(sp_out[0][0][0].values, values)
    599       self.assertAllEqual(sp_out[0][0][0].dense_shape, shape)
    600       self.assertAllEqual(sp_out[1].indices, indices)
    601       self.assertAllEqual(sp_out[1].values, values)
    602       self.assertAllEqual(sp_out[1].dense_shape, shape)
    603 
    604   def testFeedSparseTensor(self):
    605     with session.Session() as s:
    606       indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
    607       values = np.array([1.0, 2.0]).astype(np.float32)
    608       shape = np.array([7, 9, 2]).astype(np.int64)
    609       sp = sparse_tensor.SparseTensor(
    610           array_ops.placeholder(dtype=np.int64, shape=(2, 3)),
    611           array_ops.placeholder(dtype=np.float32, shape=(2,)),
    612           array_ops.placeholder(dtype=np.int64, shape=(3,)),
    613       )
    614       sp_indices = array_ops.identity(sp.indices)
    615       sp_values = array_ops.identity(sp.values)
    616       sp_shape = array_ops.identity(sp.dense_shape)
    617       sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
    618       # Feed with tuple
    619       indices_out, values_out, shape_out = s.run(
    620           [sp_indices, sp_values, sp_shape], {
    621               sp: (indices, values, shape)
    622           })
    623       self.assertAllEqual(indices_out, indices)
    624       self.assertAllEqual(values_out, values)
    625       self.assertAllEqual(shape_out, shape)
    626       # Feed with tuple, fetch sp directly
    627       sp_out = s.run(sp, {sp: (indices, values, shape)})
    628       self.assertAllEqual(sp_out.indices, indices)
    629       self.assertAllEqual(sp_out.values, values)
    630       self.assertAllEqual(sp_out.dense_shape, shape)
    631       # Feed with SparseTensorValue
    632       indices_out, values_out, shape_out = s.run(
    633           [sp_indices, sp_values, sp_shape], {
    634               sp: sparse_tensor.SparseTensorValue(indices, values, shape)
    635           })
    636       self.assertAllEqual(indices_out, indices)
    637       self.assertAllEqual(values_out, values)
    638       self.assertAllEqual(shape_out, shape)
    639       # Feed with SparseTensorValue, fetch SparseTensorValue
    640       sp2_out = s.run(sp2, {
    641           sp: sparse_tensor.SparseTensorValue(indices, values, shape)
    642       })
    643       self.assertAllEqual(sp2_out.indices, indices)
    644       self.assertAllEqual(sp2_out.values, values)
    645       self.assertAllEqual(sp2_out.dense_shape, shape)
    646       # Feed SparseTensorValue and fetch sp directly.
    647       sp_out = s.run(sp, {
    648           sp: sparse_tensor.SparseTensorValue(indices, values, shape)
    649       })
    650       self.assertAllEqual(sp_out.indices, indices)
    651       self.assertAllEqual(sp_out.values, values)
    652       self.assertAllEqual(sp_out.dense_shape, shape)
    653 
    654   def testFeedSparsePlaceholder(self):
    655     with session.Session() as s:
    656       indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
    657       values = np.array([1.0, 2.0]).astype(np.float32)
    658       shape = np.array([7, 9, 2]).astype(np.int64)
    659       sp = array_ops.sparse_placeholder(dtype=np.float32, name='placeholder1')
    660       sp_indices = array_ops.identity(sp.indices)
    661       sp_values = array_ops.identity(sp.values)
    662       sp_shape = array_ops.identity(sp.dense_shape)
    663       sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
    664       # Feed with tuple
    665       indices_out, values_out, shape_out = s.run(
    666           [sp_indices, sp_values, sp_shape], {
    667               sp: (indices, values, shape)
    668           })
    669       self.assertAllEqual(indices_out, indices)
    670       self.assertAllEqual(values_out, values)
    671       self.assertAllEqual(shape_out, shape)
    672       # Feed with SparseTensorValue
    673       indices_out, values_out, shape_out = s.run(
    674           [sp_indices, sp_values, sp_shape], {
    675               sp: sparse_tensor.SparseTensorValue(indices, values, shape)
    676           })
    677       self.assertAllEqual(indices_out, indices)
    678       self.assertAllEqual(values_out, values)
    679       self.assertAllEqual(shape_out, shape)
    680       # Feed with SparseTensorValue, fetch SparseTensorValue
    681       sp2_out = s.run(sp2, {
    682           sp: sparse_tensor.SparseTensorValue(indices, values, shape)
    683       })
    684       self.assertAllEqual(sp2_out.indices, indices)
    685       self.assertAllEqual(sp2_out.values, values)
    686       self.assertAllEqual(sp2_out.dense_shape, shape)
    687 
    688   def testFeedSparsePlaceholderPartialShape(self):
    689     with session.Session() as s:
    690       indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
    691       values = np.array([1.0, 2.0]).astype(np.float32)
    692       shape = np.array([7, 9, 2]).astype(np.int64)
    693       sp = array_ops.sparse_placeholder(
    694           shape=[None, 9, 2], dtype=np.float32, name='placeholder1')
    695       sp_indices = array_ops.identity(sp.indices)
    696       sp_values = array_ops.identity(sp.values)
    697       sp_shape = array_ops.identity(sp.dense_shape)
    698       sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
    699       # Feed with tuple
    700       indices_out, values_out, shape_out = s.run(
    701           [sp_indices, sp_values, sp_shape], {
    702               sp: (indices, values, shape)
    703           })
    704       self.assertAllEqual(indices_out, indices)
    705       self.assertAllEqual(values_out, values)
    706       self.assertAllEqual(shape_out, shape)
    707       # Feed with SparseTensorValue
    708       indices_out, values_out, shape_out = s.run(
    709           [sp_indices, sp_values, sp_shape], {
    710               sp: sparse_tensor.SparseTensorValue(indices, values, shape)
    711           })
    712       self.assertAllEqual(indices_out, indices)
    713       self.assertAllEqual(values_out, values)
    714       self.assertAllEqual(shape_out, shape)
    715       # Feed with SparseTensorValue, fetch SparseTensorValue
    716       sp2_out = s.run(sp2, {
    717           sp: sparse_tensor.SparseTensorValue(indices, values, shape)
    718       })
    719       self.assertAllEqual(sp2_out.indices, indices)
    720       self.assertAllEqual(sp2_out.values, values)
    721       self.assertAllEqual(sp2_out.dense_shape, shape)
    722 
    723   def testFeedSparsePlaceholderConstantShape(self):
    724     with session.Session() as s:
    725       indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
    726       values = np.array([1.0, 2.0]).astype(np.float32)
    727       shape = np.array([7, 9, 2]).astype(np.int64)
    728       sp = array_ops.sparse_placeholder(
    729           dtype=np.float32, shape=shape, name='placeholder1')
    730       self.assertAllEqual(sp.dense_shape.eval(session=s), shape)
    731       self.assertAllEqual(tensor_util.constant_value(sp.dense_shape), shape)
    732       sp_indices = array_ops.identity(sp.indices)
    733       sp_values = array_ops.identity(sp.values)
    734       sp_shape = array_ops.identity(sp.dense_shape)
    735       # Feed with tuple
    736       indices_out, values_out, shape_out = s.run(
    737           [sp_indices, sp_values, sp_shape], {
    738               sp: (indices, values)
    739           })
    740       self.assertAllEqual(indices_out, indices)
    741       self.assertAllEqual(values_out, values)
    742       self.assertAllEqual(shape_out, shape)
    743 
    744   def testFetchIndexedSlices(self):
    745     with session.Session() as s:
    746       indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
    747       values = np.array([1.0, 2.0]).astype(np.float32)
    748       dense_shape = np.array([7, 9, 2]).astype(np.int64)
    749       ind = ops.IndexedSlices(
    750           constant_op.constant(values), constant_op.constant(indices),
    751           constant_op.constant(dense_shape))
    752       # Single fetch, use as tuple
    753       ind_out = s.run(ind)
    754       values_out, indices_out, dense_shape_out = ind_out
    755       self.assertAllEqual(values_out, values)
    756       self.assertAllEqual(indices_out, indices)
    757       self.assertAllEqual(dense_shape_out, dense_shape)
    758       # Single fetch, use as IndexedSlicesValue
    759       ind_out = s.run(ind)
    760       self.assertAllEqual(ind_out.values, values)
    761       self.assertAllEqual(ind_out.indices, indices)
    762       self.assertAllEqual(ind_out.dense_shape, dense_shape)
    763       # Tuple fetch, use as tuple
    764       values_out, indices_out, dense_shape_out = s.run(ind)
    765       self.assertAllEqual(values_out, values)
    766       self.assertAllEqual(indices_out, indices)
    767       self.assertAllEqual(dense_shape_out, dense_shape)
    768       # List fetch, use as tuple
    769       (values_out, indices_out, dense_shape_out), = s.run([ind])
    770       self.assertAllEqual(values_out, values)
    771       self.assertAllEqual(indices_out, indices)
    772       self.assertAllEqual(dense_shape_out, dense_shape)
    773       # List fetch, use as IndexedSlicesValue
    774       ind_out, = s.run([ind])
    775       self.assertAllEqual(ind_out.values, values)
    776       self.assertAllEqual(ind_out.indices, indices)
    777       self.assertAllEqual(ind_out.dense_shape, dense_shape)
    778 
    779   def testFeedIndexedSlices(self):
    780     with session.Session() as s:
    781       values = np.array([1.0, 2.0]).astype(np.float32)
    782       indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
    783       dense_shape = np.array([7, 9, 2]).astype(np.int64)
    784       ind = ops.IndexedSlices(
    785           array_ops.placeholder(dtype=np.float32, shape=(2,)),
    786           array_ops.placeholder(dtype=np.int64, shape=(2, 3)),
    787           array_ops.placeholder(dtype=np.int64, shape=(3,)),
    788       )
    789       ind_values = array_ops.identity(ind.values)
    790       ind_indices = array_ops.identity(ind.indices)
    791       ind_dense_shape = array_ops.identity(ind.dense_shape)
    792       ind2 = ops.IndexedSlices(ind_values, ind_indices, ind_dense_shape)
    793       # Feed with tuple
    794       values_out, indices_out, dense_shape_out = s.run(
    795           [ind_values, ind_indices, ind_dense_shape], {
    796               ind: (values, indices, dense_shape)
    797           })
    798       self.assertAllEqual(values_out, values)
    799       self.assertAllEqual(indices_out, indices)
    800       self.assertAllEqual(dense_shape_out, dense_shape)
    801       # Feed with IndexedSlicesValue
    802       values_out, indices_out, dense_shape_out = s.run(
    803           [ind_values, ind_indices, ind_dense_shape], {
    804               ind: ops.IndexedSlicesValue(values, indices, dense_shape)
    805           })
    806       self.assertAllEqual(values_out, values)
    807       self.assertAllEqual(indices_out, indices)
    808       self.assertAllEqual(dense_shape_out, dense_shape)
    809       # Feed with IndexedSlicesValue, fetch IndexedSlicesValue
    810       ind2_out = s.run(ind2, {
    811           ind: ops.IndexedSlicesValue(values, indices, dense_shape)
    812       })
    813       self.assertAllEqual(ind2_out.values, values)
    814       self.assertAllEqual(ind2_out.indices, indices)
    815       self.assertAllEqual(ind2_out.dense_shape, dense_shape)
    816 
    817   def testFetchIndexedSlicesWithoutDenseShape(self):
    818     with session.Session() as s:
    819       indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
    820       values = np.array([1.0, 2.0]).astype(np.float32)
    821       dense_shape = None
    822       ind = ops.IndexedSlices(
    823           constant_op.constant(values), constant_op.constant(indices), None)
    824       # Single fetch, use as tuple
    825       ind_out = s.run(ind)
    826       values_out, indices_out, dense_shape_out = ind_out
    827       self.assertAllEqual(values_out, values)
    828       self.assertAllEqual(indices_out, indices)
    829       self.assertAllEqual(dense_shape_out, dense_shape)
    830       # Single fetch, use as IndexedSlicesValue
    831       ind_out = s.run(ind)
    832       self.assertAllEqual(ind_out.values, values)
    833       self.assertAllEqual(ind_out.indices, indices)
    834       self.assertAllEqual(ind_out.dense_shape, dense_shape)
    835       # Tuple fetch, use as tuple
    836       values_out, indices_out, dense_shape_out = s.run(ind)
    837       self.assertAllEqual(values_out, values)
    838       self.assertAllEqual(indices_out, indices)
    839       self.assertAllEqual(dense_shape_out, dense_shape)
    840       # List fetch, use as tuple
    841       (values_out, indices_out, dense_shape_out), = s.run([ind])
    842       self.assertAllEqual(values_out, values)
    843       self.assertAllEqual(indices_out, indices)
    844       self.assertAllEqual(dense_shape_out, dense_shape)
    845       # List fetch, use as IndexedSlicesValue
    846       ind_out, = s.run([ind])
    847       self.assertAllEqual(ind_out.values, values)
    848       self.assertAllEqual(ind_out.indices, indices)
    849       self.assertAllEqual(ind_out.dense_shape, dense_shape)
    850 
    851   def testFeedIndexedSlicesWithoutDenseShape(self):
    852     with session.Session() as s:
    853       values = np.array([1.0, 2.0]).astype(np.float32)
    854       indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
    855       dense_shape = None
    856       ind = ops.IndexedSlices(
    857           array_ops.placeholder(dtype=np.float32, shape=(2,)),
    858           array_ops.placeholder(dtype=np.int64, shape=(2, 3)), None)
    859       ind_values = array_ops.identity(ind.values)
    860       ind_indices = array_ops.identity(ind.indices)
    861       ind2 = ops.IndexedSlices(ind_values, ind_indices)
    862       # Feed with tuple
    863       values_out, indices_out = s.run([ind_values, ind_indices], {
    864           ind: (values, indices)
    865       })
    866       self.assertAllEqual(values_out, values)
    867       self.assertAllEqual(indices_out, indices)
    868       # Feed with IndexedSlicesValue
    869       values_out, indices_out = s.run([ind_values, ind_indices], {
    870           ind: ops.IndexedSlicesValue(values, indices, dense_shape)
    871       })
    872       self.assertAllEqual(values_out, values)
    873       self.assertAllEqual(indices_out, indices)
    874       # Feed with IndexedSlicesValue, fetch IndexedSlicesValue
    875       ind2_out = s.run(ind2, {
    876           ind: ops.IndexedSlicesValue(values, indices, dense_shape)
    877       })
    878       self.assertAllEqual(ind2_out.values, values)
    879       self.assertAllEqual(ind2_out.indices, indices)
    880       self.assertAllEqual(ind2_out.dense_shape, dense_shape)
    881 
    882   def testExtendWithStatelessOperations(self):
    883     with session.Session() as s:
    884       a = constant_op.constant(1.0, shape=[1, 2])
    885       b = constant_op.constant(2.0, shape=[2, 3])
    886       c = math_ops.matmul(a, b)
    887       c_val = s.run(c)
    888       self.assertAllEqual([[4.0, 4.0, 4.0]], c_val)
    889       d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1])
    890       e = math_ops.matmul(c, d)
    891       # Extend will happen here.
    892       e_val = s.run(e)
    893       self.assertAllEqual([[24.0]], e_val)
    894 
    895   def testExtendWithStatefulOperations(self):
    896     with session.Session() as s:
    897       a = constant_op.constant(1.0, shape=[1, 2])
    898       b = constant_op.constant(2.0, shape=[2, 3])
    899       c = math_ops.matmul(a, b)
    900       v = variables.Variable(c, name='testExtendWithStatefulOperations_v')
    901       v.initializer.run()
    902       v_val = v.eval()
    903       self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
    904       d = constant_op.constant(3.0, shape=[2, 3])
    905       e = math_ops.matmul(a, d)
    906       assign_e_to_v = state_ops.assign(v, e)
    907       # Extend will happen here.
    908       e_val = e.eval()
    909       self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
    910       v_val = v.eval()
    911       self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
    912       s.run(assign_e_to_v)
    913       v_val = v.eval()
    914       self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
    915 
    916   def testExtendWithGroupBy(self):
    917     with session.Session() as s:
    918       a = constant_op.constant(1.0, shape=[1, 2])
    919       p = variables.Variable(a, name='testExtendWithGroupBy_p')
    920       a_val = a.eval()  # Force an Extend after this op.
    921       self.assertAllEqual([[1.0, 1.0]], a_val)
    922 
    923       b = constant_op.constant(2.0, shape=[1, 2])
    924       q = variables.Variable(b, name='testExtendWithGroupBy_q')
    925       # Extend will happen here.
    926       init = control_flow_ops.group(p.initializer, q.initializer)
    927       s.run(init)
    928       p_val, q_val = s.run([p, q])
    929 
    930       self.assertAllEqual([[1.0, 1.0]], p_val)
    931       self.assertAllEqual([[2.0, 2.0]], q_val)
    932 
    933   def testTensorGetMethod(self):
    934     with session.Session():
    935       a = constant_op.constant(1.0, shape=[1, 2])
    936       b = constant_op.constant(2.0, shape=[2, 3])
    937       c = math_ops.matmul(a, b)
    938 
    939       c_val = c.eval()
    940       self.assertAllEqual([[4.0, 4.0, 4.0]], c_val)
    941 
    942       fed_c_val = c.eval(feed_dict={a.name: [[4.0, 4.0]]})
    943       self.assertAllEqual([[16.0, 16.0, 16.0]], fed_c_val)
    944 
    945   def testOperationRunMethod(self):
    946     with session.Session():
    947       a = constant_op.constant(1.0, shape=[1, 2])
    948       b = constant_op.constant(2.0, shape=[1, 2], name='b')
    949       v = variables.Variable(a, a.dtype)
    950       assign_a_to_v = state_ops.assign(v, a)
    951 
    952       assign_a_to_v.eval()
    953 
    954       v_val = v.eval()
    955       self.assertAllEqual([[1.0, 1.0]], v_val)
    956 
    957       assign_b_to_v = state_ops.assign(v, b)
    958 
    959       assign_b_to_v.eval()
    960       v_val = v.eval()
    961       self.assertAllEqual([[2.0, 2.0]], v_val)
    962 
    963       assign_b_to_v.eval(feed_dict={'b:0': [[3.0, 3.0]]})
    964       v_val = v.eval()
    965       self.assertAllEqual([[3.0, 3.0]], v_val)
    966 
    967   def testDefaultGraph(self):
    968     with session.Session() as s:
    969       self.assertEqual(ops.get_default_graph(), s.graph)
    970       a = constant_op.constant(1.0, shape=[1, 2])
    971       b = constant_op.constant(2.0, shape=[2, 3])
    972       self.assertEqual(ops.get_default_graph(), a.graph)
    973       self.assertEqual(ops.get_default_graph(), b.graph)
    974       c = math_ops.matmul(a, b)
    975       v = variables.Variable(c, name='testDefaultGraph_v')
    976       v.initializer.run()
    977       v_val = v.eval()
    978       self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
    979       d = constant_op.constant(3.0, shape=[2, 3])
    980       e = math_ops.matmul(a, d)
    981       assign_e_to_v = state_ops.assign(v, e)
    982       e_val = e.eval()
    983       self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
    984       v_val = v.eval()
    985       self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
    986       s.run(assign_e_to_v)
    987       v_val = v.eval()
    988       self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
    989       self.assertEqual(ops.get_default_graph(), s.graph)
    990 
    991   def _testDefaultGraphInThread(self, constructed_event, continue_event, i):
    992     with session.Session() as s:
    993       self.assertEqual(ops.get_default_graph(), s.graph)
    994       a = constant_op.constant(1.0, shape=[1, 2])
    995       b = constant_op.constant(2.0, shape=[2, 3])
    996       c = math_ops.matmul(a, b)
    997       v = variables.Variable(c, name='var_%d' % i)
    998 
    999       # Block here until all threads have constructed their graph.
   1000       constructed_event.set()
   1001       continue_event.wait()
   1002 
   1003       assign_c_to_v = state_ops.assign(v, c)
   1004       v.initializer.run()
   1005       assign_c_to_v.eval()
   1006       v_val = v.eval()
   1007       self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
   1008       d = constant_op.constant(3.0, shape=[2, 3])
   1009       e = math_ops.matmul(a, d)
   1010       assign_e_to_v = state_ops.assign(v, e)
   1011       e_val = e.eval()
   1012       self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
   1013       v_val = v.eval()
   1014       self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
   1015       s.run(assign_e_to_v)
   1016       v_val = v.eval()
   1017       self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
   1018       self.assertEqual(ops.get_default_graph(), s.graph)
   1019 
   1020   def testDefaultGraphWithThreads(self):
   1021     # Fork ten threads that use their thread-local default graph.
   1022     threads = []
   1023     constructed_events = [threading.Event() for _ in range(10)]
   1024     continue_event = threading.Event()
   1025     for i, constructed_event in enumerate(constructed_events):
   1026       t = self.checkedThread(
   1027           target=self._testDefaultGraphInThread,
   1028           args=(constructed_event, continue_event, i))
   1029       threads.append(t)
   1030     for t in threads:
   1031       t.start()
   1032     for constructed_event in constructed_events:
   1033       constructed_event.wait()
   1034     continue_event.set()
   1035     for t in threads:
   1036       t.join()
   1037 
   1038   def testParallelRun(self):
   1039     with session.Session() as sess:
   1040       c = constant_op.constant(5.0)
   1041       ev = threading.Event()
   1042 
   1043       def run_step():
   1044         ev.wait()
   1045         val = c.eval(session=sess)
   1046         self.assertEqual(val, 5.0)
   1047 
   1048       threads = [self.checkedThread(target=run_step) for _ in range(100)]
   1049       for t in threads:
   1050         t.start()
   1051       ev.set()
   1052       for t in threads:
   1053         t.join()
   1054 
   1055   def testRunFeedDict(self):
   1056     with session.Session() as s:
   1057       x = array_ops.zeros([2])
   1058 
   1059       y = s.run(2 * x, feed_dict={x: np.ones(2).astype(np.float32)})
   1060       self.assertAllEqual(y, 2 * np.ones(2))
   1061 
   1062       y = s.run(2 * x, feed_dict={x.name: np.ones(2).astype(np.float32)})
   1063       self.assertAllEqual(y, 2 * np.ones(2))
   1064 
   1065       y = s.run(2 * x, feed_dict={x: [1, 1]})
   1066       assert (y == 2 * np.ones(2)).all()
   1067 
   1068       # Test nested tuple keys
   1069       z = (((array_ops.zeros([2]),),), array_ops.zeros([2]),
   1070            (array_ops.zeros([2]),))
   1071       result = [z[0][0][0] * 2, z[1] * 2, z[2][0] * 2]
   1072       values = (((np.array([1, 1]),),), np.array([2, 2]), (np.array([3, 3]),))
   1073       result_value = s.run(result, feed_dict={z: values})
   1074       self.assertAllEqual(result_value[0], 2 * np.ones(2))
   1075       self.assertAllEqual(result_value[1], 2 * np.array([2, 2]))
   1076       self.assertAllEqual(result_value[2], 2 * np.array([3, 3]))
   1077 
   1078   def testGraphDef(self):
   1079     with session.Session() as sess:
   1080       self.assertProtoEquals('versions { producer: %d min_consumer: %d }' %
   1081                              (versions.GRAPH_DEF_VERSION,
   1082                               versions.GRAPH_DEF_VERSION_MIN_CONSUMER),
   1083                              sess.graph_def)
   1084       c = constant_op.constant(5.0, name='c')
   1085       self.assertEquals(len(sess.graph_def.node), 1)
   1086       d = constant_op.constant(6.0, name='d')
   1087       self.assertEquals(len(sess.graph_def.node), 2)
   1088       self.assertAllEqual(c.eval(), 5.0)
   1089       self.assertAllEqual(d.eval(), 6.0)
   1090       e = constant_op.constant(7.0, name='e')
   1091       self.assertEquals(len(sess.graph_def.node), 3)
   1092       self.assertAllEqual(e.eval(), 7.0)
   1093 
   1094   def testUseAfterClose(self):
   1095     with session.Session() as sess:
   1096       c = constant_op.constant(5.0)
   1097       self.assertAllEqual(sess.run(c), 5.0)
   1098     with self.assertRaisesWithPredicateMatch(
   1099         RuntimeError, lambda e: 'Attempted to use a closed Session.' in str(e)):
   1100       sess.run(c)
   1101 
   1102   def testUseAfterCloseConcurrent(self):
   1103     with session.Session() as sess:
   1104       c = constant_op.constant(5.0)
   1105       self.assertAllEqual(sess.run(c), 5.0)
   1106 
   1107       def update_thread():
   1108         with self.assertRaisesWithPredicateMatch(
   1109             RuntimeError,
   1110             lambda e: 'Attempted to use a closed Session.' in str(e)):
   1111           while True:
   1112             sess.run(c)
   1113 
   1114       t = threading.Thread(target=update_thread)
   1115       t.start()
   1116       time.sleep(0.1)
   1117       sess.close()
   1118       t.join()
   1119 
   1120   def testUseEmptyGraph(self):
   1121     with session.Session() as sess:
   1122       with self.assertRaisesRegexp(RuntimeError, 'The Session graph is empty.'):
   1123         sess.run([])
   1124       with self.assertRaisesRegexp(RuntimeError, 'The Session graph is empty.'):
   1125         sess.run(())
   1126       with self.assertRaisesRegexp(RuntimeError, 'The Session graph is empty.'):
   1127         sess.run({})
   1128 
   1129   def testNotEntered(self):
   1130     # pylint: disable=protected-access
   1131     self.assertEqual(ops._default_session_stack.get_default(), None)
   1132     # pylint: enable=protected-access
   1133     with ops.device('/cpu:0'):
   1134       sess = session.Session()
   1135       c_1 = constant_op.constant(5.0)
   1136       with sess.graph.as_default():
   1137         c_2 = constant_op.constant(5.0)
   1138       self.assertEqual(c_1.graph, c_2.graph)
   1139       self.assertEqual(sess.run(c_2), 5.0)
   1140       with self.assertRaisesWithPredicateMatch(
   1141           ValueError, lambda e: 'No default session is registered.' in str(e)):
   1142         c_2.eval()
   1143 
   1144   def testInteractive(self):
   1145     with ops.device('/cpu:0'):
   1146       sess = session.InteractiveSession()
   1147       a = constant_op.constant(1.0, shape=[1, 2])
   1148       b = constant_op.constant(2.0, shape=[2, 3])
   1149       c = math_ops.matmul(a, b)
   1150       self.assertAllEqual([[4.0, 4.0, 4.0]], c.eval())
   1151       d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1])
   1152       e = math_ops.matmul(c, d)
   1153       self.assertAllEqual([[24.0]], e.eval())
   1154       sess.close()
   1155 
   1156   def testInteractivePlacePrunedGraph(self):
   1157     sess = session.InteractiveSession()
   1158 
   1159     # Build a graph that has a bad op in it (no kernel).
   1160     #
   1161     # This test currently does not link in any GPU kernels,
   1162     # which is why placing this is invalid.  If at some point
   1163     # GPU kernels are added to this test, some other different
   1164     # op / device combo should be chosen.
   1165     with ops.device('/device:GPU:0'):
   1166       a = constant_op.constant(1.0, shape=[1, 2])
   1167 
   1168     b = constant_op.constant(1.0, shape=[1, 2])
   1169 
   1170     # Only run the valid op, this should work.
   1171     b.eval()
   1172 
   1173     with self.assertRaises(errors.InvalidArgumentError):
   1174       a.eval()
   1175     sess.close()
   1176 
   1177   def testDefaultSessionPlacePrunedGraph(self):
   1178     sess = session.Session()
   1179 
   1180     # Build a graph that has a bad op in it (no kernel).
   1181     #
   1182     # This test currently does not link in any GPU kernels,
   1183     # which is why placing this is invalid.  If at some point
   1184     # GPU kernels are added to this test, some other different
   1185     # op / device combo should be chosen.
   1186     with ops.device('/device:GPU:0'):
   1187       _ = constant_op.constant(1.0, shape=[1, 2])
   1188 
   1189     b = constant_op.constant(1.0, shape=[1, 2])
   1190 
   1191     with self.assertRaises(errors.InvalidArgumentError):
   1192       # Even though we don't run the bad op, we place the entire
   1193       # graph, which should fail with a non-interactive session.
   1194       sess.run(b)
   1195 
   1196     sess.close()
   1197 
   1198   def testSharedGraph(self):
   1199     with ops.Graph().as_default() as g, ops.device('/cpu:0'):
   1200       a = constant_op.constant(1.0, shape=[1, 2])
   1201       b = constant_op.constant(2.0, shape=[2, 3])
   1202       c = math_ops.matmul(a, b)
   1203 
   1204     with session.Session(graph=g) as sess1:
   1205       with session.Session(graph=g) as sess2:
   1206         self.assertAllEqual(sess1.run(c), sess2.run(c))
   1207 
   1208   def testDuplicatedInputs(self):
   1209     with session.Session() as sess:
   1210       a = constant_op.constant(1.0, shape=[1, 2])
   1211       b = constant_op.constant(2.0, shape=[1, 3])
   1212       a_val, b_val, a2_val = sess.run([a, b, a])
   1213       self.assertAllEqual(a_val, [[1.0, 1.0]])
   1214       self.assertAllEqual(b_val, [[2.0, 2.0, 2.0]])
   1215       self.assertAllEqual(a2_val, [[1.0, 1.0]])
   1216 
   1217   def testFeedAndFetch(self):
   1218     with session.Session() as sess:
   1219       for dtype in [
   1220           dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32,
   1221           dtypes.uint8, dtypes.int16, dtypes.int8, dtypes.int64, dtypes.bool,
   1222           dtypes.complex64, dtypes.complex128
   1223       ]:
   1224         for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
   1225           np_dtype = dtype.as_numpy_dtype
   1226 
   1227           feed_t = array_ops.placeholder(dtype=dtype, shape=shape)
   1228           out_t = array_ops.identity(feed_t)
   1229 
   1230           np_array = np.random.randint(-10, 10, shape)
   1231 
   1232           if dtype == dtypes.bool:
   1233             np_array = np_array > 0
   1234           elif dtype == dtypes.complex64:
   1235             np_array = np.sqrt(np_array.astype(np_dtype))
   1236           elif dtype == dtypes.complex64:
   1237             np_array = np.sqrt(np_array.astype(np_dtype))
   1238           else:
   1239             np_array = np_array.astype(np_dtype)
   1240 
   1241           self.assertAllEqual(np_array,
   1242                               sess.run(out_t, feed_dict={
   1243                                   feed_t: np_array
   1244                               }))
   1245           # Check that we can also get the feed back.
   1246           self.assertAllEqual(np_array,
   1247                               sess.run(feed_t, feed_dict={
   1248                                   feed_t: np_array
   1249                               }))
   1250           # Also check that we can get both back.
   1251           out_v, feed_v = sess.run(
   1252               [out_t, feed_t], feed_dict={
   1253                   feed_t: np_array
   1254               })
   1255           self.assertAllEqual(np_array, out_v)
   1256           self.assertAllEqual(np_array, feed_v)
   1257 
   1258           feed_fetch_runner = sess.make_callable([out_t, feed_t], [feed_t])
   1259           out_v, feed_v = feed_fetch_runner(np_array)
   1260           self.assertAllEqual(np_array, out_v)
   1261           self.assertAllEqual(np_array, feed_v)
   1262 
   1263   def testMakeCallableOnTensorWithRunOptions(self):
   1264     with session.Session() as sess:
   1265       a = constant_op.constant(42.0)
   1266       tensor_runner = sess.make_callable(a, accept_options=True)
   1267       run_options = config_pb2.RunOptions(
   1268           trace_level=config_pb2.RunOptions.FULL_TRACE)
   1269       run_metadata = config_pb2.RunMetadata()
   1270       self.assertEqual(0, len(run_metadata.step_stats.dev_stats))
   1271       res = tensor_runner(options=run_options, run_metadata=run_metadata)
   1272       self.assertEqual(42.0, res)
   1273       self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
   1274 
   1275   def testMakeCallableOnOperationWithRunOptions(self):
   1276     with session.Session() as sess:
   1277       a = variables.Variable(42.0)
   1278       b = state_ops.assign_add(a, 1.0)
   1279       sess.run(a.initializer)
   1280       tensor_runner = sess.make_callable(b.op, accept_options=True)
   1281       run_options = config_pb2.RunOptions(
   1282           trace_level=config_pb2.RunOptions.FULL_TRACE)
   1283       run_metadata = config_pb2.RunMetadata()
   1284       self.assertEqual(0, len(run_metadata.step_stats.dev_stats))
   1285       tensor_runner(options=run_options, run_metadata=run_metadata)
   1286       self.assertEqual(43.0, sess.run(a))
   1287       self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
   1288 
   1289   def testMakeCallableWithFeedListAndRunOptions(self):
   1290     with session.Session() as sess:
   1291       ph = array_ops.placeholder(dtypes.float32)
   1292       a = math_ops.add(ph, 1.0)
   1293       tensor_runner = sess.make_callable(
   1294           a, feed_list=[ph.name], accept_options=True)
   1295       run_options = config_pb2.RunOptions(
   1296           trace_level=config_pb2.RunOptions.FULL_TRACE)
   1297       run_metadata = config_pb2.RunMetadata()
   1298       self.assertEqual(0, len(run_metadata.step_stats.dev_stats))
   1299       self.assertAllClose(42.0,
   1300                           tensor_runner(
   1301                               41.0,
   1302                               options=run_options,
   1303                               run_metadata=run_metadata))
   1304       self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
   1305 
   1306   def testFeedError(self):
   1307     with session.Session() as sess:
   1308       feed_t = array_ops.placeholder(dtype=dtypes.float32)
   1309       out_t = array_ops.identity(feed_t)
   1310       feed_val = constant_op.constant(5.0)
   1311       with self.assertRaisesRegexp(TypeError, 'cannot be a tf.Tensor object'):
   1312         sess.run(out_t, feed_dict={feed_t: feed_val})
   1313       with self.assertRaisesRegexp(TypeError, 'cannot be a tf.Tensor object'):
   1314         out_t.eval(feed_dict={feed_t: feed_val})
   1315       with self.assertRaisesRegexp(TypeError, 'cannot be a tf.Tensor object'):
   1316         out_t.op.run(feed_dict={feed_t: feed_val})
   1317 
   1318   def testFeedPrecisionLossError(self):
   1319     with session.Session() as sess:
   1320       largest_int64 = np.iinfo(np.int64).max
   1321 
   1322       feed_int_implicit_int32 = constant_op.constant(1)
   1323       feed_int_explicit_int32 = constant_op.constant(1, dtype=dtypes.int32)
   1324 
   1325       out_t = constant_op.constant(1.0)
   1326 
   1327       with self.assertRaisesRegexp(TypeError,
   1328                                    'is not compatible with Tensor type'):
   1329         sess.run(out_t, feed_dict={feed_int_implicit_int32: largest_int64})
   1330       with self.assertRaisesRegexp(TypeError,
   1331                                    'is not compatible with Tensor type'):
   1332         sess.run(out_t, feed_dict={feed_int_explicit_int32: largest_int64})
   1333 
   1334   def testStringFetch(self):
   1335     with session.Session():
   1336       for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
   1337         size = 1
   1338         for s in shape:
   1339           size *= s
   1340         c_list = np.array(
   1341             [compat.as_bytes(str(i)) for i in xrange(size)],
   1342             dtype=np.object).reshape(shape) if size > 0 else []
   1343         c = constant_op.constant(c_list)
   1344         self.assertAllEqual(c.eval(), c_list)
   1345 
   1346   def testStringFeed(self):
   1347     with session.Session() as sess:
   1348       for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
   1349         size = 1
   1350         for s in shape:
   1351           size *= s
   1352         c_list = np.array(
   1353             [compat.as_bytes(str(i)) for i in xrange(size)],
   1354             dtype=np.object).reshape(shape)
   1355         feed_t = array_ops.placeholder(dtype=dtypes.string, shape=shape)
   1356         c = array_ops.identity(feed_t)
   1357         self.assertAllEqual(sess.run(c, feed_dict={feed_t: c_list}), c_list)
   1358         self.assertAllEqual(
   1359             sess.run(feed_t, feed_dict={
   1360                 feed_t: c_list
   1361             }), c_list)
   1362         c_v, feed_v = sess.run([c, feed_t], feed_dict={feed_t: c_list})
   1363         self.assertAllEqual(c_v, c_list)
   1364         self.assertAllEqual(feed_v, c_list)
   1365 
   1366   def testStringFeedWithNullCharacters(self):
   1367     with session.Session():
   1368       c_list = [b'\n\x01\x00', b'\n\x00\x01']
   1369       feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[2])
   1370       c = array_ops.identity(feed_t)
   1371       out = c.eval(feed_dict={feed_t: c_list})
   1372       self.assertEqual(c_list[0], out[0])
   1373       self.assertEqual(c_list[1], out[1])
   1374 
   1375   def testStringFeedWithUnicode(self):
   1376     with session.Session():
   1377       c_list = [
   1378           u'\n\x01\x00', u'\n\x00\x01', u'\u26a3 unicode',
   1379           u'\U0001f60e deal with it'
   1380       ]
   1381       feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[len(c_list)])
   1382       c = array_ops.identity(feed_t)
   1383 
   1384       out = c.eval(feed_dict={feed_t: c_list})
   1385       for i in range(len(c_list)):
   1386         self.assertEqual(c_list[i], out[i].decode('utf-8'))
   1387 
   1388       out = c.eval(feed_dict={feed_t: np.array(c_list, dtype=np.object)})
   1389       for i in range(len(c_list)):
   1390         self.assertEqual(c_list[i], out[i].decode('utf-8'))
   1391 
   1392   def testInvalidTargetFails(self):
   1393     with self.assertRaisesRegexp(
   1394         errors.NotFoundError,
   1395         'No session factory registered for the given session options'):
   1396       session.Session('INVALID_TARGET')
   1397 
   1398   def testFetchByNameDifferentStringTypes(self):
   1399     with session.Session() as sess:
   1400       c = constant_op.constant(42.0, name='c')
   1401       d = constant_op.constant(43.0, name=u'd')
   1402       e = constant_op.constant(44.0, name=b'e')
   1403       f = constant_op.constant(45.0, name=r'f')
   1404 
   1405       self.assertTrue(isinstance(c.name, six.text_type))
   1406       self.assertTrue(isinstance(d.name, six.text_type))
   1407       self.assertTrue(isinstance(e.name, six.text_type))
   1408       self.assertTrue(isinstance(f.name, six.text_type))
   1409 
   1410       self.assertEqual(42.0, sess.run('c:0'))
   1411       self.assertEqual(42.0, sess.run(u'c:0'))
   1412       self.assertEqual(42.0, sess.run(b'c:0'))
   1413       self.assertEqual(42.0, sess.run(r'c:0'))
   1414 
   1415       self.assertEqual(43.0, sess.run('d:0'))
   1416       self.assertEqual(43.0, sess.run(u'd:0'))
   1417       self.assertEqual(43.0, sess.run(b'd:0'))
   1418       self.assertEqual(43.0, sess.run(r'd:0'))
   1419 
   1420       self.assertEqual(44.0, sess.run('e:0'))
   1421       self.assertEqual(44.0, sess.run(u'e:0'))
   1422       self.assertEqual(44.0, sess.run(b'e:0'))
   1423       self.assertEqual(44.0, sess.run(r'e:0'))
   1424 
   1425       self.assertEqual(45.0, sess.run('f:0'))
   1426       self.assertEqual(45.0, sess.run(u'f:0'))
   1427       self.assertEqual(45.0, sess.run(b'f:0'))
   1428       self.assertEqual(45.0, sess.run(r'f:0'))
   1429 
   1430   def testIncorrectGraph(self):
   1431     with ops.Graph().as_default() as g_1:
   1432       c_1 = constant_op.constant(1.0, name='c')
   1433 
   1434     with ops.Graph().as_default() as g_2:
   1435       c_2 = constant_op.constant(2.0, name='c')
   1436 
   1437     self.assertEqual('c', c_1.op.name)
   1438     self.assertEqual('c', c_2.op.name)
   1439 
   1440     with session.Session(graph=g_1) as sess_1:
   1441       self.assertEqual(1.0, sess_1.run(c_1))
   1442       with self.assertRaises(ValueError):
   1443         sess_1.run(c_2)
   1444       with self.assertRaises(ValueError):
   1445         sess_1.run(c_2.op)
   1446 
   1447     with session.Session(graph=g_2) as sess_2:
   1448       with self.assertRaises(ValueError):
   1449         sess_2.run(c_1)
   1450       with self.assertRaises(ValueError):
   1451         sess_2.run(c_1.op)
   1452       self.assertEqual(2.0, sess_2.run(c_2))
   1453 
   1454   def testFeedDictKeyException(self):
   1455     with session.Session() as sess:
   1456       a = constant_op.constant(1.0, dtypes.float32, name='a')
   1457       with self.assertRaisesRegexp(TypeError, 'Cannot interpret feed_dict'):
   1458         sess.run(a, feed_dict={'a': [2.0]})
   1459 
   1460   def testPerStepTrace(self):
   1461     run_options = config_pb2.RunOptions(
   1462         trace_level=config_pb2.RunOptions.FULL_TRACE)
   1463     run_metadata = config_pb2.RunMetadata()
   1464 
   1465     with ops.device('/cpu:0'):
   1466       with session.Session() as sess:
   1467         sess.run(constant_op.constant(1.0))
   1468         self.assertTrue(not run_metadata.HasField('step_stats'))
   1469 
   1470         sess.run(constant_op.constant(1.0), run_metadata=run_metadata)
   1471         self.assertTrue(not run_metadata.HasField('step_stats'))
   1472 
   1473         sess.run(
   1474             constant_op.constant(1.0),
   1475             options=run_options,
   1476             run_metadata=run_metadata)
   1477 
   1478         self.assertTrue(run_metadata.HasField('step_stats'))
   1479         self.assertEquals(len(run_metadata.step_stats.dev_stats), 1)
   1480 
   1481   def testRunOptionsRunMetadata(self):
   1482     run_options = config_pb2.RunOptions(
   1483         trace_level=config_pb2.RunOptions.FULL_TRACE)
   1484     run_metadata = config_pb2.RunMetadata()
   1485 
   1486     with ops.device('/cpu:0'):
   1487       with session.Session() as sess:
   1488         # all combinations are valid
   1489         sess.run(constant_op.constant(1.0), options=None, run_metadata=None)
   1490         sess.run(
   1491             constant_op.constant(1.0), options=None, run_metadata=run_metadata)
   1492         self.assertTrue(not run_metadata.HasField('step_stats'))
   1493 
   1494         sess.run(
   1495             constant_op.constant(1.0), options=run_options, run_metadata=None)
   1496         self.assertTrue(not run_metadata.HasField('step_stats'))
   1497 
   1498         sess.run(
   1499             constant_op.constant(1.0),
   1500             options=run_options,
   1501             run_metadata=run_metadata)
   1502 
   1503         self.assertTrue(run_metadata.HasField('step_stats'))
   1504         self.assertEquals(len(run_metadata.step_stats.dev_stats), 1)
   1505 
   1506   def testFeedShapeCompatibility(self):
   1507     # TODO(nolivia): C API doesn't yet handle marking nodes as not feedable.
   1508     if ops._USE_C_API:
   1509       return
   1510 
   1511     with session.Session() as sess:
   1512       some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0])
   1513       new_shape = constant_op.constant([2, 2])
   1514       reshaped_tensor = array_ops.reshape(some_tensor, new_shape)
   1515 
   1516       with self.assertRaisesRegexp(ValueError, 'Cannot feed value of shape'):
   1517         sess.run(reshaped_tensor, feed_dict={some_tensor: [1.0, 2.0, 3.0]})
   1518 
   1519       with self.assertRaisesRegexp(ValueError, 'may not be fed'):
   1520         sess.run(reshaped_tensor, feed_dict={new_shape: [3, 7]})
   1521 
   1522   def testInferShapesFalse(self):
   1523     with ops.Graph().as_default(), ops.device('/cpu:0'):
   1524       a = constant_op.constant([[1, 2]])
   1525       sess = session.Session()
   1526       self.assertFalse('_output_shapes' in sess.graph_def.node[0].attr)
   1527       # Avoid lint error regarding 'unused' var a.
   1528       self.assertTrue(a == a)
   1529 
   1530   def testInferShapesTrue(self):
   1531     config = config_pb2.ConfigProto(
   1532         graph_options=config_pb2.GraphOptions(infer_shapes=True))
   1533     with ops.Graph().as_default(), ops.device('/cpu:0'):
   1534       a = constant_op.constant([[1, 2]])
   1535       sess = session.Session(config=config)
   1536       self.assertTrue('_output_shapes' in sess.graph_def.node[0].attr)
   1537       # Avoid lint error regarding 'unused' var a.
   1538       self.assertTrue(a == a)
   1539 
   1540   def testBuildCostModel(self):
   1541     run_options = config_pb2.RunOptions()
   1542     config = config_pb2.ConfigProto(
   1543         allow_soft_placement=True,
   1544         graph_options=config_pb2.GraphOptions(build_cost_model=100))
   1545     with session.Session(config=config) as sess:
   1546       with ops.device('/device:GPU:0'):
   1547         a = array_ops.placeholder(dtypes.float32, shape=[])
   1548         b = math_ops.add(a, a)
   1549         c = array_ops.identity(b)
   1550         d = math_ops.multiply(c, c)
   1551       for step in xrange(120):
   1552         run_metadata = config_pb2.RunMetadata()
   1553         sess.run(
   1554             d,
   1555             feed_dict={a: 1.0},
   1556             options=run_options,
   1557             run_metadata=run_metadata)
   1558         if step == 99:
   1559           self.assertTrue(run_metadata.HasField('cost_graph'))
   1560         else:
   1561           self.assertFalse(run_metadata.HasField('cost_graph'))
   1562 
   1563   def runTestOutputPartitionGraphs(self, sess):
   1564     run_options = config_pb2.RunOptions(output_partition_graphs=True)
   1565     a = constant_op.constant(1)
   1566     run_metadata = config_pb2.RunMetadata()
   1567     sess.run(a, options=run_options, run_metadata=run_metadata)
   1568     self.assertGreater(len(run_metadata.partition_graphs), 0)
   1569     sess.run(a, run_metadata=run_metadata)
   1570     self.assertEqual(len(run_metadata.partition_graphs), 0)
   1571 
   1572   def testOutputPartitionGraphsDirect(self):
   1573     self.runTestOutputPartitionGraphs(session.Session())
   1574 
   1575   def testOutputPartitionGraphsDistributed(self):
   1576     server = server_lib.Server.create_local_server()
   1577     self.runTestOutputPartitionGraphs(session.Session(server.target))
   1578 
   1579   def testNonInteractiveSessionNesting(self):
   1580     sess1 = session.Session()
   1581     sess1_controller = sess1.as_default()
   1582     sess1_controller.__enter__()
   1583 
   1584     sess2 = session.Session()
   1585     sess2_controller = sess2.as_default()
   1586     sess2_controller.__enter__()
   1587 
   1588     with self.assertRaisesRegexp(AssertionError, 'Nesting violated'):
   1589       sess1_controller.__exit__(None, None, None)
   1590 
   1591     ops._default_session_stack.reset()
   1592 
   1593   def testInteractiveSessionNesting(self):
   1594     sess1 = session.InteractiveSession()
   1595     sess2 = session.InteractiveSession()
   1596     del sess1
   1597     del sess2
   1598 
   1599   def testAsDefault(self):
   1600     c = constant_op.constant(37)
   1601     sess = session.Session()
   1602     with sess.as_default():
   1603       self.assertEqual(37, c.eval())
   1604 
   1605     # Ensure that the session remains valid even when it is not captured.
   1606     with session.Session().as_default():
   1607       self.assertEqual(37, c.eval())
   1608 
   1609   def testReentry(self):
   1610     sess = session.Session()
   1611     with self.assertRaisesRegexp(RuntimeError, 'not re-entrant'):
   1612       with sess:
   1613         with sess:
   1614           pass
   1615 
   1616   def testInvalidArgument(self):
   1617     with self.assertRaisesRegexp(TypeError, 'target must be a string'):
   1618       session.Session(37)
   1619     with self.assertRaisesRegexp(TypeError, 'config must be a tf.ConfigProto'):
   1620       session.Session(config=37)
   1621     with self.assertRaisesRegexp(TypeError, 'graph must be a tf.Graph'):
   1622       session.Session(graph=37)
   1623 
   1624   def testTimeoutWithShortOperations(self):
   1625     num_epochs = 5
   1626     q = data_flow_ops.FIFOQueue(capacity=50, dtypes=[dtypes.int32], shapes=[()])
   1627     enqueue_op = q.enqueue_many(constant_op.constant([1, 2]))
   1628 
   1629     # Use a 10-second timeout, which should be longer than any
   1630     # non-blocking enqueue_many op.
   1631     config = config_pb2.ConfigProto(operation_timeout_in_ms=10000)
   1632     with session.Session(config=config) as sess:
   1633       for _ in range(num_epochs):
   1634         sess.run(enqueue_op)
   1635       self.assertEqual(sess.run(q.size()), num_epochs * 2)
   1636 
   1637   def testRegisterFetchAndFeedConversionFunctions(self):
   1638 
   1639     class SquaredTensor(object):
   1640 
   1641       def __init__(self, tensor):
   1642         self.sq = math_ops.square(tensor)
   1643 
   1644     fetch_fn = lambda squared_tensor: ([squared_tensor.sq], lambda val: val[0])
   1645     feed_fn1 = lambda feed, feed_val: [(feed.sq, feed_val)]
   1646     feed_fn2 = lambda feed: [feed.sq]
   1647 
   1648     session.register_session_run_conversion_functions(SquaredTensor, fetch_fn,
   1649                                                       feed_fn1, feed_fn2)
   1650     with self.assertRaises(ValueError):
   1651       session.register_session_run_conversion_functions(SquaredTensor, fetch_fn,
   1652                                                         feed_fn1, feed_fn2)
   1653     with self.test_session() as sess:
   1654       np1 = np.array([1.0, 1.5, 2.0, 2.5])
   1655       np2 = np.array([3.0, 3.5, 4.0, 4.5])
   1656       squared_tensor = SquaredTensor(np2)
   1657       squared_eval = sess.run(squared_tensor)
   1658       self.assertAllClose(np2 * np2, squared_eval)
   1659       squared_eval = sess.run(
   1660           squared_tensor, feed_dict={
   1661               squared_tensor: np1 * np1
   1662           })
   1663       self.assertAllClose(np1 * np1, squared_eval)
   1664       partial_run = sess.partial_run_setup([squared_tensor], [])
   1665       squared_eval = sess.partial_run(partial_run, squared_tensor)
   1666       self.assertAllClose(np2 * np2, squared_eval)
   1667 
   1668   def testDefaultLogDevicePlacement(self):
   1669 
   1670     class CaptureStderr(str):
   1671       """Class to capture stderr from C++ shared library."""
   1672 
   1673       def __enter__(self):
   1674         self._esc = compat.as_str('\b')
   1675         self._output = compat.as_str('')
   1676         self._stderr = sys.stderr
   1677         self._fd = self._stderr.fileno()
   1678         self._out_pipe, in_pipe = os.pipe()
   1679         # Save the original io stream.
   1680         self._dup_fd = os.dup(self._fd)
   1681         # Replace the original io stream with in pipe.
   1682         os.dup2(in_pipe, self._fd)
   1683         return self
   1684 
   1685       def __exit__(self, *args):
   1686         self._stderr.write(self._esc)
   1687         self._stderr.flush()
   1688         self.read()
   1689         os.close(self._out_pipe)
   1690         # Restore the original io stream.
   1691         os.dup2(self._dup_fd, self._fd)
   1692 
   1693       def read(self):
   1694         while True:
   1695           data = os.read(self._out_pipe, 1)
   1696           if not data or compat.as_str(data) == self._esc:
   1697             break
   1698           self._output += compat.as_str(data)
   1699 
   1700       def __str__(self):
   1701         return self._output
   1702 
   1703     # Passing the config to the server, but not the session should still result
   1704     # in logging device placement.
   1705     config = config_pb2.ConfigProto(log_device_placement=True)
   1706     server = server_lib.Server.create_local_server(config=config)
   1707     a = constant_op.constant(1)
   1708     b = constant_op.constant(2)
   1709     c = a + b
   1710     with session.Session(server.target) as sess:
   1711       with CaptureStderr() as log:
   1712         sess.run(c)
   1713       # Ensure that we did log device placement.
   1714       self.assertTrue('/job:local/replica:0/task:0/device:CPU:0' in str(log),
   1715                       str(log))
   1716 
   1717   def testLocalMasterSessionTimeout(self):
   1718     # Test that the timeout passed in a config to the session works correctly.
   1719     config = config_pb2.ConfigProto(operation_timeout_in_ms=1000)
   1720     server = server_lib.Server.create_local_server()
   1721     q = data_flow_ops.FIFOQueue(1, dtypes.float32)
   1722     dequeued_t = q.dequeue()
   1723 
   1724     with session.Session(server.target, config=config) as sess:
   1725       # Intentionally do not run any enqueue_ops so that dequeue will block
   1726       # until operation_timeout_in_ms.
   1727       with self.assertRaises(errors.DeadlineExceededError):
   1728         sess.run(dequeued_t)
   1729 
   1730   def testDefaultServerTimeout(self):
   1731     # Test that the default server config timeout gets used when no Session
   1732     # config is provided.
   1733     config = config_pb2.ConfigProto(operation_timeout_in_ms=1000)
   1734     server = server_lib.Server.create_local_server(config=config)
   1735     q = data_flow_ops.FIFOQueue(1, dtypes.float32)
   1736     dequeued_t = q.dequeue()
   1737 
   1738     with session.Session(server.target) as sess:
   1739       # Intentionally do not run any enqueue_ops so that dequeue will block
   1740       # until operation_timeout_in_ms.
   1741       with self.assertRaises(errors.DeadlineExceededError):
   1742         sess.run(dequeued_t)
   1743 
   1744   def runTestBuildGraphError(self, sess):
   1745     # Ensure that errors from building the graph get propagated.
   1746     data = array_ops.placeholder(dtypes.float32, shape=[])
   1747     # pylint: disable=protected-access
   1748     enter_1 = gen_control_flow_ops._enter(data, 'foo_1', False)
   1749     enter_2 = gen_control_flow_ops._enter(data, 'foo_2', False)
   1750     # pylint: enable=protected-access
   1751     res = math_ops.add(enter_1, enter_2)
   1752     with self.assertRaisesOpError('has inputs from different frames'):
   1753       sess.run(res, feed_dict={data: 1.0})
   1754 
   1755   def testBuildGraphErrorDirect(self):
   1756     self.runTestBuildGraphError(session.Session())
   1757 
   1758   def testBuildGraphErrorDist(self):
   1759     server = server_lib.Server.create_local_server()
   1760     self.runTestBuildGraphError(session.Session(server.target))
   1761 
   1762   def testDeviceAttributes(self):
   1763     attrs = session._DeviceAttributes(
   1764         '/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337)
   1765     self.assertEqual(1337, attrs.memory_limit_bytes)
   1766     self.assertEqual('/job:worker/replica:0/task:3/device:CPU:2', attrs.name)
   1767     self.assertEqual('TYPE', attrs.device_type)
   1768     str_repr = '%s' % attrs
   1769     self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
   1770 
   1771   def testDeviceAttributesCanonicalization(self):
   1772     attrs = session._DeviceAttributes('/job:worker/replica:0/task:3/cpu:1',
   1773                                       'TYPE', 1337)
   1774     self.assertEqual(1337, attrs.memory_limit_bytes)
   1775     self.assertEqual('/job:worker/replica:0/task:3/device:CPU:1', attrs.name)
   1776     self.assertEqual('TYPE', attrs.device_type)
   1777     str_repr = '%s' % attrs
   1778     self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
   1779 
   1780   def runTestAddFunctionToSession(self, target=''):
   1781     """Add a function to a session after the graph has already been run."""
   1782 
   1783     @function.Defun(dtypes.float32)
   1784     def foo(x):
   1785       return x + 1
   1786 
   1787     x = constant_op.constant(1.0)
   1788     with session.Session(target=target) as sess:
   1789       sess.run(x)
   1790       f = foo(x)
   1791       result = sess.run(f)
   1792       self.assertEqual(result, 2.0)
   1793 
   1794   def testAddFunctionToSession(self):
   1795     self.runTestAddFunctionToSession()
   1796 
   1797   def testAddFunctionToGrpcSession(self):
   1798     server = server_lib.Server.create_local_server()
   1799     self.runTestAddFunctionToSession(server.target)
   1800 
   1801   def testOpenAndCloseGrpcSession(self):
   1802     server = server_lib.Server.create_local_server()
   1803     with session.Session(server.target):
   1804       pass
   1805 
   1806   def testOpenAndCloseSession(self):
   1807     with session.Session():
   1808       pass
   1809 
   1810   def testAutoConvertAndCheckData(self):
   1811     with self.test_session() as sess:
   1812       a = array_ops.placeholder(dtype=dtypes.string)
   1813       with self.assertRaisesRegexp(
   1814           TypeError, 'Type of feed value 1 with type <(\w+) \'int\'> is not'):
   1815         sess.run(a, feed_dict={a: 1})
   1816 
   1817 
   1818 class GraphMutationTest(test_util.TensorFlowTestCase):
   1819 
   1820   def setUp(self):
   1821     self._original_use_c_api_value = ops._USE_C_API
   1822     ops._USE_C_API = True
   1823     super(GraphMutationTest, self).setUp()
   1824 
   1825   def tearDown(self):
   1826     ops._USE_C_API = self._original_use_c_api_value
   1827     super(GraphMutationTest, self).tearDown()
   1828 
   1829   def testUpdateInputAfterRunning(self):
   1830     with ops.Graph().as_default() as g:
   1831       a = constant_op.constant(1.0)
   1832       b = constant_op.constant(2.0)
   1833       c = a + b
   1834 
   1835     with session.Session(graph=g) as sess:
   1836       self.assertAllEqual(3.0, sess.run(c))
   1837       c.op._update_input(1, a)  # pylint: disable=protected-access
   1838       with self.assertRaisesRegexp(
   1839           errors.FailedPreconditionError,
   1840           'add.*was changed by updating input tensor after it was run'):
   1841         sess.run(c)
   1842 
   1843       # Check that running the graph with a new session is fine
   1844       with session.Session(graph=g) as sess2:
   1845         self.assertAllEqual(2.0, sess2.run(c))
   1846 
   1847   def testSetDeviceAfterRunning(self):
   1848     with ops.Graph().as_default() as g:
   1849       a = constant_op.constant(1.0)
   1850       b = constant_op.constant(2.0)
   1851       c = a + b
   1852 
   1853     with session.Session(graph=g) as sess:
   1854       self.assertAllEqual(3.0, sess.run(c))
   1855       c.op._set_device('/cpu:0')  # pylint: disable=protected-access
   1856       with self.assertRaisesRegexp(
   1857           errors.FailedPreconditionError,
   1858           'add.*was changed by setting device after it was run'):
   1859         sess.run(c)
   1860 
   1861   def testSetAttrAfterRunning(self):
   1862     with ops.Graph().as_default() as g:
   1863       a = constant_op.constant(1.0, dtype=dtypes.float32)
   1864       b = math_ops.cast(a, dtypes.float64)
   1865 
   1866     with session.Session(graph=g) as sess:
   1867       self.assertAllEqual(1.0, sess.run(b))
   1868       b.op._set_attr('DstT', attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT))
   1869       with self.assertRaisesRegexp(
   1870           errors.FailedPreconditionError,
   1871           'Cast.*was changed by setting attribute after it was run'):
   1872         sess.run(b)
   1873 
   1874   def testRunModifyRun(self):
   1875     with ops.Graph().as_default() as g:
   1876       a = constant_op.constant(1.0)
   1877       b = constant_op.constant(2.0)
   1878       c = a + b
   1879 
   1880       with session.Session(graph=g) as sess:
   1881         self.assertAllEqual(3.0, sess.run(c))
   1882 
   1883         d = b + c
   1884         d.op._update_input(0, a)  # pylint: disable=protected-access
   1885         self.assertAllEqual(3.0, sess.run(c))
   1886         self.assertAllEqual(4.0, sess.run(d))
   1887 
   1888   def testRunModifyRunTwoSessions(self):
   1889     with ops.Graph().as_default() as g:
   1890       a = constant_op.constant(1.0)
   1891       b = constant_op.constant(2.0)
   1892       c = a + b
   1893 
   1894       with session.Session(graph=g) as sess1:
   1895         with session.Session(graph=g) as sess2:
   1896           self.assertAllEqual(3.0, sess1.run(c))
   1897           self.assertAllEqual(3.0, sess2.run(c))
   1898 
   1899           d = b + c
   1900           d.op._update_input(0, a)  # pylint: disable=protected-access
   1901           self.assertAllEqual(3.0, sess2.run(c))
   1902           self.assertAllEqual(4.0, sess2.run(d))
   1903 
   1904           d.op._update_input(0, b)  # pylint: disable=protected-access
   1905           self.assertAllEqual(3.0, sess1.run(c))
   1906           self.assertAllEqual(5.0, sess1.run(d))
   1907 
   1908           with self.assertRaisesRegexp(
   1909               errors.FailedPreconditionError,
   1910               'add.*was changed by updating input tensor after it was run'):
   1911             sess2.run(c)
   1912 
   1913   def testTwoSessionsOneRunBeforeModification(self):
   1914     with ops.Graph().as_default() as g, ops.device('/cpu:0'):
   1915       a = constant_op.constant(1.0)
   1916       b = constant_op.constant(2.0)
   1917       c = a + b
   1918 
   1919     with session.Session(graph=g) as sess1:
   1920       with session.Session(graph=g) as sess2:
   1921         sess1.run(c)
   1922 
   1923         c.op._set_device('/cpu:0')  # pylint: disable=protected-access
   1924 
   1925         with self.assertRaisesRegexp(
   1926             errors.FailedPreconditionError,
   1927             'add.*was changed by setting device after it was run'):
   1928           sess1.run(c)
   1929 
   1930         # sess2 was not run before modification
   1931         self.assertAllEqual(3.0, sess2.run(c))
   1932 
   1933   def testTwoSessionsBothRunBeforeModification(self):
   1934     with ops.Graph().as_default() as g, ops.device('/cpu:0'):
   1935       a = constant_op.constant(1.0)
   1936       b = constant_op.constant(2.0)
   1937       c = a + b
   1938 
   1939     with session.Session(graph=g) as sess1:
   1940       with session.Session(graph=g) as sess2:
   1941         sess1.run(c)
   1942         sess2.run(c)
   1943 
   1944         c.op._set_device('/cpu:0')  # pylint: disable=protected-access
   1945 
   1946         with self.assertRaisesRegexp(
   1947             errors.FailedPreconditionError,
   1948             'add.*was changed by setting device after it was run'):
   1949           sess1.run(c)
   1950 
   1951         with self.assertRaisesRegexp(
   1952             errors.FailedPreconditionError,
   1953             'add.*was changed by setting device after it was run'):
   1954           sess2.run(c)
   1955 
   1956 
   1957 if __name__ == '__main__':
   1958   googletest.main()
   1959