1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 # pylint: disable=invalid-name 17 """Test utils for tensorflow.""" 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import contextlib 23 import gc 24 import math 25 import random 26 import re 27 import tempfile 28 import threading 29 30 import numpy as np 31 import six 32 33 _portpicker_import_error = None 34 try: 35 import portpicker # pylint: disable=g-import-not-at-top 36 except ImportError as _error: 37 _portpicker_import_error = _error 38 portpicker = None 39 40 # pylint: disable=g-import-not-at-top 41 from google.protobuf import descriptor_pool 42 from google.protobuf import text_format 43 44 from tensorflow.core.framework import graph_pb2 45 from tensorflow.core.protobuf import config_pb2 46 from tensorflow.core.protobuf import rewriter_config_pb2 47 from tensorflow.python import pywrap_tensorflow 48 from tensorflow.python.client import device_lib 49 from tensorflow.python.client import session 50 from tensorflow.python.eager import backprop 51 from tensorflow.python.eager import context 52 from tensorflow.python.eager import tape # pylint: disable=unused-import 53 from tensorflow.python.framework import device as pydev 54 from tensorflow.python.framework import dtypes 55 from tensorflow.python.framework import errors 56 from tensorflow.python.framework import importer 57 from tensorflow.python.framework import ops 58 from tensorflow.python.framework import random_seed 59 from tensorflow.python.framework import versions 60 from tensorflow.python.ops import array_ops 61 from tensorflow.python.ops import resource_variable_ops 62 from tensorflow.python.ops import variables 63 from tensorflow.python.platform import googletest 64 from tensorflow.python.platform import tf_logging as logging 65 from tensorflow.python.training import server_lib 66 from tensorflow.python.util import compat 67 from tensorflow.python.util import nest 68 from tensorflow.python.util.protobuf import compare 69 from tensorflow.python.util.tf_export import tf_export 70 71 72 @tf_export("test.gpu_device_name") 73 def gpu_device_name(): 74 """Returns the name of a GPU device if available or the empty string.""" 75 for x in device_lib.list_local_devices(): 76 if x.device_type == "GPU" or x.device_type == "SYCL": 77 return compat.as_str(x.name) 78 return "" 79 80 81 def assert_ops_in_graph(expected_ops, graph): 82 """Assert all expected operations are found. 83 84 Args: 85 expected_ops: `dict<string, string>` of op name to op type. 86 graph: Graph to check. 87 Returns: 88 `dict<string, node>` of node name to node. 89 90 Raises: 91 ValueError: If the expected ops are not present in the graph. 92 """ 93 actual_ops = {} 94 gd = graph.as_graph_def() 95 for node in gd.node: 96 if node.name in expected_ops: 97 if expected_ops[node.name] != node.op: 98 raise ValueError("Expected op for node %s is different. %s vs %s" % 99 (node.name, expected_ops[node.name], node.op)) 100 actual_ops[node.name] = node 101 if set(expected_ops.keys()) != set(actual_ops.keys()): 102 raise ValueError("Not all expected ops are present. Expected %s, found %s" % 103 (expected_ops.keys(), actual_ops.keys())) 104 return actual_ops 105 106 107 @tf_export("test.assert_equal_graph_def") 108 def assert_equal_graph_def(actual, expected, checkpoint_v2=False): 109 """Asserts that two `GraphDef`s are (mostly) the same. 110 111 Compares two `GraphDef` protos for equality, ignoring versions and ordering of 112 nodes, attrs, and control inputs. Node names are used to match up nodes 113 between the graphs, so the naming of nodes must be consistent. 114 115 Args: 116 actual: The `GraphDef` we have. 117 expected: The `GraphDef` we expected. 118 checkpoint_v2: boolean determining whether to ignore randomized attribute 119 values that appear in V2 checkpoints. 120 121 Raises: 122 AssertionError: If the `GraphDef`s do not match. 123 TypeError: If either argument is not a `GraphDef`. 124 """ 125 if not isinstance(actual, graph_pb2.GraphDef): 126 raise TypeError( 127 "Expected tf.GraphDef for actual, got %s" % type(actual).__name__) 128 if not isinstance(expected, graph_pb2.GraphDef): 129 raise TypeError( 130 "Expected tf.GraphDef for expected, got %s" % type(expected).__name__) 131 132 if checkpoint_v2: 133 _strip_checkpoint_v2_randomized(actual) 134 _strip_checkpoint_v2_randomized(expected) 135 136 diff = pywrap_tensorflow.EqualGraphDefWrapper(actual.SerializeToString(), 137 expected.SerializeToString()) 138 if diff: 139 raise AssertionError(compat.as_str(diff)) 140 141 142 def assert_meta_graph_protos_equal(tester, a, b): 143 """Compares MetaGraphDefs `a` and `b` in unit test class `tester`.""" 144 # Carefully check the collection_defs 145 tester.assertEqual(set(a.collection_def), set(b.collection_def)) 146 collection_keys = a.collection_def.keys() 147 for k in collection_keys: 148 a_value = a.collection_def[k] 149 b_value = b.collection_def[k] 150 proto_type = ops.get_collection_proto_type(k) 151 if proto_type: 152 a_proto = proto_type() 153 b_proto = proto_type() 154 # Number of entries in the collections is the same 155 tester.assertEqual( 156 len(a_value.bytes_list.value), len(b_value.bytes_list.value)) 157 for (a_value_item, b_value_item) in zip(a_value.bytes_list.value, 158 b_value.bytes_list.value): 159 a_proto.ParseFromString(a_value_item) 160 b_proto.ParseFromString(b_value_item) 161 tester.assertProtoEquals(a_proto, b_proto) 162 else: 163 tester.assertEquals(a_value, b_value) 164 # Compared the fields directly, remove their raw values from the 165 # proto comparison below. 166 a.ClearField("collection_def") 167 b.ClearField("collection_def") 168 169 # Check the graph_defs. 170 assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True) 171 # Check graph_def versions (ignored by assert_equal_graph_def). 172 tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions) 173 # Compared the fields directly, remove their raw values from the 174 # proto comparison below. 175 a.ClearField("graph_def") 176 b.ClearField("graph_def") 177 178 tester.assertProtoEquals(a, b) 179 180 181 # Matches attributes named via _SHARDED_SUFFIX in 182 # tensorflow/python/training/saver.py 183 _SHARDED_SAVE_OP_PATTERN = "_temp_[0-9a-z]{32}/part" 184 185 186 def _strip_checkpoint_v2_randomized(graph_def): 187 for node in graph_def.node: 188 delete_keys = [] 189 for attr_key in node.attr: 190 attr_tensor_value = node.attr[attr_key].tensor 191 if attr_tensor_value and len(attr_tensor_value.string_val) == 1: 192 attr_tensor_string_value = attr_tensor_value.string_val[0] 193 if (attr_tensor_string_value and 194 re.match(_SHARDED_SAVE_OP_PATTERN, str(attr_tensor_string_value))): 195 delete_keys.append(attr_key) 196 for attr_key in delete_keys: 197 del node.attr[attr_key] 198 199 200 def IsGoogleCudaEnabled(): 201 return pywrap_tensorflow.IsGoogleCudaEnabled() 202 203 204 def CudaSupportsHalfMatMulAndConv(): 205 return pywrap_tensorflow.CudaSupportsHalfMatMulAndConv() 206 207 208 def InstallStackTraceHandler(): 209 pywrap_tensorflow.InstallStacktraceHandler() 210 211 212 def NHWCToNCHW(input_tensor): 213 """Converts the input from the NHWC format to NCHW. 214 215 Args: 216 input_tensor: a 4- or 5-D tensor, or an array representing shape 217 218 Returns: 219 converted tensor or shape array 220 """ 221 # tensor dim -> new axis order 222 new_axes = {4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]} 223 if isinstance(input_tensor, ops.Tensor): 224 ndims = input_tensor.shape.ndims 225 return array_ops.transpose(input_tensor, new_axes[ndims]) 226 else: 227 ndims = len(input_tensor) 228 return [input_tensor[a] for a in new_axes[ndims]] 229 230 231 def NHWCToNCHW_VECT_C(input_shape_or_tensor): 232 """Transforms the input from the NHWC layout to NCHW_VECT_C layout. 233 234 Note: Does not include quantization or type conversion steps, which should 235 be applied afterwards. 236 237 Args: 238 input_shape_or_tensor: a 4- or 5-D tensor, or an array representing shape 239 240 Returns: 241 tensor or shape array transformed into NCHW_VECT_C 242 243 Raises: 244 ValueError: if last dimension of `input_shape_or_tensor` is not evenly 245 divisible by 4. 246 """ 247 permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]} 248 is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) 249 temp_shape = ( 250 input_shape_or_tensor.shape.as_list() 251 if is_tensor else input_shape_or_tensor) 252 if temp_shape[-1] % 4 != 0: 253 raise ValueError( 254 "Last dimension of input must be evenly divisible by 4 to convert to " 255 "NCHW_VECT_C.") 256 temp_shape[-1] //= 4 257 temp_shape.append(4) 258 permutation = permutations[len(temp_shape)] 259 if is_tensor: 260 t = array_ops.reshape(input_shape_or_tensor, temp_shape) 261 return array_ops.transpose(t, permutation) 262 else: 263 return [temp_shape[a] for a in permutation] 264 265 266 def NCHW_VECT_CToNHWC(input_shape_or_tensor): 267 """Transforms the input from the NCHW_VECT_C layout to NHWC layout. 268 269 Note: Does not include de-quantization or type conversion steps, which should 270 be applied beforehand. 271 272 Args: 273 input_shape_or_tensor: a 5- or 6-D tensor, or an array representing shape 274 275 Returns: 276 tensor or shape array transformed into NHWC 277 278 Raises: 279 ValueError: if last dimension of `input_shape_or_tensor` is not 4. 280 """ 281 permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]} 282 is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) 283 input_shape = ( 284 input_shape_or_tensor.shape.as_list() 285 if is_tensor else input_shape_or_tensor) 286 if input_shape[-1] != 4: 287 raise ValueError("Last dimension of NCHW_VECT_C must be 4.") 288 permutation = permutations[len(input_shape)] 289 nhwc_shape = [input_shape[a] for a in permutation[:-1]] 290 nhwc_shape[-1] *= input_shape[-1] 291 if is_tensor: 292 t = array_ops.transpose(input_shape_or_tensor, permutation) 293 return array_ops.reshape(t, nhwc_shape) 294 else: 295 return nhwc_shape 296 297 298 def NCHWToNHWC(input_tensor): 299 """Converts the input from the NCHW format to NHWC. 300 301 Args: 302 input_tensor: a 4- or 5-D tensor, or an array representing shape 303 304 Returns: 305 converted tensor or shape array 306 """ 307 # tensor dim -> new axis order 308 new_axes = {4: [0, 2, 3, 1], 5: [0, 2, 3, 4, 1]} 309 if isinstance(input_tensor, ops.Tensor): 310 ndims = input_tensor.shape.ndims 311 return array_ops.transpose(input_tensor, new_axes[ndims]) 312 else: 313 ndims = len(input_tensor) 314 return [input_tensor[a] for a in new_axes[ndims]] 315 316 317 # TODO(skyewm): remove this eventually 318 # pylint: disable=protected-access 319 def _use_c_api_wrapper(fn, use_c_api, *args, **kwargs): 320 prev_value = ops._USE_C_API 321 ops._USE_C_API = use_c_api 322 try: 323 # Reset the default graph so it has the C API enabled. We call 324 # reset_default_graph() instead of creating a new default Graph context to 325 # make this robust to tests that call reset_default_graph(), which requires 326 # that the current default graph isn't nested. 327 ops.reset_default_graph() 328 fn(*args, **kwargs) 329 finally: 330 ops._USE_C_API = prev_value 331 # Make sure default graph reflects prev_value in case next test doesn't call 332 # reset_default_graph(). 333 ops.reset_default_graph() 334 # pylint: disable=protected-access 335 336 337 def c_api_and_cuda_enabled(): 338 return ops._USE_C_API and IsGoogleCudaEnabled() 339 340 341 def skip_if(condition): 342 """Skips the decorated function if condition is or evaluates to True. 343 344 Args: 345 condition: Either an expression that can be used in "if not condition" 346 statement, or a callable whose result should be a boolean. 347 Returns: 348 The wrapped function 349 """ 350 351 def real_skip_if(fn): 352 353 def wrapper(*args, **kwargs): 354 if callable(condition): 355 skip = condition() 356 else: 357 skip = condition 358 if not skip: 359 fn(*args, **kwargs) 360 361 return wrapper 362 363 return real_skip_if 364 365 366 # TODO(skyewm): remove this eventually 367 def disable_c_api(fn): 368 """Decorator for disabling the C API on a test. 369 370 Note this disables the C API after running the test class's setup/teardown 371 methods. 372 373 Args: 374 fn: the function to be wrapped 375 376 Returns: 377 The wrapped function 378 """ 379 380 def wrapper(*args, **kwargs): 381 _use_c_api_wrapper(fn, False, *args, **kwargs) 382 383 return wrapper 384 385 386 # TODO(skyewm): remove this eventually 387 def enable_c_api(fn): 388 """Decorator for enabling the C API on a test. 389 390 Note this enables the C API after running the test class's setup/teardown 391 methods. 392 393 Args: 394 fn: the function to be wrapped 395 396 Returns: 397 The wrapped function 398 """ 399 400 def wrapper(*args, **kwargs): 401 _use_c_api_wrapper(fn, True, *args, **kwargs) 402 403 return wrapper 404 405 406 # This decorator is a hacky way to run all the test methods in a decorated 407 # class with and without C API enabled. 408 # TODO(iga): Remove this and its uses once we switch to using C API by default. 409 def with_c_api(cls): 410 """Adds methods that call original methods but with C API enabled. 411 412 Note this enables the C API in new methods after running the test class's 413 setup method. This can be a problem if some objects are created in it 414 before the C API is enabled. 415 416 Args: 417 cls: class to decorate 418 419 Returns: 420 cls with new test methods added 421 """ 422 for name, value in cls.__dict__.copy().items(): 423 if callable(value) and name.startswith("test"): 424 setattr(cls, name + "WithCApi", enable_c_api(value)) 425 return cls 426 427 428 def assert_no_new_tensors(f): 429 """Decorator for asserting that no new Tensors persist after a test. 430 431 Mainly useful for checking that code using the Python C API has correctly 432 manipulated reference counts. 433 434 Clears the caches that it knows about, runs the garbage collector, then checks 435 that there are no Tensor or Tensor-like objects still around. This includes 436 Tensors to which something still has a reference (e.g. from missing 437 Py_DECREFs) and uncollectable cycles (i.e. Python reference cycles where one 438 of the objects has __del__ defined). 439 440 Args: 441 f: The test case to run. 442 Returns: 443 The decorated test case. 444 """ 445 446 def decorator(self, **kwargs): 447 """Finds existing Tensors, runs the test, checks for new Tensors.""" 448 449 def _is_tensor(obj): 450 try: 451 return (isinstance(obj, ops.Tensor) or 452 isinstance(obj, variables.Variable)) 453 except ReferenceError: 454 # If the object no longer exists, we don't care about it. 455 return False 456 457 tensors_before = set(id(obj) for obj in gc.get_objects() if _is_tensor(obj)) 458 outside_graph_key = ops.get_default_graph()._graph_key 459 with ops.Graph().as_default(): 460 # Run the test in a new graph so that collections get cleared when it's 461 # done, but inherit the graph key so optimizers behave. 462 ops.get_default_graph()._graph_key = outside_graph_key 463 f(self, **kwargs) 464 # Make an effort to clear caches, which would otherwise look like leaked 465 # Tensors. 466 backprop._zeros_cache.flush() 467 context.get_default_context().scalar_cache().clear() 468 gc.collect() 469 tensors_after = [ 470 obj for obj in gc.get_objects() 471 if _is_tensor(obj) and id(obj) not in tensors_before 472 ] 473 if tensors_after: 474 raise AssertionError(("%d Tensors not deallocated after test: %s" % ( 475 len(tensors_after), 476 str(tensors_after), 477 ))) 478 479 return decorator 480 481 482 def assert_no_garbage_created(f): 483 """Test method decorator to assert that no garbage has been created. 484 485 Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters 486 cannot be un-set (i.e. will disable garbage collection for any other unit 487 tests in the same file/shard). 488 489 Args: 490 f: The function to decorate. 491 Returns: 492 The decorated function. 493 """ 494 495 def decorator(self, **kwargs): 496 """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage.""" 497 gc.disable() 498 previous_debug_flags = gc.get_debug() 499 gc.set_debug(gc.DEBUG_SAVEALL) 500 gc.collect() 501 previous_garbage = len(gc.garbage) 502 f(self, **kwargs) 503 gc.collect() 504 # This will fail if any garbage has been created, typically because of a 505 # reference cycle. 506 self.assertEqual(previous_garbage, len(gc.garbage)) 507 # TODO(allenl): Figure out why this debug flag reset doesn't work. It would 508 # be nice to be able to decorate arbitrary tests in a large test suite and 509 # not hold on to every object in other tests. 510 gc.set_debug(previous_debug_flags) 511 gc.enable() 512 513 return decorator 514 515 516 def run_in_graph_and_eager_modes(__unused__=None, 517 graph=None, 518 config=None, 519 use_gpu=False, 520 force_gpu=False, 521 reset_test=True, 522 assert_no_eager_garbage=False): 523 """Runs the test in both graph and eager modes. 524 525 Args: 526 __unused__: Prevents sliently skipping tests. 527 graph: Optional graph to use during the returned session. 528 config: An optional config_pb2.ConfigProto to use to configure the 529 session. 530 use_gpu: If True, attempt to run as many ops as possible on GPU. 531 force_gpu: If True, pin all ops to `/device:GPU:0`. 532 reset_test: If True, tearDown and SetUp the test case again. 533 assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage 534 collector and asserts that no extra garbage has been created when running 535 the test in eager mode. This will fail if there are reference cycles 536 (e.g. a = []; a.append(a)). Off by default because some tests may create 537 garbage for legitimate reasons (e.g. they define a class which inherits 538 from `object`), and because DEBUG_SAVEALL is sticky in some Python 539 interpreters (meaning that tests which rely on objects being collected 540 elsewhere in the unit test file will not work). Additionally, checks that 541 nothing still has a reference to Tensors that the test allocated. 542 Returns: 543 Returns a decorator that will run the decorated test function 544 using both a graph and using eager execution. 545 """ 546 547 assert not __unused__, "Add () after run_in_graph_and_eager_modes." 548 549 def decorator(f): 550 """Test method decorator.""" 551 552 def decorated(self, **kwargs): 553 """Decorated the test method.""" 554 with context.graph_mode(): 555 with self.test_session(graph, config, use_gpu, force_gpu): 556 f(self, **kwargs) 557 558 if reset_test: 559 # This decorator runs the wrapped test twice. 560 # Reset the test environment between runs. 561 self.tearDown() 562 self.setUp() 563 564 def run_eager_mode(self, **kwargs): 565 if force_gpu: 566 gpu_name = gpu_device_name() 567 if not gpu_name: 568 gpu_name = "/device:GPU:0" 569 with context.device(gpu_name): 570 f(self) 571 elif use_gpu: 572 # TODO(xpan): Support softplacement and gpu by default when available. 573 f(self, **kwargs) 574 else: 575 with context.device("/device:CPU:0"): 576 f(self, **kwargs) 577 578 if assert_no_eager_garbage: 579 run_eager_mode = assert_no_new_tensors( 580 assert_no_garbage_created(run_eager_mode)) 581 582 with context.eager_mode(): 583 with ops.Graph().as_default(): 584 run_eager_mode(self, **kwargs) 585 586 return decorated 587 588 return decorator 589 590 591 @tf_export("test.is_gpu_available") 592 def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None): 593 """Returns whether TensorFlow can access a GPU. 594 595 Args: 596 cuda_only: limit the search to CUDA gpus. 597 min_cuda_compute_capability: a (major,minor) pair that indicates the minimum 598 CUDA compute capability required, or None if no requirement. 599 600 Returns: 601 True iff a gpu device of the requested kind is available. 602 """ 603 604 def compute_capability_from_device_desc(device_desc): 605 # TODO(jingyue): The device description generator has to be in sync with 606 # this file. Another option is to put compute capability in 607 # DeviceAttributes, but I avoided that to keep DeviceAttributes 608 # target-independent. Reconsider this option when we have more things like 609 # this to keep in sync. 610 # LINT.IfChange 611 match = re.search(r"compute capability: (\d+)\.(\d+)", device_desc) 612 # LINT.ThenChange(//tensorflow/core/\ 613 # common_runtime/gpu/gpu_device.cc) 614 if not match: 615 return 0, 0 616 return int(match.group(1)), int(match.group(2)) 617 618 for local_device in device_lib.list_local_devices(): 619 if local_device.device_type == "GPU": 620 if (min_cuda_compute_capability is None or 621 compute_capability_from_device_desc(local_device.physical_device_desc) 622 >= min_cuda_compute_capability): 623 return True 624 if local_device.device_type == "SYCL" and not cuda_only: 625 return True 626 return False 627 628 629 @contextlib.contextmanager 630 def device(use_gpu): 631 """Uses gpu when requested and available.""" 632 if use_gpu and is_gpu_available(): 633 dev = "/device:GPU:0" 634 else: 635 dev = "/device:CPU:0" 636 with ops.device(dev): 637 yield 638 639 640 @tf_export("test.TestCase") 641 class TensorFlowTestCase(googletest.TestCase): 642 """Base class for tests that need to test TensorFlow. 643 """ 644 645 def __init__(self, methodName="runTest"): # pylint: disable=invalid-name 646 super(TensorFlowTestCase, self).__init__(methodName) 647 self._threads = [] 648 self._tempdir = None 649 self._cached_session = None 650 651 def setUp(self): 652 self._ClearCachedSession() 653 random.seed(random_seed.DEFAULT_GRAPH_SEED) 654 np.random.seed(random_seed.DEFAULT_GRAPH_SEED) 655 # Note: The following line is necessary because some test methods may error 656 # out from within nested graph contexts (e.g., via assertRaises and 657 # assertRaisesRegexp), which may leave ops._default_graph_stack non-empty 658 # under certain versions of Python. That would cause 659 # ops.reset_default_graph() to throw an exception if the stack were not 660 # cleared first. 661 ops._default_graph_stack.reset() # pylint: disable=protected-access 662 ops.reset_default_graph() 663 random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED) 664 665 def tearDown(self): 666 for thread in self._threads: 667 thread.check_termination() 668 669 self._ClearCachedSession() 670 671 def _ClearCachedSession(self): 672 if self._cached_session is not None: 673 self._cached_session.close() 674 self._cached_session = None 675 676 def get_temp_dir(self): 677 """Returns a unique temporary directory for the test to use. 678 679 If you call this method multiple times during in a test, it will return the 680 same folder. However, across different runs the directories will be 681 different. This will ensure that across different runs tests will not be 682 able to pollute each others environment. 683 If you need multiple unique directories within a single test, you should 684 use tempfile.mkdtemp as follows: 685 tempfile.mkdtemp(dir=self.get_temp_dir()): 686 687 Returns: 688 string, the path to the unique temporary directory created for this test. 689 """ 690 if not self._tempdir: 691 self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir()) 692 return self._tempdir 693 694 def _AssertProtoEquals(self, a, b, msg=None): 695 """Asserts that a and b are the same proto. 696 697 Uses ProtoEq() first, as it returns correct results 698 for floating point attributes, and then use assertProtoEqual() 699 in case of failure as it provides good error messages. 700 701 Args: 702 a: a proto. 703 b: another proto. 704 msg: Optional message to report on failure. 705 """ 706 if not compare.ProtoEq(a, b): 707 compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg) 708 709 def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None): 710 """Asserts that message is same as parsed expected_message_ascii. 711 712 Creates another prototype of message, reads the ascii message into it and 713 then compares them using self._AssertProtoEqual(). 714 715 Args: 716 expected_message_maybe_ascii: proto message in original or ascii form. 717 message: the message to validate. 718 msg: Optional message to report on failure. 719 """ 720 msg = msg if msg else "" 721 if isinstance(expected_message_maybe_ascii, type(message)): 722 expected_message = expected_message_maybe_ascii 723 self._AssertProtoEquals(expected_message, message) 724 elif isinstance(expected_message_maybe_ascii, str): 725 expected_message = type(message)() 726 text_format.Merge( 727 expected_message_maybe_ascii, 728 expected_message, 729 descriptor_pool=descriptor_pool.Default()) 730 self._AssertProtoEquals(expected_message, message, msg=msg) 731 else: 732 assert False, ("Can't compare protos of type %s and %s. %s" % 733 (type(expected_message_maybe_ascii), type(message), msg)) 734 735 def assertProtoEqualsVersion( 736 self, 737 expected, 738 actual, 739 producer=versions.GRAPH_DEF_VERSION, 740 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER, 741 msg=None): 742 expected = "versions { producer: %d min_consumer: %d };\n%s" % ( 743 producer, min_consumer, expected) 744 self.assertProtoEquals(expected, actual, msg=msg) 745 746 def assertStartsWith(self, actual, expected_start, msg=None): 747 """Assert that actual.startswith(expected_start) is True. 748 749 Args: 750 actual: str 751 expected_start: str 752 msg: Optional message to report on failure. 753 """ 754 if not actual.startswith(expected_start): 755 fail_msg = "%r does not start with %r" % (actual, expected_start) 756 fail_msg += " : %r" % (msg) if msg else "" 757 self.fail(fail_msg) 758 759 def _eval_tensor(self, tensor): 760 if tensor is None: 761 return None 762 elif isinstance(tensor, ops.EagerTensor): 763 return tensor.numpy() 764 elif isinstance(tensor, resource_variable_ops.ResourceVariable): 765 return tensor.read_value().numpy() 766 elif callable(tensor): 767 return self._eval_helper(tensor()) 768 else: 769 raise ValueError("Unsupported type %s." % type(tensor)) 770 771 def _eval_helper(self, tensors): 772 if tensors is None: 773 return None 774 return nest.map_structure(self._eval_tensor, tensors) 775 776 def evaluate(self, tensors): 777 """Evaluates tensors and returns numpy values. 778 779 Args: 780 tensors: A Tensor or a nested list/tuple of Tensors. 781 782 Returns: 783 tensors numpy values. 784 """ 785 if context.in_eager_mode(): 786 return self._eval_helper(tensors) 787 else: 788 sess = ops.get_default_session() 789 if sess is None: 790 with self.test_session() as sess: 791 return sess.run(tensors) 792 else: 793 return sess.run(tensors) 794 795 # pylint: disable=g-doc-return-or-yield 796 @contextlib.contextmanager 797 def test_session(self, 798 graph=None, 799 config=None, 800 use_gpu=False, 801 force_gpu=False): 802 """Returns a TensorFlow Session for use in executing tests. 803 804 This method should be used for all functional tests. 805 806 This method behaves different than session.Session: for performance reasons 807 `test_session` will by default (if `graph` is None) reuse the same session 808 across tests. This means you may want to either call the function 809 `reset_default_graph()` before tests, or if creating an explicit new graph, 810 pass it here (simply setting it with `as_default()` won't do it), which will 811 trigger the creation of a new session. 812 813 Use the `use_gpu` and `force_gpu` options to control where ops are run. If 814 `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if 815 `use_gpu` 816 is True, TensorFlow tries to run as many ops on the GPU as possible. If both 817 `force_gpu and `use_gpu` are False, all ops are pinned to the CPU. 818 819 Example: 820 ```python 821 class MyOperatorTest(test_util.TensorFlowTestCase): 822 def testMyOperator(self): 823 with self.test_session(use_gpu=True): 824 valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] 825 result = MyOperator(valid_input).eval() 826 self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] 827 invalid_input = [-1.0, 2.0, 7.0] 828 with self.assertRaisesOpError("negative input not supported"): 829 MyOperator(invalid_input).eval() 830 ``` 831 832 Args: 833 graph: Optional graph to use during the returned session. 834 config: An optional config_pb2.ConfigProto to use to configure the 835 session. 836 use_gpu: If True, attempt to run as many ops as possible on GPU. 837 force_gpu: If True, pin all ops to `/device:GPU:0`. 838 839 Returns: 840 A Session object that should be used as a context manager to surround 841 the graph building and execution code in a test case. 842 """ 843 if self.id().endswith(".test_session"): 844 self.skipTest("Not a test.") 845 846 def prepare_config(config): 847 """Returns a config for sessions. 848 849 Args: 850 config: An optional config_pb2.ConfigProto to use to configure the 851 session. 852 Returns: 853 A config_pb2.ConfigProto object. 854 """ 855 if config is None: 856 config = config_pb2.ConfigProto() 857 config.allow_soft_placement = not force_gpu 858 config.gpu_options.per_process_gpu_memory_fraction = 0.3 859 elif force_gpu and config.allow_soft_placement: 860 config = config_pb2.ConfigProto().CopyFrom(config) 861 config.allow_soft_placement = False 862 # Don't perform optimizations for tests so we don't inadvertently run 863 # gpu ops on cpu 864 config.graph_options.optimizer_options.opt_level = -1 865 config.graph_options.rewrite_options.constant_folding = ( 866 rewriter_config_pb2.RewriterConfig.OFF) 867 config.graph_options.rewrite_options.arithmetic_optimization = ( 868 rewriter_config_pb2.RewriterConfig.OFF) 869 return config 870 871 if graph is None: 872 if self._cached_session is None: 873 self._cached_session = session.Session( 874 graph=None, config=prepare_config(config)) 875 sess = self._cached_session 876 with sess.graph.as_default(), sess.as_default(): 877 if force_gpu: 878 # Use the name of an actual device if one is detected, or '/device:GPU:0' 879 # otherwise 880 gpu_name = gpu_device_name() 881 if not gpu_name: 882 gpu_name = "/device:GPU:0" 883 with sess.graph.device(gpu_name): 884 yield sess 885 elif use_gpu: 886 yield sess 887 else: 888 with sess.graph.device("/cpu:0"): 889 yield sess 890 else: 891 with session.Session(graph=graph, config=prepare_config(config)) as sess: 892 if force_gpu: 893 # Use the name of an actual device if one is detected, or '/device:GPU:0' 894 # otherwise 895 gpu_name = gpu_device_name() 896 if not gpu_name: 897 gpu_name = "/device:GPU:0" 898 with sess.graph.device(gpu_name): 899 yield sess 900 elif use_gpu: 901 yield sess 902 else: 903 with sess.graph.device("/cpu:0"): 904 yield sess 905 906 # pylint: enable=g-doc-return-or-yield 907 908 class _CheckedThread(object): 909 """A wrapper class for Thread that asserts successful completion. 910 911 This class should be created using the TensorFlowTestCase.checkedThread() 912 method. 913 """ 914 915 def __init__(self, testcase, target, args=None, kwargs=None): 916 """Constructs a new instance of _CheckedThread. 917 918 Args: 919 testcase: The TensorFlowTestCase for which this thread is being created. 920 target: A callable object representing the code to be executed in the 921 thread. 922 args: A tuple of positional arguments that will be passed to target. 923 kwargs: A dictionary of keyword arguments that will be passed to target. 924 """ 925 self._testcase = testcase 926 self._target = target 927 self._args = () if args is None else args 928 self._kwargs = {} if kwargs is None else kwargs 929 self._thread = threading.Thread(target=self._protected_run) 930 self._exception = None 931 932 self._is_thread_joined = False 933 934 def _protected_run(self): 935 """Target for the wrapper thread. Sets self._exception on failure.""" 936 try: 937 self._target(*self._args, **self._kwargs) 938 except Exception as e: # pylint: disable=broad-except 939 self._exception = e 940 941 def start(self): 942 """Starts the thread's activity. 943 944 This must be called at most once per _CheckedThread object. It arranges 945 for the object's target to be invoked in a separate thread of control. 946 """ 947 self._thread.start() 948 949 def join(self): 950 """Blocks until the thread terminates. 951 952 Raises: 953 self._testcase.failureException: If the thread terminates with due to 954 an exception. 955 """ 956 self._is_thread_joined = True 957 self._thread.join() 958 if self._exception is not None: 959 self._testcase.fail("Error in checkedThread: %s" % str(self._exception)) 960 961 def is_alive(self): 962 """Returns whether the thread is alive. 963 964 This method returns True just before the run() method starts 965 until just after the run() method terminates. 966 967 Returns: 968 True if the thread is alive, otherwise False. 969 """ 970 return self._thread.is_alive() 971 972 def check_termination(self): 973 """Returns whether the checked thread was properly used and did terminate. 974 975 Every checked thread should be "join"ed after starting, and before the 976 test tears down. If it is not joined, it is possible the thread will hang 977 and cause flaky failures in tests. 978 979 Raises: 980 self._testcase.failureException: If check_termination was called before 981 thread was joined. 982 983 RuntimeError: If the thread is not terminated. This means thread was not 984 joined with the main thread. 985 """ 986 if self._is_thread_joined: 987 if self.is_alive(): 988 raise RuntimeError( 989 "Thread was not joined with main thread, and is still running " 990 "when the test finished.") 991 else: 992 self._testcase.fail("A checked thread was not joined.") 993 994 def checkedThread(self, target, args=None, kwargs=None): 995 """Returns a Thread wrapper that asserts 'target' completes successfully. 996 997 This method should be used to create all threads in test cases, as 998 otherwise there is a risk that a thread will silently fail, and/or 999 assertions made in the thread will not be respected. 1000 1001 Args: 1002 target: A callable object to be executed in the thread. 1003 args: The argument tuple for the target invocation. Defaults to (). 1004 kwargs: A dictionary of keyword arguments for the target invocation. 1005 Defaults to {}. 1006 1007 Returns: 1008 A wrapper for threading.Thread that supports start() and join() methods. 1009 """ 1010 ret = TensorFlowTestCase._CheckedThread(self, target, args, kwargs) 1011 self._threads.append(ret) 1012 return ret 1013 1014 1015 # pylint: enable=invalid-name 1016 1017 def assertNear(self, f1, f2, err, msg=None): 1018 """Asserts that two floats are near each other. 1019 1020 Checks that |f1 - f2| < err and asserts a test failure 1021 if not. 1022 1023 Args: 1024 f1: A float value. 1025 f2: A float value. 1026 err: A float value. 1027 msg: An optional string message to append to the failure message. 1028 """ 1029 # f1 == f2 is needed here as we might have: f1, f2 = inf, inf 1030 self.assertTrue(f1 == f2 or math.fabs(f1 - f2) <= err, 1031 "%f != %f +/- %f%s" % (f1, f2, err, " (%s)" % msg 1032 if msg is not None else "")) 1033 1034 def assertArrayNear(self, farray1, farray2, err, msg=None): 1035 """Asserts that two float arrays are near each other. 1036 1037 Checks that for all elements of farray1 and farray2 1038 |f1 - f2| < err. Asserts a test failure if not. 1039 1040 Args: 1041 farray1: a list of float values. 1042 farray2: a list of float values. 1043 err: a float value. 1044 msg: Optional message to report on failure. 1045 """ 1046 self.assertEqual(len(farray1), len(farray2), msg=msg) 1047 for f1, f2 in zip(farray1, farray2): 1048 self.assertNear(float(f1), float(f2), err, msg=msg) 1049 1050 def _NDArrayNear(self, ndarray1, ndarray2, err): 1051 return np.linalg.norm(ndarray1 - ndarray2) < err 1052 1053 def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None): 1054 """Asserts that two numpy arrays have near values. 1055 1056 Args: 1057 ndarray1: a numpy ndarray. 1058 ndarray2: a numpy ndarray. 1059 err: a float. The maximum absolute difference allowed. 1060 msg: Optional message to report on failure. 1061 """ 1062 self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err), msg=msg) 1063 1064 def _GetNdArray(self, a): 1065 if not isinstance(a, np.ndarray): 1066 a = np.array(a) 1067 return a 1068 1069 def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): 1070 a = self._GetNdArray(a) 1071 b = self._GetNdArray(b) 1072 self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s." % 1073 (a.shape, b.shape)) 1074 if not np.allclose(a, b, rtol=rtol, atol=atol): 1075 # Prints more details than np.testing.assert_allclose. 1076 # 1077 # NOTE: numpy.allclose (and numpy.testing.assert_allclose) 1078 # checks whether two arrays are element-wise equal within a 1079 # tolerance. The relative difference (rtol * abs(b)) and the 1080 # absolute difference atol are added together to compare against 1081 # the absolute difference between a and b. Here, we want to 1082 # print out which elements violate such conditions. 1083 cond = np.logical_or( 1084 np.abs(a - b) > atol + rtol * np.abs(b), 1085 np.isnan(a) != np.isnan(b)) 1086 if a.ndim: 1087 x = a[np.where(cond)] 1088 y = b[np.where(cond)] 1089 print("not close where = ", np.where(cond)) 1090 else: 1091 # np.where is broken for scalars 1092 x, y = a, b 1093 print("not close lhs = ", x) 1094 print("not close rhs = ", y) 1095 print("not close dif = ", np.abs(x - y)) 1096 print("not close tol = ", atol + rtol * np.abs(y)) 1097 print("dtype = %s, shape = %s" % (a.dtype, a.shape)) 1098 # TODO(xpan): There seems to be a bug: 1099 # tensorflow/compiler/tests:binary_ops_test pass with float32 1100 # nan even though the equal_nan is False by default internally. 1101 np.testing.assert_allclose( 1102 a, b, rtol=rtol, atol=atol, err_msg=msg, equal_nan=True) 1103 1104 def _assertAllCloseRecursive(self, 1105 a, 1106 b, 1107 rtol=1e-6, 1108 atol=1e-6, 1109 path=None, 1110 msg=None): 1111 path = path or [] 1112 path_str = (("[" + "][".join([str(p) for p in path]) + "]") if path else "") 1113 msg = msg if msg else "" 1114 1115 # Check if a and/or b are namedtuples. 1116 if hasattr(a, "_asdict"): 1117 a = a._asdict() 1118 if hasattr(b, "_asdict"): 1119 b = b._asdict() 1120 a_is_dict = isinstance(a, dict) 1121 if a_is_dict != isinstance(b, dict): 1122 raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" % 1123 (path_str, path_str, msg)) 1124 if a_is_dict: 1125 self.assertItemsEqual( 1126 a.keys(), 1127 b.keys(), 1128 msg="mismatched keys: a%s has keys %s, but b%s has keys %s. %s" % 1129 (path_str, a.keys(), path_str, b.keys(), msg)) 1130 for k in a: 1131 path.append(k) 1132 self._assertAllCloseRecursive( 1133 a[k], b[k], rtol=rtol, atol=atol, path=path, msg=msg) 1134 del path[-1] 1135 elif isinstance(a, (list, tuple)): 1136 # Try to directly compare a, b as ndarrays; if not work, then traverse 1137 # through the sequence, which is more expensive. 1138 try: 1139 a_as_ndarray = np.array(a) 1140 b_as_ndarray = np.array(b) 1141 self._assertArrayLikeAllClose( 1142 a_as_ndarray, 1143 b_as_ndarray, 1144 rtol=rtol, 1145 atol=atol, 1146 msg="Mismatched value: a%s is different from b%s. %s" % 1147 (path_str, path_str, msg)) 1148 except (ValueError, TypeError) as e: 1149 if len(a) != len(b): 1150 raise ValueError( 1151 "Mismatched length: a%s has %d items, but b%s has %d items. %s" % 1152 (path_str, len(a), path_str, len(b), msg)) 1153 for idx, (a_ele, b_ele) in enumerate(zip(a, b)): 1154 path.append(str(idx)) 1155 self._assertAllCloseRecursive( 1156 a_ele, b_ele, rtol=rtol, atol=atol, path=path, msg=msg) 1157 del path[-1] 1158 # a and b are ndarray like objects 1159 else: 1160 try: 1161 self._assertArrayLikeAllClose( 1162 a, 1163 b, 1164 rtol=rtol, 1165 atol=atol, 1166 msg="Mismatched value: a%s is different from b%s." % (path_str, 1167 path_str)) 1168 except TypeError as e: 1169 msg = "Error: a%s has %s, but b%s has %s" % ( 1170 path_str, type(a), path_str, type(b)) 1171 e.args = ((e.args[0] + ' : ' + msg,) + e.args[1:]) 1172 raise 1173 1174 def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): 1175 """Asserts that two structures of numpy arrays, have near values. 1176 1177 `a` and `b` can be arbitrarily nested structures. A layer of a nested 1178 structure can be a `dict`, `namedtuple`, `tuple` or `list`. 1179 1180 Args: 1181 a: The expected numpy `ndarray`, or anything that can be converted into a 1182 numpy `ndarray`, or any arbitrarily nested of structure of these. 1183 b: The actual numpy `ndarray`, or anything that can be converted into a 1184 numpy `ndarray`, or any arbitrarily nested of structure of these. 1185 rtol: relative tolerance. 1186 atol: absolute tolerance. 1187 msg: Optional message to report on failure. 1188 1189 Raises: 1190 ValueError: if only one of `a[p]` and `b[p]` is a dict or 1191 `a[p]` and `b[p]` have different length, where `[p]` denotes a path 1192 to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and 1193 `[p] = [1]['d']`, then `a[p] = (6, 7)`. 1194 """ 1195 self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg) 1196 1197 def assertAllCloseAccordingToType(self, 1198 a, 1199 b, 1200 rtol=1e-6, 1201 atol=1e-6, 1202 float_rtol=1e-6, 1203 float_atol=1e-6, 1204 half_rtol=1e-3, 1205 half_atol=1e-3, 1206 bfloat16_rtol=1e-2, 1207 bfloat16_atol=1e-2, 1208 msg=None): 1209 """Like assertAllClose, but also suitable for comparing fp16 arrays. 1210 1211 In particular, the tolerance is reduced to 1e-3 if at least 1212 one of the arguments is of type float16. 1213 1214 Args: 1215 a: the expected numpy ndarray or anything can be converted to one. 1216 b: the actual numpy ndarray or anything can be converted to one. 1217 rtol: relative tolerance. 1218 atol: absolute tolerance. 1219 float_rtol: relative tolerance for float32. 1220 float_atol: absolute tolerance for float32. 1221 half_rtol: relative tolerance for float16. 1222 half_atol: absolute tolerance for float16. 1223 bfloat16_rtol: relative tolerance for bfloat16. 1224 bfloat16_atol: absolute tolerance for bfloat16. 1225 msg: Optional message to report on failure. 1226 """ 1227 a = self._GetNdArray(a) 1228 b = self._GetNdArray(b) 1229 # types with lower tol are put later to overwrite previous ones. 1230 if (a.dtype == np.float32 or b.dtype == np.float32 or 1231 a.dtype == np.complex64 or b.dtype == np.complex64): 1232 rtol = max(rtol, float_rtol) 1233 atol = max(atol, float_atol) 1234 if a.dtype == np.float16 or b.dtype == np.float16: 1235 rtol = max(rtol, half_rtol) 1236 atol = max(atol, half_atol) 1237 if (a.dtype == dtypes.bfloat16.as_numpy_dtype or 1238 b.dtype == dtypes.bfloat16.as_numpy_dtype): 1239 rtol = max(rtol, bfloat16_rtol) 1240 atol = max(atol, bfloat16_atol) 1241 1242 self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg) 1243 1244 def assertAllEqual(self, a, b, msg=None): 1245 """Asserts that two numpy arrays have the same values. 1246 1247 Args: 1248 a: the expected numpy ndarray or anything can be converted to one. 1249 b: the actual numpy ndarray or anything can be converted to one. 1250 msg: Optional message to report on failure. 1251 """ 1252 msg = msg if msg else "" 1253 a = self._GetNdArray(a) 1254 b = self._GetNdArray(b) 1255 self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s." 1256 " %s" % (a.shape, b.shape, msg)) 1257 same = (a == b) 1258 1259 if a.dtype == np.float32 or a.dtype == np.float64: 1260 same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b))) 1261 if not np.all(same): 1262 # Prints more details than np.testing.assert_array_equal. 1263 diff = np.logical_not(same) 1264 if a.ndim: 1265 x = a[np.where(diff)] 1266 y = b[np.where(diff)] 1267 print("not equal where = ", np.where(diff)) 1268 else: 1269 # np.where is broken for scalars 1270 x, y = a, b 1271 print("not equal lhs = ", x) 1272 print("not equal rhs = ", y) 1273 np.testing.assert_array_equal(a, b, err_msg=msg) 1274 1275 # pylint: disable=g-doc-return-or-yield 1276 @contextlib.contextmanager 1277 def assertRaisesWithPredicateMatch(self, exception_type, 1278 expected_err_re_or_predicate): 1279 """Returns a context manager to enclose code expected to raise an exception. 1280 1281 If the exception is an OpError, the op stack is also included in the message 1282 predicate search. 1283 1284 Args: 1285 exception_type: The expected type of exception that should be raised. 1286 expected_err_re_or_predicate: If this is callable, it should be a function 1287 of one argument that inspects the passed-in exception and 1288 returns True (success) or False (please fail the test). Otherwise, the 1289 error message is expected to match this regular expression partially. 1290 1291 Returns: 1292 A context manager to surround code that is expected to raise an 1293 exception. 1294 """ 1295 if callable(expected_err_re_or_predicate): 1296 predicate = expected_err_re_or_predicate 1297 else: 1298 1299 def predicate(e): 1300 err_str = e.message if isinstance(e, errors.OpError) else str(e) 1301 op = e.op if isinstance(e, errors.OpError) else None 1302 while op is not None: 1303 err_str += "\nCaused by: " + op.name 1304 op = op._original_op # pylint: disable=protected-access 1305 logging.info("Searching within error strings: '%s' within '%s'", 1306 expected_err_re_or_predicate, err_str) 1307 return re.search(expected_err_re_or_predicate, err_str) 1308 1309 try: 1310 yield 1311 self.fail(exception_type.__name__ + " not raised") 1312 except Exception as e: # pylint: disable=broad-except 1313 if not isinstance(e, exception_type) or not predicate(e): 1314 raise AssertionError("Exception of type %s: %s" % (str(type(e)), 1315 str(e))) 1316 1317 # pylint: enable=g-doc-return-or-yield 1318 1319 def assertRaisesOpError(self, expected_err_re_or_predicate): 1320 return self.assertRaisesWithPredicateMatch(errors.OpError, 1321 expected_err_re_or_predicate) 1322 1323 def assertShapeEqual(self, np_array, tf_tensor, msg=None): 1324 """Asserts that a Numpy ndarray and a TensorFlow tensor have the same shape. 1325 1326 Args: 1327 np_array: A Numpy ndarray or Numpy scalar. 1328 tf_tensor: A Tensor. 1329 msg: Optional message to report on failure. 1330 1331 Raises: 1332 TypeError: If the arguments have the wrong type. 1333 """ 1334 if not isinstance(np_array, (np.ndarray, np.generic)): 1335 raise TypeError("np_array must be a Numpy ndarray or Numpy scalar") 1336 if not isinstance(tf_tensor, ops.Tensor): 1337 raise TypeError("tf_tensor must be a Tensor") 1338 self.assertAllEqual( 1339 np_array.shape, tf_tensor.get_shape().as_list(), msg=msg) 1340 1341 def assertDeviceEqual(self, device1, device2, msg=None): 1342 """Asserts that the two given devices are the same. 1343 1344 Args: 1345 device1: A string device name or TensorFlow `DeviceSpec` object. 1346 device2: A string device name or TensorFlow `DeviceSpec` object. 1347 msg: Optional message to report on failure. 1348 """ 1349 device1 = pydev.canonical_name(device1) 1350 device2 = pydev.canonical_name(device2) 1351 self.assertEqual(device1, device2, 1352 "Devices %s and %s are not equal. %s" % 1353 (device1, device2, msg)) 1354 1355 # Fix Python 3 compatibility issues 1356 if six.PY3: 1357 # pylint: disable=invalid-name 1358 1359 # Silence a deprecation warning 1360 assertRaisesRegexp = googletest.TestCase.assertRaisesRegex 1361 1362 # assertItemsEqual is assertCountEqual as of 3.2. 1363 assertItemsEqual = googletest.TestCase.assertCountEqual 1364 1365 # pylint: enable=invalid-name 1366 1367 1368 @tf_export("test.create_local_cluster") 1369 def create_local_cluster(num_workers, 1370 num_ps, 1371 protocol="grpc", 1372 worker_config=None, 1373 ps_config=None): 1374 """Create and start local servers and return the associated `Server` objects. 1375 1376 Example: 1377 ```python 1378 workers, _ = tf.test.create_local_cluster(num_workers=2, num_ps=2) 1379 1380 worker_sessions = [tf.Session(w.target) for w in workers] 1381 1382 with tf.device("/job:ps/task:0"): 1383 ... 1384 with tf.device("/job:ps/task:1"): 1385 ... 1386 with tf.device("/job:worker/task:0"): 1387 ... 1388 with tf.device("/job:worker/task:1"): 1389 ... 1390 1391 worker_sessions[0].run(...) 1392 ``` 1393 1394 Args: 1395 num_workers: Number of worker servers to start. 1396 num_ps: Number of PS servers to start. 1397 protocol: Communication protocol. Allowed values are documented in 1398 the documentation of `tf.train.Server`. 1399 worker_config: (optional) ConfigProto to initialize workers. Can be used 1400 to instantiate multiple devices etc. 1401 ps_config: (optional) ConfigProto to initialize PS servers. 1402 1403 Returns: 1404 A tuple `(worker_servers, ps_servers)`. `worker_servers` is a list 1405 of `num_workers` objects of type `tf.train.Server` (all running locally); 1406 and `ps_servers` is a list of `num_ps` objects of similar type. 1407 1408 Raises: 1409 ImportError: if portpicker module was not found at load time 1410 """ 1411 if _portpicker_import_error: 1412 raise _portpicker_import_error # pylint: disable=raising-bad-type 1413 worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] 1414 ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)] 1415 cluster_dict = { 1416 "worker": ["localhost:%s" % port for port in worker_ports], 1417 "ps": ["localhost:%s" % port for port in ps_ports] 1418 } 1419 cs = server_lib.ClusterSpec(cluster_dict) 1420 1421 workers = [ 1422 server_lib.Server( 1423 cs, 1424 job_name="worker", 1425 protocol=protocol, 1426 task_index=ix, 1427 config=worker_config, 1428 start=True) for ix in range(num_workers) 1429 ] 1430 ps_servers = [ 1431 server_lib.Server( 1432 cs, 1433 job_name="ps", 1434 protocol=protocol, 1435 task_index=ix, 1436 config=ps_config, 1437 start=True) for ix in range(num_ps) 1438 ] 1439 1440 return workers, ps_servers 1441 1442 1443 def get_node_def_from_graph(node_name, graph_def): 1444 """Returns the `NodeDef` instance for given node name in the graph def. 1445 1446 This method explores only the NodeDefs in `graph_def.node`. 1447 1448 Args: 1449 node_name: Name of the NodeDef to search for. 1450 graph_def: An instance of `GraphDef` proto. 1451 1452 Returns: 1453 the `NodeDef` instance whose name field matches the given node_name or None. 1454 """ 1455 for node_def in graph_def.node: 1456 if node_def.name == node_name: 1457 return node_def 1458 return None 1459 1460 1461 def set_producer_version(graph, producer_version): 1462 """Sets graph.graph_def_versions.producer to `producer_version`.""" 1463 # The C API doesn't expose altering GraphDefVersions. We can indirectly set 1464 # it via import_graph_def though. 1465 graph_def = graph_pb2.GraphDef() 1466 graph_def.versions.producer = producer_version 1467 with graph.as_default(): 1468 importer.import_graph_def(graph_def) 1469 assert graph.graph_def_versions.producer, producer_version 1470