Home | History | Annotate | Download | only in framework
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 # pylint: disable=invalid-name
     17 """Test utils for tensorflow."""
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import contextlib
     23 import gc
     24 import math
     25 import random
     26 import re
     27 import tempfile
     28 import threading
     29 
     30 import numpy as np
     31 import six
     32 
     33 _portpicker_import_error = None
     34 try:
     35   import portpicker  # pylint: disable=g-import-not-at-top
     36 except ImportError as _error:
     37   _portpicker_import_error = _error
     38   portpicker = None
     39 
     40 # pylint: disable=g-import-not-at-top
     41 from google.protobuf import descriptor_pool
     42 from google.protobuf import text_format
     43 
     44 from tensorflow.core.framework import graph_pb2
     45 from tensorflow.core.protobuf import config_pb2
     46 from tensorflow.core.protobuf import rewriter_config_pb2
     47 from tensorflow.python import pywrap_tensorflow
     48 from tensorflow.python.client import device_lib
     49 from tensorflow.python.client import session
     50 from tensorflow.python.eager import backprop
     51 from tensorflow.python.eager import context
     52 from tensorflow.python.eager import tape  # pylint: disable=unused-import
     53 from tensorflow.python.framework import device as pydev
     54 from tensorflow.python.framework import dtypes
     55 from tensorflow.python.framework import errors
     56 from tensorflow.python.framework import importer
     57 from tensorflow.python.framework import ops
     58 from tensorflow.python.framework import random_seed
     59 from tensorflow.python.framework import versions
     60 from tensorflow.python.ops import array_ops
     61 from tensorflow.python.ops import resource_variable_ops
     62 from tensorflow.python.ops import variables
     63 from tensorflow.python.platform import googletest
     64 from tensorflow.python.platform import tf_logging as logging
     65 from tensorflow.python.training import server_lib
     66 from tensorflow.python.util import compat
     67 from tensorflow.python.util import nest
     68 from tensorflow.python.util.protobuf import compare
     69 from tensorflow.python.util.tf_export import tf_export
     70 
     71 
     72 @tf_export("test.gpu_device_name")
     73 def gpu_device_name():
     74   """Returns the name of a GPU device if available or the empty string."""
     75   for x in device_lib.list_local_devices():
     76     if x.device_type == "GPU" or x.device_type == "SYCL":
     77       return compat.as_str(x.name)
     78   return ""
     79 
     80 
     81 def assert_ops_in_graph(expected_ops, graph):
     82   """Assert all expected operations are found.
     83 
     84   Args:
     85     expected_ops: `dict<string, string>` of op name to op type.
     86     graph: Graph to check.
     87   Returns:
     88     `dict<string, node>` of node name to node.
     89 
     90   Raises:
     91     ValueError: If the expected ops are not present in the graph.
     92   """
     93   actual_ops = {}
     94   gd = graph.as_graph_def()
     95   for node in gd.node:
     96     if node.name in expected_ops:
     97       if expected_ops[node.name] != node.op:
     98         raise ValueError("Expected op for node %s is different. %s vs %s" %
     99                          (node.name, expected_ops[node.name], node.op))
    100       actual_ops[node.name] = node
    101   if set(expected_ops.keys()) != set(actual_ops.keys()):
    102     raise ValueError("Not all expected ops are present. Expected %s, found %s" %
    103                      (expected_ops.keys(), actual_ops.keys()))
    104   return actual_ops
    105 
    106 
    107 @tf_export("test.assert_equal_graph_def")
    108 def assert_equal_graph_def(actual, expected, checkpoint_v2=False):
    109   """Asserts that two `GraphDef`s are (mostly) the same.
    110 
    111   Compares two `GraphDef` protos for equality, ignoring versions and ordering of
    112   nodes, attrs, and control inputs.  Node names are used to match up nodes
    113   between the graphs, so the naming of nodes must be consistent.
    114 
    115   Args:
    116     actual: The `GraphDef` we have.
    117     expected: The `GraphDef` we expected.
    118     checkpoint_v2: boolean determining whether to ignore randomized attribute
    119         values that appear in V2 checkpoints.
    120 
    121   Raises:
    122     AssertionError: If the `GraphDef`s do not match.
    123     TypeError: If either argument is not a `GraphDef`.
    124   """
    125   if not isinstance(actual, graph_pb2.GraphDef):
    126     raise TypeError(
    127         "Expected tf.GraphDef for actual, got %s" % type(actual).__name__)
    128   if not isinstance(expected, graph_pb2.GraphDef):
    129     raise TypeError(
    130         "Expected tf.GraphDef for expected, got %s" % type(expected).__name__)
    131 
    132   if checkpoint_v2:
    133     _strip_checkpoint_v2_randomized(actual)
    134     _strip_checkpoint_v2_randomized(expected)
    135 
    136   diff = pywrap_tensorflow.EqualGraphDefWrapper(actual.SerializeToString(),
    137                                                 expected.SerializeToString())
    138   if diff:
    139     raise AssertionError(compat.as_str(diff))
    140 
    141 
    142 def assert_meta_graph_protos_equal(tester, a, b):
    143   """Compares MetaGraphDefs `a` and `b` in unit test class `tester`."""
    144   # Carefully check the collection_defs
    145   tester.assertEqual(set(a.collection_def), set(b.collection_def))
    146   collection_keys = a.collection_def.keys()
    147   for k in collection_keys:
    148     a_value = a.collection_def[k]
    149     b_value = b.collection_def[k]
    150     proto_type = ops.get_collection_proto_type(k)
    151     if proto_type:
    152       a_proto = proto_type()
    153       b_proto = proto_type()
    154       # Number of entries in the collections is the same
    155       tester.assertEqual(
    156           len(a_value.bytes_list.value), len(b_value.bytes_list.value))
    157       for (a_value_item, b_value_item) in zip(a_value.bytes_list.value,
    158                                               b_value.bytes_list.value):
    159         a_proto.ParseFromString(a_value_item)
    160         b_proto.ParseFromString(b_value_item)
    161         tester.assertProtoEquals(a_proto, b_proto)
    162     else:
    163       tester.assertEquals(a_value, b_value)
    164   # Compared the fields directly, remove their raw values from the
    165   # proto comparison below.
    166   a.ClearField("collection_def")
    167   b.ClearField("collection_def")
    168 
    169   # Check the graph_defs.
    170   assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True)
    171   # Check graph_def versions (ignored by assert_equal_graph_def).
    172   tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions)
    173   # Compared the fields directly, remove their raw values from the
    174   # proto comparison below.
    175   a.ClearField("graph_def")
    176   b.ClearField("graph_def")
    177 
    178   tester.assertProtoEquals(a, b)
    179 
    180 
    181 # Matches attributes named via _SHARDED_SUFFIX in
    182 # tensorflow/python/training/saver.py
    183 _SHARDED_SAVE_OP_PATTERN = "_temp_[0-9a-z]{32}/part"
    184 
    185 
    186 def _strip_checkpoint_v2_randomized(graph_def):
    187   for node in graph_def.node:
    188     delete_keys = []
    189     for attr_key in node.attr:
    190       attr_tensor_value = node.attr[attr_key].tensor
    191       if attr_tensor_value and len(attr_tensor_value.string_val) == 1:
    192         attr_tensor_string_value = attr_tensor_value.string_val[0]
    193         if (attr_tensor_string_value and
    194             re.match(_SHARDED_SAVE_OP_PATTERN, str(attr_tensor_string_value))):
    195           delete_keys.append(attr_key)
    196     for attr_key in delete_keys:
    197       del node.attr[attr_key]
    198 
    199 
    200 def IsGoogleCudaEnabled():
    201   return pywrap_tensorflow.IsGoogleCudaEnabled()
    202 
    203 
    204 def CudaSupportsHalfMatMulAndConv():
    205   return pywrap_tensorflow.CudaSupportsHalfMatMulAndConv()
    206 
    207 
    208 def InstallStackTraceHandler():
    209   pywrap_tensorflow.InstallStacktraceHandler()
    210 
    211 
    212 def NHWCToNCHW(input_tensor):
    213   """Converts the input from the NHWC format to NCHW.
    214 
    215   Args:
    216     input_tensor: a 4- or 5-D tensor, or an array representing shape
    217 
    218   Returns:
    219     converted tensor or shape array
    220   """
    221   # tensor dim -> new axis order
    222   new_axes = {4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]}
    223   if isinstance(input_tensor, ops.Tensor):
    224     ndims = input_tensor.shape.ndims
    225     return array_ops.transpose(input_tensor, new_axes[ndims])
    226   else:
    227     ndims = len(input_tensor)
    228     return [input_tensor[a] for a in new_axes[ndims]]
    229 
    230 
    231 def NHWCToNCHW_VECT_C(input_shape_or_tensor):
    232   """Transforms the input from the NHWC layout to NCHW_VECT_C layout.
    233 
    234   Note: Does not include quantization or type conversion steps, which should
    235   be applied afterwards.
    236 
    237   Args:
    238     input_shape_or_tensor: a 4- or 5-D tensor, or an array representing shape
    239 
    240   Returns:
    241     tensor or shape array transformed into NCHW_VECT_C
    242 
    243   Raises:
    244     ValueError: if last dimension of `input_shape_or_tensor` is not evenly
    245         divisible by 4.
    246   """
    247   permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]}
    248   is_tensor = isinstance(input_shape_or_tensor, ops.Tensor)
    249   temp_shape = (
    250       input_shape_or_tensor.shape.as_list()
    251       if is_tensor else input_shape_or_tensor)
    252   if temp_shape[-1] % 4 != 0:
    253     raise ValueError(
    254         "Last dimension of input must be evenly divisible by 4 to convert to "
    255         "NCHW_VECT_C.")
    256   temp_shape[-1] //= 4
    257   temp_shape.append(4)
    258   permutation = permutations[len(temp_shape)]
    259   if is_tensor:
    260     t = array_ops.reshape(input_shape_or_tensor, temp_shape)
    261     return array_ops.transpose(t, permutation)
    262   else:
    263     return [temp_shape[a] for a in permutation]
    264 
    265 
    266 def NCHW_VECT_CToNHWC(input_shape_or_tensor):
    267   """Transforms the input from the NCHW_VECT_C layout to NHWC layout.
    268 
    269   Note: Does not include de-quantization or type conversion steps, which should
    270   be applied beforehand.
    271 
    272   Args:
    273     input_shape_or_tensor: a 5- or 6-D tensor, or an array representing shape
    274 
    275   Returns:
    276     tensor or shape array transformed into NHWC
    277 
    278   Raises:
    279     ValueError: if last dimension of `input_shape_or_tensor` is not 4.
    280   """
    281   permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]}
    282   is_tensor = isinstance(input_shape_or_tensor, ops.Tensor)
    283   input_shape = (
    284       input_shape_or_tensor.shape.as_list()
    285       if is_tensor else input_shape_or_tensor)
    286   if input_shape[-1] != 4:
    287     raise ValueError("Last dimension of NCHW_VECT_C must be 4.")
    288   permutation = permutations[len(input_shape)]
    289   nhwc_shape = [input_shape[a] for a in permutation[:-1]]
    290   nhwc_shape[-1] *= input_shape[-1]
    291   if is_tensor:
    292     t = array_ops.transpose(input_shape_or_tensor, permutation)
    293     return array_ops.reshape(t, nhwc_shape)
    294   else:
    295     return nhwc_shape
    296 
    297 
    298 def NCHWToNHWC(input_tensor):
    299   """Converts the input from the NCHW format to NHWC.
    300 
    301   Args:
    302     input_tensor: a 4- or 5-D tensor, or an array representing shape
    303 
    304   Returns:
    305     converted tensor or shape array
    306   """
    307   # tensor dim -> new axis order
    308   new_axes = {4: [0, 2, 3, 1], 5: [0, 2, 3, 4, 1]}
    309   if isinstance(input_tensor, ops.Tensor):
    310     ndims = input_tensor.shape.ndims
    311     return array_ops.transpose(input_tensor, new_axes[ndims])
    312   else:
    313     ndims = len(input_tensor)
    314     return [input_tensor[a] for a in new_axes[ndims]]
    315 
    316 
    317 # TODO(skyewm): remove this eventually
    318 # pylint: disable=protected-access
    319 def _use_c_api_wrapper(fn, use_c_api, *args, **kwargs):
    320   prev_value = ops._USE_C_API
    321   ops._USE_C_API = use_c_api
    322   try:
    323     # Reset the default graph so it has the C API enabled. We call
    324     # reset_default_graph() instead of creating a new default Graph context to
    325     # make this robust to tests that call reset_default_graph(), which requires
    326     # that the current default graph isn't nested.
    327     ops.reset_default_graph()
    328     fn(*args, **kwargs)
    329   finally:
    330     ops._USE_C_API = prev_value
    331     # Make sure default graph reflects prev_value in case next test doesn't call
    332     # reset_default_graph().
    333     ops.reset_default_graph()
    334 # pylint: disable=protected-access
    335 
    336 
    337 def c_api_and_cuda_enabled():
    338   return ops._USE_C_API and IsGoogleCudaEnabled()
    339 
    340 
    341 def skip_if(condition):
    342   """Skips the decorated function if condition is or evaluates to True.
    343 
    344   Args:
    345     condition: Either an expression that can be used in "if not condition"
    346                statement, or a callable whose result should be a boolean.
    347   Returns:
    348     The wrapped function
    349   """
    350 
    351   def real_skip_if(fn):
    352 
    353     def wrapper(*args, **kwargs):
    354       if callable(condition):
    355         skip = condition()
    356       else:
    357         skip = condition
    358       if not skip:
    359         fn(*args, **kwargs)
    360 
    361     return wrapper
    362 
    363   return real_skip_if
    364 
    365 
    366 # TODO(skyewm): remove this eventually
    367 def disable_c_api(fn):
    368   """Decorator for disabling the C API on a test.
    369 
    370   Note this disables the C API after running the test class's setup/teardown
    371   methods.
    372 
    373   Args:
    374     fn: the function to be wrapped
    375 
    376   Returns:
    377     The wrapped function
    378   """
    379 
    380   def wrapper(*args, **kwargs):
    381     _use_c_api_wrapper(fn, False, *args, **kwargs)
    382 
    383   return wrapper
    384 
    385 
    386 # TODO(skyewm): remove this eventually
    387 def enable_c_api(fn):
    388   """Decorator for enabling the C API on a test.
    389 
    390   Note this enables the C API after running the test class's setup/teardown
    391   methods.
    392 
    393   Args:
    394     fn: the function to be wrapped
    395 
    396   Returns:
    397     The wrapped function
    398   """
    399 
    400   def wrapper(*args, **kwargs):
    401     _use_c_api_wrapper(fn, True, *args, **kwargs)
    402 
    403   return wrapper
    404 
    405 
    406 # This decorator is a hacky way to run all the test methods in a decorated
    407 # class with and without C API enabled.
    408 # TODO(iga): Remove this and its uses once we switch to using C API by default.
    409 def with_c_api(cls):
    410   """Adds methods that call original methods but with C API enabled.
    411 
    412   Note this enables the C API in new methods after running the test class's
    413   setup method. This can be a problem if some objects are created in it
    414   before the C API is enabled.
    415 
    416   Args:
    417     cls: class to decorate
    418 
    419   Returns:
    420     cls with new test methods added
    421   """
    422   for name, value in cls.__dict__.copy().items():
    423     if callable(value) and name.startswith("test"):
    424       setattr(cls, name + "WithCApi", enable_c_api(value))
    425   return cls
    426 
    427 
    428 def assert_no_new_tensors(f):
    429   """Decorator for asserting that no new Tensors persist after a test.
    430 
    431   Mainly useful for checking that code using the Python C API has correctly
    432   manipulated reference counts.
    433 
    434   Clears the caches that it knows about, runs the garbage collector, then checks
    435   that there are no Tensor or Tensor-like objects still around. This includes
    436   Tensors to which something still has a reference (e.g. from missing
    437   Py_DECREFs) and uncollectable cycles (i.e. Python reference cycles where one
    438   of the objects has __del__ defined).
    439 
    440   Args:
    441     f: The test case to run.
    442   Returns:
    443     The decorated test case.
    444   """
    445 
    446   def decorator(self, **kwargs):
    447     """Finds existing Tensors, runs the test, checks for new Tensors."""
    448 
    449     def _is_tensor(obj):
    450       try:
    451         return (isinstance(obj, ops.Tensor) or
    452                 isinstance(obj, variables.Variable))
    453       except ReferenceError:
    454         # If the object no longer exists, we don't care about it.
    455         return False
    456 
    457     tensors_before = set(id(obj) for obj in gc.get_objects() if _is_tensor(obj))
    458     outside_graph_key = ops.get_default_graph()._graph_key
    459     with ops.Graph().as_default():
    460       # Run the test in a new graph so that collections get cleared when it's
    461       # done, but inherit the graph key so optimizers behave.
    462       ops.get_default_graph()._graph_key = outside_graph_key
    463       f(self, **kwargs)
    464     # Make an effort to clear caches, which would otherwise look like leaked
    465     # Tensors.
    466     backprop._zeros_cache.flush()
    467     context.get_default_context().scalar_cache().clear()
    468     gc.collect()
    469     tensors_after = [
    470         obj for obj in gc.get_objects()
    471         if _is_tensor(obj) and id(obj) not in tensors_before
    472     ]
    473     if tensors_after:
    474       raise AssertionError(("%d Tensors not deallocated after test: %s" % (
    475           len(tensors_after),
    476           str(tensors_after),
    477       )))
    478 
    479   return decorator
    480 
    481 
    482 def assert_no_garbage_created(f):
    483   """Test method decorator to assert that no garbage has been created.
    484 
    485   Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters
    486   cannot be un-set (i.e. will disable garbage collection for any other unit
    487   tests in the same file/shard).
    488 
    489   Args:
    490     f: The function to decorate.
    491   Returns:
    492     The decorated function.
    493   """
    494 
    495   def decorator(self, **kwargs):
    496     """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage."""
    497     gc.disable()
    498     previous_debug_flags = gc.get_debug()
    499     gc.set_debug(gc.DEBUG_SAVEALL)
    500     gc.collect()
    501     previous_garbage = len(gc.garbage)
    502     f(self, **kwargs)
    503     gc.collect()
    504     # This will fail if any garbage has been created, typically because of a
    505     # reference cycle.
    506     self.assertEqual(previous_garbage, len(gc.garbage))
    507     # TODO(allenl): Figure out why this debug flag reset doesn't work. It would
    508     # be nice to be able to decorate arbitrary tests in a large test suite and
    509     # not hold on to every object in other tests.
    510     gc.set_debug(previous_debug_flags)
    511     gc.enable()
    512 
    513   return decorator
    514 
    515 
    516 def run_in_graph_and_eager_modes(__unused__=None,
    517                                  graph=None,
    518                                  config=None,
    519                                  use_gpu=False,
    520                                  force_gpu=False,
    521                                  reset_test=True,
    522                                  assert_no_eager_garbage=False):
    523   """Runs the test in both graph and eager modes.
    524 
    525   Args:
    526     __unused__: Prevents sliently skipping tests.
    527     graph: Optional graph to use during the returned session.
    528     config: An optional config_pb2.ConfigProto to use to configure the
    529       session.
    530     use_gpu: If True, attempt to run as many ops as possible on GPU.
    531     force_gpu: If True, pin all ops to `/device:GPU:0`.
    532     reset_test: If True, tearDown and SetUp the test case again.
    533     assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage
    534       collector and asserts that no extra garbage has been created when running
    535       the test in eager mode. This will fail if there are reference cycles
    536       (e.g. a = []; a.append(a)). Off by default because some tests may create
    537       garbage for legitimate reasons (e.g. they define a class which inherits
    538       from `object`), and because DEBUG_SAVEALL is sticky in some Python
    539       interpreters (meaning that tests which rely on objects being collected
    540       elsewhere in the unit test file will not work). Additionally, checks that
    541       nothing still has a reference to Tensors that the test allocated.
    542   Returns:
    543     Returns a decorator that will run the decorated test function
    544         using both a graph and using eager execution.
    545   """
    546 
    547   assert not __unused__, "Add () after run_in_graph_and_eager_modes."
    548 
    549   def decorator(f):
    550     """Test method decorator."""
    551 
    552     def decorated(self, **kwargs):
    553       """Decorated the test method."""
    554       with context.graph_mode():
    555         with self.test_session(graph, config, use_gpu, force_gpu):
    556           f(self, **kwargs)
    557 
    558       if reset_test:
    559         # This decorator runs the wrapped test twice.
    560         # Reset the test environment between runs.
    561         self.tearDown()
    562         self.setUp()
    563 
    564       def run_eager_mode(self, **kwargs):
    565         if force_gpu:
    566           gpu_name = gpu_device_name()
    567           if not gpu_name:
    568             gpu_name = "/device:GPU:0"
    569           with context.device(gpu_name):
    570             f(self)
    571         elif use_gpu:
    572           # TODO(xpan): Support softplacement and gpu by default when available.
    573           f(self, **kwargs)
    574         else:
    575           with context.device("/device:CPU:0"):
    576             f(self, **kwargs)
    577 
    578       if assert_no_eager_garbage:
    579         run_eager_mode = assert_no_new_tensors(
    580             assert_no_garbage_created(run_eager_mode))
    581 
    582       with context.eager_mode():
    583         with ops.Graph().as_default():
    584           run_eager_mode(self, **kwargs)
    585 
    586     return decorated
    587 
    588   return decorator
    589 
    590 
    591 @tf_export("test.is_gpu_available")
    592 def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None):
    593   """Returns whether TensorFlow can access a GPU.
    594 
    595   Args:
    596     cuda_only: limit the search to CUDA gpus.
    597     min_cuda_compute_capability: a (major,minor) pair that indicates the minimum
    598       CUDA compute capability required, or None if no requirement.
    599 
    600   Returns:
    601     True iff a gpu device of the requested kind is available.
    602   """
    603 
    604   def compute_capability_from_device_desc(device_desc):
    605     # TODO(jingyue): The device description generator has to be in sync with
    606     # this file. Another option is to put compute capability in
    607     # DeviceAttributes, but I avoided that to keep DeviceAttributes
    608     # target-independent. Reconsider this option when we have more things like
    609     # this to keep in sync.
    610     # LINT.IfChange
    611     match = re.search(r"compute capability: (\d+)\.(\d+)", device_desc)
    612     # LINT.ThenChange(//tensorflow/core/\
    613     #                 common_runtime/gpu/gpu_device.cc)
    614     if not match:
    615       return 0, 0
    616     return int(match.group(1)), int(match.group(2))
    617 
    618   for local_device in device_lib.list_local_devices():
    619     if local_device.device_type == "GPU":
    620       if (min_cuda_compute_capability is None or
    621           compute_capability_from_device_desc(local_device.physical_device_desc)
    622           >= min_cuda_compute_capability):
    623         return True
    624     if local_device.device_type == "SYCL" and not cuda_only:
    625       return True
    626   return False
    627 
    628 
    629 @contextlib.contextmanager
    630 def device(use_gpu):
    631   """Uses gpu when requested and available."""
    632   if use_gpu and is_gpu_available():
    633     dev = "/device:GPU:0"
    634   else:
    635     dev = "/device:CPU:0"
    636   with ops.device(dev):
    637     yield
    638 
    639 
    640 @tf_export("test.TestCase")
    641 class TensorFlowTestCase(googletest.TestCase):
    642   """Base class for tests that need to test TensorFlow.
    643   """
    644 
    645   def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
    646     super(TensorFlowTestCase, self).__init__(methodName)
    647     self._threads = []
    648     self._tempdir = None
    649     self._cached_session = None
    650 
    651   def setUp(self):
    652     self._ClearCachedSession()
    653     random.seed(random_seed.DEFAULT_GRAPH_SEED)
    654     np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
    655     # Note: The following line is necessary because some test methods may error
    656     # out from within nested graph contexts (e.g., via assertRaises and
    657     # assertRaisesRegexp), which may leave ops._default_graph_stack non-empty
    658     # under certain versions of Python. That would cause
    659     # ops.reset_default_graph() to throw an exception if the stack were not
    660     # cleared first.
    661     ops._default_graph_stack.reset()  # pylint: disable=protected-access
    662     ops.reset_default_graph()
    663     random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED)
    664 
    665   def tearDown(self):
    666     for thread in self._threads:
    667       thread.check_termination()
    668 
    669     self._ClearCachedSession()
    670 
    671   def _ClearCachedSession(self):
    672     if self._cached_session is not None:
    673       self._cached_session.close()
    674       self._cached_session = None
    675 
    676   def get_temp_dir(self):
    677     """Returns a unique temporary directory for the test to use.
    678 
    679     If you call this method multiple times during in a test, it will return the
    680     same folder. However, across different runs the directories will be
    681     different. This will ensure that across different runs tests will not be
    682     able to pollute each others environment.
    683     If you need multiple unique directories within a single test, you should
    684     use tempfile.mkdtemp as follows:
    685       tempfile.mkdtemp(dir=self.get_temp_dir()):
    686 
    687     Returns:
    688       string, the path to the unique temporary directory created for this test.
    689     """
    690     if not self._tempdir:
    691       self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
    692     return self._tempdir
    693 
    694   def _AssertProtoEquals(self, a, b, msg=None):
    695     """Asserts that a and b are the same proto.
    696 
    697     Uses ProtoEq() first, as it returns correct results
    698     for floating point attributes, and then use assertProtoEqual()
    699     in case of failure as it provides good error messages.
    700 
    701     Args:
    702       a: a proto.
    703       b: another proto.
    704       msg: Optional message to report on failure.
    705     """
    706     if not compare.ProtoEq(a, b):
    707       compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg)
    708 
    709   def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None):
    710     """Asserts that message is same as parsed expected_message_ascii.
    711 
    712     Creates another prototype of message, reads the ascii message into it and
    713     then compares them using self._AssertProtoEqual().
    714 
    715     Args:
    716       expected_message_maybe_ascii: proto message in original or ascii form.
    717       message: the message to validate.
    718       msg: Optional message to report on failure.
    719     """
    720     msg = msg if msg else ""
    721     if isinstance(expected_message_maybe_ascii, type(message)):
    722       expected_message = expected_message_maybe_ascii
    723       self._AssertProtoEquals(expected_message, message)
    724     elif isinstance(expected_message_maybe_ascii, str):
    725       expected_message = type(message)()
    726       text_format.Merge(
    727           expected_message_maybe_ascii,
    728           expected_message,
    729           descriptor_pool=descriptor_pool.Default())
    730       self._AssertProtoEquals(expected_message, message, msg=msg)
    731     else:
    732       assert False, ("Can't compare protos of type %s and %s. %s" %
    733                      (type(expected_message_maybe_ascii), type(message), msg))
    734 
    735   def assertProtoEqualsVersion(
    736       self,
    737       expected,
    738       actual,
    739       producer=versions.GRAPH_DEF_VERSION,
    740       min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER,
    741       msg=None):
    742     expected = "versions { producer: %d min_consumer: %d };\n%s" % (
    743         producer, min_consumer, expected)
    744     self.assertProtoEquals(expected, actual, msg=msg)
    745 
    746   def assertStartsWith(self, actual, expected_start, msg=None):
    747     """Assert that actual.startswith(expected_start) is True.
    748 
    749     Args:
    750       actual: str
    751       expected_start: str
    752       msg: Optional message to report on failure.
    753     """
    754     if not actual.startswith(expected_start):
    755       fail_msg = "%r does not start with %r" % (actual, expected_start)
    756       fail_msg += " : %r" % (msg) if msg else ""
    757       self.fail(fail_msg)
    758 
    759   def _eval_tensor(self, tensor):
    760     if tensor is None:
    761       return None
    762     elif isinstance(tensor, ops.EagerTensor):
    763       return tensor.numpy()
    764     elif isinstance(tensor, resource_variable_ops.ResourceVariable):
    765       return tensor.read_value().numpy()
    766     elif callable(tensor):
    767       return self._eval_helper(tensor())
    768     else:
    769       raise ValueError("Unsupported type %s." % type(tensor))
    770 
    771   def _eval_helper(self, tensors):
    772     if tensors is None:
    773       return None
    774     return nest.map_structure(self._eval_tensor, tensors)
    775 
    776   def evaluate(self, tensors):
    777     """Evaluates tensors and returns numpy values.
    778 
    779     Args:
    780       tensors: A Tensor or a nested list/tuple of Tensors.
    781 
    782     Returns:
    783       tensors numpy values.
    784     """
    785     if context.in_eager_mode():
    786       return self._eval_helper(tensors)
    787     else:
    788       sess = ops.get_default_session()
    789       if sess is None:
    790         with self.test_session() as sess:
    791           return sess.run(tensors)
    792       else:
    793         return sess.run(tensors)
    794 
    795   # pylint: disable=g-doc-return-or-yield
    796   @contextlib.contextmanager
    797   def test_session(self,
    798                    graph=None,
    799                    config=None,
    800                    use_gpu=False,
    801                    force_gpu=False):
    802     """Returns a TensorFlow Session for use in executing tests.
    803 
    804     This method should be used for all functional tests.
    805 
    806     This method behaves different than session.Session: for performance reasons
    807     `test_session` will by default (if `graph` is None) reuse the same session
    808     across tests. This means you may want to either call the function
    809     `reset_default_graph()` before tests, or if creating an explicit new graph,
    810     pass it here (simply setting it with `as_default()` won't do it), which will
    811     trigger the creation of a new session.
    812 
    813     Use the `use_gpu` and `force_gpu` options to control where ops are run. If
    814     `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
    815     `use_gpu`
    816     is True, TensorFlow tries to run as many ops on the GPU as possible. If both
    817     `force_gpu and `use_gpu` are False, all ops are pinned to the CPU.
    818 
    819     Example:
    820     ```python
    821     class MyOperatorTest(test_util.TensorFlowTestCase):
    822       def testMyOperator(self):
    823         with self.test_session(use_gpu=True):
    824           valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
    825           result = MyOperator(valid_input).eval()
    826           self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
    827           invalid_input = [-1.0, 2.0, 7.0]
    828           with self.assertRaisesOpError("negative input not supported"):
    829             MyOperator(invalid_input).eval()
    830     ```
    831 
    832     Args:
    833       graph: Optional graph to use during the returned session.
    834       config: An optional config_pb2.ConfigProto to use to configure the
    835         session.
    836       use_gpu: If True, attempt to run as many ops as possible on GPU.
    837       force_gpu: If True, pin all ops to `/device:GPU:0`.
    838 
    839     Returns:
    840       A Session object that should be used as a context manager to surround
    841       the graph building and execution code in a test case.
    842     """
    843     if self.id().endswith(".test_session"):
    844       self.skipTest("Not a test.")
    845 
    846     def prepare_config(config):
    847       """Returns a config for sessions.
    848 
    849       Args:
    850         config: An optional config_pb2.ConfigProto to use to configure the
    851           session.
    852       Returns:
    853         A config_pb2.ConfigProto object.
    854       """
    855       if config is None:
    856         config = config_pb2.ConfigProto()
    857         config.allow_soft_placement = not force_gpu
    858         config.gpu_options.per_process_gpu_memory_fraction = 0.3
    859       elif force_gpu and config.allow_soft_placement:
    860         config = config_pb2.ConfigProto().CopyFrom(config)
    861         config.allow_soft_placement = False
    862       # Don't perform optimizations for tests so we don't inadvertently run
    863       # gpu ops on cpu
    864       config.graph_options.optimizer_options.opt_level = -1
    865       config.graph_options.rewrite_options.constant_folding = (
    866           rewriter_config_pb2.RewriterConfig.OFF)
    867       config.graph_options.rewrite_options.arithmetic_optimization = (
    868           rewriter_config_pb2.RewriterConfig.OFF)
    869       return config
    870 
    871     if graph is None:
    872       if self._cached_session is None:
    873         self._cached_session = session.Session(
    874             graph=None, config=prepare_config(config))
    875       sess = self._cached_session
    876       with sess.graph.as_default(), sess.as_default():
    877         if force_gpu:
    878           # Use the name of an actual device if one is detected, or '/device:GPU:0'
    879           # otherwise
    880           gpu_name = gpu_device_name()
    881           if not gpu_name:
    882             gpu_name = "/device:GPU:0"
    883           with sess.graph.device(gpu_name):
    884             yield sess
    885         elif use_gpu:
    886           yield sess
    887         else:
    888           with sess.graph.device("/cpu:0"):
    889             yield sess
    890     else:
    891       with session.Session(graph=graph, config=prepare_config(config)) as sess:
    892         if force_gpu:
    893           # Use the name of an actual device if one is detected, or '/device:GPU:0'
    894           # otherwise
    895           gpu_name = gpu_device_name()
    896           if not gpu_name:
    897             gpu_name = "/device:GPU:0"
    898           with sess.graph.device(gpu_name):
    899             yield sess
    900         elif use_gpu:
    901           yield sess
    902         else:
    903           with sess.graph.device("/cpu:0"):
    904             yield sess
    905 
    906   # pylint: enable=g-doc-return-or-yield
    907 
    908   class _CheckedThread(object):
    909     """A wrapper class for Thread that asserts successful completion.
    910 
    911     This class should be created using the TensorFlowTestCase.checkedThread()
    912     method.
    913     """
    914 
    915     def __init__(self, testcase, target, args=None, kwargs=None):
    916       """Constructs a new instance of _CheckedThread.
    917 
    918       Args:
    919         testcase: The TensorFlowTestCase for which this thread is being created.
    920         target: A callable object representing the code to be executed in the
    921           thread.
    922         args: A tuple of positional arguments that will be passed to target.
    923         kwargs: A dictionary of keyword arguments that will be passed to target.
    924       """
    925       self._testcase = testcase
    926       self._target = target
    927       self._args = () if args is None else args
    928       self._kwargs = {} if kwargs is None else kwargs
    929       self._thread = threading.Thread(target=self._protected_run)
    930       self._exception = None
    931 
    932       self._is_thread_joined = False
    933 
    934     def _protected_run(self):
    935       """Target for the wrapper thread. Sets self._exception on failure."""
    936       try:
    937         self._target(*self._args, **self._kwargs)
    938       except Exception as e:  # pylint: disable=broad-except
    939         self._exception = e
    940 
    941     def start(self):
    942       """Starts the thread's activity.
    943 
    944       This must be called at most once per _CheckedThread object. It arranges
    945       for the object's target to be invoked in a separate thread of control.
    946       """
    947       self._thread.start()
    948 
    949     def join(self):
    950       """Blocks until the thread terminates.
    951 
    952       Raises:
    953         self._testcase.failureException: If the thread terminates with due to
    954           an exception.
    955       """
    956       self._is_thread_joined = True
    957       self._thread.join()
    958       if self._exception is not None:
    959         self._testcase.fail("Error in checkedThread: %s" % str(self._exception))
    960 
    961     def is_alive(self):
    962       """Returns whether the thread is alive.
    963 
    964       This method returns True just before the run() method starts
    965       until just after the run() method terminates.
    966 
    967       Returns:
    968         True if the thread is alive, otherwise False.
    969       """
    970       return self._thread.is_alive()
    971 
    972     def check_termination(self):
    973       """Returns whether the checked thread was properly used and did terminate.
    974 
    975       Every checked thread should be "join"ed after starting, and before the
    976       test tears down. If it is not joined, it is possible the thread will hang
    977       and cause flaky failures in tests.
    978 
    979       Raises:
    980         self._testcase.failureException: If check_termination was called before
    981         thread was joined.
    982 
    983         RuntimeError: If the thread is not terminated. This means thread was not
    984         joined with the main thread.
    985       """
    986       if self._is_thread_joined:
    987         if self.is_alive():
    988           raise RuntimeError(
    989               "Thread was not joined with main thread, and is still running "
    990               "when the test finished.")
    991       else:
    992         self._testcase.fail("A checked thread was not joined.")
    993 
    994   def checkedThread(self, target, args=None, kwargs=None):
    995     """Returns a Thread wrapper that asserts 'target' completes successfully.
    996 
    997     This method should be used to create all threads in test cases, as
    998     otherwise there is a risk that a thread will silently fail, and/or
    999     assertions made in the thread will not be respected.
   1000 
   1001     Args:
   1002       target: A callable object to be executed in the thread.
   1003       args: The argument tuple for the target invocation. Defaults to ().
   1004       kwargs: A dictionary of keyword arguments for the target invocation.
   1005         Defaults to {}.
   1006 
   1007     Returns:
   1008       A wrapper for threading.Thread that supports start() and join() methods.
   1009     """
   1010     ret = TensorFlowTestCase._CheckedThread(self, target, args, kwargs)
   1011     self._threads.append(ret)
   1012     return ret
   1013 
   1014 
   1015 # pylint: enable=invalid-name
   1016 
   1017   def assertNear(self, f1, f2, err, msg=None):
   1018     """Asserts that two floats are near each other.
   1019 
   1020     Checks that |f1 - f2| < err and asserts a test failure
   1021     if not.
   1022 
   1023     Args:
   1024       f1: A float value.
   1025       f2: A float value.
   1026       err: A float value.
   1027       msg: An optional string message to append to the failure message.
   1028     """
   1029     # f1 == f2 is needed here as we might have: f1, f2 = inf, inf
   1030     self.assertTrue(f1 == f2 or math.fabs(f1 - f2) <= err,
   1031                     "%f != %f +/- %f%s" % (f1, f2, err, " (%s)" % msg
   1032                                            if msg is not None else ""))
   1033 
   1034   def assertArrayNear(self, farray1, farray2, err, msg=None):
   1035     """Asserts that two float arrays are near each other.
   1036 
   1037     Checks that for all elements of farray1 and farray2
   1038     |f1 - f2| < err.  Asserts a test failure if not.
   1039 
   1040     Args:
   1041       farray1: a list of float values.
   1042       farray2: a list of float values.
   1043       err: a float value.
   1044       msg: Optional message to report on failure.
   1045     """
   1046     self.assertEqual(len(farray1), len(farray2), msg=msg)
   1047     for f1, f2 in zip(farray1, farray2):
   1048       self.assertNear(float(f1), float(f2), err, msg=msg)
   1049 
   1050   def _NDArrayNear(self, ndarray1, ndarray2, err):
   1051     return np.linalg.norm(ndarray1 - ndarray2) < err
   1052 
   1053   def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None):
   1054     """Asserts that two numpy arrays have near values.
   1055 
   1056     Args:
   1057       ndarray1: a numpy ndarray.
   1058       ndarray2: a numpy ndarray.
   1059       err: a float. The maximum absolute difference allowed.
   1060       msg: Optional message to report on failure.
   1061     """
   1062     self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err), msg=msg)
   1063 
   1064   def _GetNdArray(self, a):
   1065     if not isinstance(a, np.ndarray):
   1066       a = np.array(a)
   1067     return a
   1068 
   1069   def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
   1070     a = self._GetNdArray(a)
   1071     b = self._GetNdArray(b)
   1072     self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s." %
   1073                      (a.shape, b.shape))
   1074     if not np.allclose(a, b, rtol=rtol, atol=atol):
   1075       # Prints more details than np.testing.assert_allclose.
   1076       #
   1077       # NOTE: numpy.allclose (and numpy.testing.assert_allclose)
   1078       # checks whether two arrays are element-wise equal within a
   1079       # tolerance. The relative difference (rtol * abs(b)) and the
   1080       # absolute difference atol are added together to compare against
   1081       # the absolute difference between a and b.  Here, we want to
   1082       # print out which elements violate such conditions.
   1083       cond = np.logical_or(
   1084           np.abs(a - b) > atol + rtol * np.abs(b),
   1085           np.isnan(a) != np.isnan(b))
   1086       if a.ndim:
   1087         x = a[np.where(cond)]
   1088         y = b[np.where(cond)]
   1089         print("not close where = ", np.where(cond))
   1090       else:
   1091         # np.where is broken for scalars
   1092         x, y = a, b
   1093       print("not close lhs = ", x)
   1094       print("not close rhs = ", y)
   1095       print("not close dif = ", np.abs(x - y))
   1096       print("not close tol = ", atol + rtol * np.abs(y))
   1097       print("dtype = %s, shape = %s" % (a.dtype, a.shape))
   1098       # TODO(xpan): There seems to be a bug:
   1099       # tensorflow/compiler/tests:binary_ops_test pass with float32
   1100       # nan even though the equal_nan is False by default internally.
   1101       np.testing.assert_allclose(
   1102           a, b, rtol=rtol, atol=atol, err_msg=msg, equal_nan=True)
   1103 
   1104   def _assertAllCloseRecursive(self,
   1105                                a,
   1106                                b,
   1107                                rtol=1e-6,
   1108                                atol=1e-6,
   1109                                path=None,
   1110                                msg=None):
   1111     path = path or []
   1112     path_str = (("[" + "][".join([str(p) for p in path]) + "]") if path else "")
   1113     msg = msg if msg else ""
   1114 
   1115     # Check if a and/or b are namedtuples.
   1116     if hasattr(a, "_asdict"):
   1117       a = a._asdict()
   1118     if hasattr(b, "_asdict"):
   1119       b = b._asdict()
   1120     a_is_dict = isinstance(a, dict)
   1121     if a_is_dict != isinstance(b, dict):
   1122       raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" %
   1123                        (path_str, path_str, msg))
   1124     if a_is_dict:
   1125       self.assertItemsEqual(
   1126           a.keys(),
   1127           b.keys(),
   1128           msg="mismatched keys: a%s has keys %s, but b%s has keys %s. %s" %
   1129           (path_str, a.keys(), path_str, b.keys(), msg))
   1130       for k in a:
   1131         path.append(k)
   1132         self._assertAllCloseRecursive(
   1133             a[k], b[k], rtol=rtol, atol=atol, path=path, msg=msg)
   1134         del path[-1]
   1135     elif isinstance(a, (list, tuple)):
   1136       # Try to directly compare a, b as ndarrays; if not work, then traverse
   1137       # through the sequence, which is more expensive.
   1138       try:
   1139         a_as_ndarray = np.array(a)
   1140         b_as_ndarray = np.array(b)
   1141         self._assertArrayLikeAllClose(
   1142             a_as_ndarray,
   1143             b_as_ndarray,
   1144             rtol=rtol,
   1145             atol=atol,
   1146             msg="Mismatched value: a%s is different from b%s. %s" %
   1147             (path_str, path_str, msg))
   1148       except (ValueError, TypeError) as e:
   1149         if len(a) != len(b):
   1150           raise ValueError(
   1151               "Mismatched length: a%s has %d items, but b%s has %d items. %s" %
   1152               (path_str, len(a), path_str, len(b), msg))
   1153         for idx, (a_ele, b_ele) in enumerate(zip(a, b)):
   1154           path.append(str(idx))
   1155           self._assertAllCloseRecursive(
   1156               a_ele, b_ele, rtol=rtol, atol=atol, path=path, msg=msg)
   1157           del path[-1]
   1158     # a and b are ndarray like objects
   1159     else:
   1160       try:
   1161         self._assertArrayLikeAllClose(
   1162             a,
   1163             b,
   1164             rtol=rtol,
   1165             atol=atol,
   1166             msg="Mismatched value: a%s is different from b%s." % (path_str,
   1167                                                                   path_str))
   1168       except TypeError as e:
   1169         msg = "Error: a%s has %s, but b%s has %s" % (
   1170             path_str, type(a), path_str, type(b))
   1171         e.args = ((e.args[0] + ' : ' + msg,) + e.args[1:])
   1172         raise
   1173 
   1174   def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
   1175     """Asserts that two structures of numpy arrays, have near values.
   1176 
   1177     `a` and `b` can be arbitrarily nested structures. A layer of a nested
   1178     structure can be a `dict`, `namedtuple`, `tuple` or `list`.
   1179 
   1180     Args:
   1181       a: The expected numpy `ndarray`, or anything that can be converted into a
   1182           numpy `ndarray`, or any arbitrarily nested of structure of these.
   1183       b: The actual numpy `ndarray`, or anything that can be converted into a
   1184           numpy `ndarray`, or any arbitrarily nested of structure of these.
   1185       rtol: relative tolerance.
   1186       atol: absolute tolerance.
   1187       msg: Optional message to report on failure.
   1188 
   1189     Raises:
   1190       ValueError: if only one of `a[p]` and `b[p]` is a dict or
   1191           `a[p]` and `b[p]` have different length, where `[p]` denotes a path
   1192           to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and
   1193           `[p] = [1]['d']`, then `a[p] = (6, 7)`.
   1194     """
   1195     self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg)
   1196 
   1197   def assertAllCloseAccordingToType(self,
   1198                                     a,
   1199                                     b,
   1200                                     rtol=1e-6,
   1201                                     atol=1e-6,
   1202                                     float_rtol=1e-6,
   1203                                     float_atol=1e-6,
   1204                                     half_rtol=1e-3,
   1205                                     half_atol=1e-3,
   1206                                     bfloat16_rtol=1e-2,
   1207                                     bfloat16_atol=1e-2,
   1208                                     msg=None):
   1209     """Like assertAllClose, but also suitable for comparing fp16 arrays.
   1210 
   1211     In particular, the tolerance is reduced to 1e-3 if at least
   1212     one of the arguments is of type float16.
   1213 
   1214     Args:
   1215       a: the expected numpy ndarray or anything can be converted to one.
   1216       b: the actual numpy ndarray or anything can be converted to one.
   1217       rtol: relative tolerance.
   1218       atol: absolute tolerance.
   1219       float_rtol: relative tolerance for float32.
   1220       float_atol: absolute tolerance for float32.
   1221       half_rtol: relative tolerance for float16.
   1222       half_atol: absolute tolerance for float16.
   1223       bfloat16_rtol: relative tolerance for bfloat16.
   1224       bfloat16_atol: absolute tolerance for bfloat16.
   1225       msg: Optional message to report on failure.
   1226     """
   1227     a = self._GetNdArray(a)
   1228     b = self._GetNdArray(b)
   1229     # types with lower tol are put later to overwrite previous ones.
   1230     if (a.dtype == np.float32 or b.dtype == np.float32 or
   1231         a.dtype == np.complex64 or b.dtype == np.complex64):
   1232       rtol = max(rtol, float_rtol)
   1233       atol = max(atol, float_atol)
   1234     if a.dtype == np.float16 or b.dtype == np.float16:
   1235       rtol = max(rtol, half_rtol)
   1236       atol = max(atol, half_atol)
   1237     if (a.dtype == dtypes.bfloat16.as_numpy_dtype or
   1238         b.dtype == dtypes.bfloat16.as_numpy_dtype):
   1239       rtol = max(rtol, bfloat16_rtol)
   1240       atol = max(atol, bfloat16_atol)
   1241 
   1242     self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg)
   1243 
   1244   def assertAllEqual(self, a, b, msg=None):
   1245     """Asserts that two numpy arrays have the same values.
   1246 
   1247     Args:
   1248       a: the expected numpy ndarray or anything can be converted to one.
   1249       b: the actual numpy ndarray or anything can be converted to one.
   1250       msg: Optional message to report on failure.
   1251     """
   1252     msg = msg if msg else ""
   1253     a = self._GetNdArray(a)
   1254     b = self._GetNdArray(b)
   1255     self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s."
   1256                      " %s" % (a.shape, b.shape, msg))
   1257     same = (a == b)
   1258 
   1259     if a.dtype == np.float32 or a.dtype == np.float64:
   1260       same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b)))
   1261     if not np.all(same):
   1262       # Prints more details than np.testing.assert_array_equal.
   1263       diff = np.logical_not(same)
   1264       if a.ndim:
   1265         x = a[np.where(diff)]
   1266         y = b[np.where(diff)]
   1267         print("not equal where = ", np.where(diff))
   1268       else:
   1269         # np.where is broken for scalars
   1270         x, y = a, b
   1271       print("not equal lhs = ", x)
   1272       print("not equal rhs = ", y)
   1273       np.testing.assert_array_equal(a, b, err_msg=msg)
   1274 
   1275   # pylint: disable=g-doc-return-or-yield
   1276   @contextlib.contextmanager
   1277   def assertRaisesWithPredicateMatch(self, exception_type,
   1278                                      expected_err_re_or_predicate):
   1279     """Returns a context manager to enclose code expected to raise an exception.
   1280 
   1281     If the exception is an OpError, the op stack is also included in the message
   1282     predicate search.
   1283 
   1284     Args:
   1285       exception_type: The expected type of exception that should be raised.
   1286       expected_err_re_or_predicate: If this is callable, it should be a function
   1287         of one argument that inspects the passed-in exception and
   1288         returns True (success) or False (please fail the test). Otherwise, the
   1289         error message is expected to match this regular expression partially.
   1290 
   1291     Returns:
   1292       A context manager to surround code that is expected to raise an
   1293       exception.
   1294     """
   1295     if callable(expected_err_re_or_predicate):
   1296       predicate = expected_err_re_or_predicate
   1297     else:
   1298 
   1299       def predicate(e):
   1300         err_str = e.message if isinstance(e, errors.OpError) else str(e)
   1301         op = e.op if isinstance(e, errors.OpError) else None
   1302         while op is not None:
   1303           err_str += "\nCaused by: " + op.name
   1304           op = op._original_op  # pylint: disable=protected-access
   1305         logging.info("Searching within error strings: '%s' within '%s'",
   1306                      expected_err_re_or_predicate, err_str)
   1307         return re.search(expected_err_re_or_predicate, err_str)
   1308 
   1309     try:
   1310       yield
   1311       self.fail(exception_type.__name__ + " not raised")
   1312     except Exception as e:  # pylint: disable=broad-except
   1313       if not isinstance(e, exception_type) or not predicate(e):
   1314         raise AssertionError("Exception of type %s: %s" % (str(type(e)),
   1315                                                            str(e)))
   1316 
   1317   # pylint: enable=g-doc-return-or-yield
   1318 
   1319   def assertRaisesOpError(self, expected_err_re_or_predicate):
   1320     return self.assertRaisesWithPredicateMatch(errors.OpError,
   1321                                                expected_err_re_or_predicate)
   1322 
   1323   def assertShapeEqual(self, np_array, tf_tensor, msg=None):
   1324     """Asserts that a Numpy ndarray and a TensorFlow tensor have the same shape.
   1325 
   1326     Args:
   1327       np_array: A Numpy ndarray or Numpy scalar.
   1328       tf_tensor: A Tensor.
   1329       msg: Optional message to report on failure.
   1330 
   1331     Raises:
   1332       TypeError: If the arguments have the wrong type.
   1333     """
   1334     if not isinstance(np_array, (np.ndarray, np.generic)):
   1335       raise TypeError("np_array must be a Numpy ndarray or Numpy scalar")
   1336     if not isinstance(tf_tensor, ops.Tensor):
   1337       raise TypeError("tf_tensor must be a Tensor")
   1338     self.assertAllEqual(
   1339         np_array.shape, tf_tensor.get_shape().as_list(), msg=msg)
   1340 
   1341   def assertDeviceEqual(self, device1, device2, msg=None):
   1342     """Asserts that the two given devices are the same.
   1343 
   1344     Args:
   1345       device1: A string device name or TensorFlow `DeviceSpec` object.
   1346       device2: A string device name or TensorFlow `DeviceSpec` object.
   1347       msg: Optional message to report on failure.
   1348     """
   1349     device1 = pydev.canonical_name(device1)
   1350     device2 = pydev.canonical_name(device2)
   1351     self.assertEqual(device1, device2,
   1352                      "Devices %s and %s are not equal. %s" % 
   1353                      (device1, device2, msg))
   1354 
   1355   # Fix Python 3 compatibility issues
   1356   if six.PY3:
   1357     # pylint: disable=invalid-name
   1358 
   1359     # Silence a deprecation warning
   1360     assertRaisesRegexp = googletest.TestCase.assertRaisesRegex
   1361 
   1362     # assertItemsEqual is assertCountEqual as of 3.2.
   1363     assertItemsEqual = googletest.TestCase.assertCountEqual
   1364 
   1365     # pylint: enable=invalid-name
   1366 
   1367 
   1368 @tf_export("test.create_local_cluster")
   1369 def create_local_cluster(num_workers,
   1370                          num_ps,
   1371                          protocol="grpc",
   1372                          worker_config=None,
   1373                          ps_config=None):
   1374   """Create and start local servers and return the associated `Server` objects.
   1375 
   1376   Example:
   1377   ```python
   1378   workers, _ = tf.test.create_local_cluster(num_workers=2, num_ps=2)
   1379 
   1380   worker_sessions = [tf.Session(w.target) for w in workers]
   1381 
   1382   with tf.device("/job:ps/task:0"):
   1383     ...
   1384   with tf.device("/job:ps/task:1"):
   1385     ...
   1386   with tf.device("/job:worker/task:0"):
   1387     ...
   1388   with tf.device("/job:worker/task:1"):
   1389     ...
   1390 
   1391   worker_sessions[0].run(...)
   1392   ```
   1393 
   1394   Args:
   1395     num_workers: Number of worker servers to start.
   1396     num_ps: Number of PS servers to start.
   1397     protocol: Communication protocol.  Allowed values are documented in
   1398       the documentation of `tf.train.Server`.
   1399     worker_config: (optional) ConfigProto to initialize workers. Can be used
   1400       to instantiate multiple devices etc.
   1401     ps_config: (optional) ConfigProto to initialize PS servers.
   1402 
   1403   Returns:
   1404     A tuple `(worker_servers, ps_servers)`.  `worker_servers` is a list
   1405     of `num_workers` objects of type `tf.train.Server` (all running locally);
   1406     and `ps_servers` is a list of `num_ps` objects of similar type.
   1407 
   1408   Raises:
   1409     ImportError: if portpicker module was not found at load time
   1410   """
   1411   if _portpicker_import_error:
   1412     raise _portpicker_import_error  # pylint: disable=raising-bad-type
   1413   worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
   1414   ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
   1415   cluster_dict = {
   1416       "worker": ["localhost:%s" % port for port in worker_ports],
   1417       "ps": ["localhost:%s" % port for port in ps_ports]
   1418   }
   1419   cs = server_lib.ClusterSpec(cluster_dict)
   1420 
   1421   workers = [
   1422       server_lib.Server(
   1423           cs,
   1424           job_name="worker",
   1425           protocol=protocol,
   1426           task_index=ix,
   1427           config=worker_config,
   1428           start=True) for ix in range(num_workers)
   1429   ]
   1430   ps_servers = [
   1431       server_lib.Server(
   1432           cs,
   1433           job_name="ps",
   1434           protocol=protocol,
   1435           task_index=ix,
   1436           config=ps_config,
   1437           start=True) for ix in range(num_ps)
   1438   ]
   1439 
   1440   return workers, ps_servers
   1441 
   1442 
   1443 def get_node_def_from_graph(node_name, graph_def):
   1444   """Returns the `NodeDef` instance for given node name in the graph def.
   1445 
   1446   This method explores only the NodeDefs in `graph_def.node`.
   1447 
   1448   Args:
   1449     node_name: Name of the NodeDef to search for.
   1450     graph_def: An instance of `GraphDef` proto.
   1451 
   1452   Returns:
   1453     the `NodeDef` instance whose name field matches the given node_name or None.
   1454   """
   1455   for node_def in graph_def.node:
   1456     if node_def.name == node_name:
   1457       return node_def
   1458   return None
   1459 
   1460 
   1461 def set_producer_version(graph, producer_version):
   1462   """Sets graph.graph_def_versions.producer to `producer_version`."""
   1463   # The C API doesn't expose altering GraphDefVersions. We can indirectly set
   1464   # it via import_graph_def though.
   1465   graph_def = graph_pb2.GraphDef()
   1466   graph_def.versions.producer = producer_version
   1467   with graph.as_default():
   1468     importer.import_graph_def(graph_def)
   1469   assert graph.graph_def_versions.producer, producer_version
   1470