1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Testing. 17 18 See the @{$python/test} guide. 19 20 Note: `tf.test.mock` is an alias to the python `mock` or `unittest.mock` 21 depending on the python version. 22 23 @@main 24 @@TestCase 25 @@test_src_dir_path 26 @@assert_equal_graph_def 27 @@get_temp_dir 28 @@is_built_with_cuda 29 @@is_gpu_available 30 @@gpu_device_name 31 @@compute_gradient 32 @@compute_gradient_error 33 @@create_local_cluster 34 35 """ 36 37 from __future__ import absolute_import 38 from __future__ import division 39 from __future__ import print_function 40 41 42 # pylint: disable=g-bad-import-order 43 from tensorflow.python.framework import test_util as _test_util 44 from tensorflow.python.platform import googletest as _googletest 45 from tensorflow.python.util.all_util import remove_undocumented 46 47 # pylint: disable=unused-import 48 from tensorflow.python.framework.test_util import assert_equal_graph_def 49 from tensorflow.python.framework.test_util import create_local_cluster 50 from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase 51 from tensorflow.python.framework.test_util import gpu_device_name 52 from tensorflow.python.framework.test_util import is_gpu_available 53 54 from tensorflow.python.ops.gradient_checker import compute_gradient_error 55 from tensorflow.python.ops.gradient_checker import compute_gradient 56 # pylint: enable=unused-import,g-bad-import-order 57 58 import sys 59 from tensorflow.python.util.tf_export import tf_export 60 if sys.version_info.major == 2: 61 import mock # pylint: disable=g-import-not-at-top,unused-import 62 else: 63 from unittest import mock # pylint: disable=g-import-not-at-top 64 65 # Import Benchmark class 66 Benchmark = _googletest.Benchmark # pylint: disable=invalid-name 67 68 # Import StubOutForTesting class 69 StubOutForTesting = _googletest.StubOutForTesting # pylint: disable=invalid-name 70 71 72 @tf_export('test.main') 73 def main(argv=None): 74 """Runs all unit tests.""" 75 _test_util.InstallStackTraceHandler() 76 return _googletest.main(argv) 77 78 79 @tf_export('test.get_temp_dir') 80 def get_temp_dir(): 81 """Returns a temporary directory for use during tests. 82 83 There is no need to delete the directory after the test. 84 85 Returns: 86 The temporary directory. 87 """ 88 return _googletest.GetTempDir() 89 90 91 @tf_export('test.test_src_dir_path') 92 def test_src_dir_path(relative_path): 93 """Creates an absolute test srcdir path given a relative path. 94 95 Args: 96 relative_path: a path relative to tensorflow root. 97 e.g. "core/platform". 98 99 Returns: 100 An absolute path to the linked in runfiles. 101 """ 102 return _googletest.test_src_dir_path(relative_path) 103 104 105 @tf_export('test.is_built_with_cuda') 106 def is_built_with_cuda(): 107 """Returns whether TensorFlow was built with CUDA (GPU) support.""" 108 return _test_util.IsGoogleCudaEnabled() 109 110 111 _allowed_symbols = [ 112 # We piggy-back googletest documentation. 113 'Benchmark', 114 'mock', 115 'StubOutForTesting', 116 ] 117 118 remove_undocumented(__name__, _allowed_symbols) 119