Home | History | Annotate | Download | only in platform
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Testing.
     17 
     18 See the @{$python/test} guide.
     19 
     20 Note: `tf.test.mock` is an alias to the python `mock` or `unittest.mock`
     21 depending on the python version.
     22 
     23 @@main
     24 @@TestCase
     25 @@test_src_dir_path
     26 @@assert_equal_graph_def
     27 @@get_temp_dir
     28 @@is_built_with_cuda
     29 @@is_gpu_available
     30 @@gpu_device_name
     31 @@compute_gradient
     32 @@compute_gradient_error
     33 @@create_local_cluster
     34 
     35 """
     36 
     37 from __future__ import absolute_import
     38 from __future__ import division
     39 from __future__ import print_function
     40 
     41 
     42 # pylint: disable=g-bad-import-order
     43 from tensorflow.python.framework import test_util as _test_util
     44 from tensorflow.python.platform import googletest as _googletest
     45 from tensorflow.python.util.all_util import remove_undocumented
     46 
     47 # pylint: disable=unused-import
     48 from tensorflow.python.framework.test_util import assert_equal_graph_def
     49 from tensorflow.python.framework.test_util import create_local_cluster
     50 from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase
     51 from tensorflow.python.framework.test_util import gpu_device_name
     52 from tensorflow.python.framework.test_util import is_gpu_available
     53 
     54 from tensorflow.python.ops.gradient_checker import compute_gradient_error
     55 from tensorflow.python.ops.gradient_checker import compute_gradient
     56 # pylint: enable=unused-import,g-bad-import-order
     57 
     58 import sys
     59 from tensorflow.python.util.tf_export import tf_export
     60 if sys.version_info.major == 2:
     61   import mock                # pylint: disable=g-import-not-at-top,unused-import
     62 else:
     63   from unittest import mock  # pylint: disable=g-import-not-at-top
     64 
     65 # Import Benchmark class
     66 Benchmark = _googletest.Benchmark  # pylint: disable=invalid-name
     67 
     68 # Import StubOutForTesting class
     69 StubOutForTesting = _googletest.StubOutForTesting  # pylint: disable=invalid-name
     70 
     71 
     72 @tf_export('test.main')
     73 def main(argv=None):
     74   """Runs all unit tests."""
     75   _test_util.InstallStackTraceHandler()
     76   return _googletest.main(argv)
     77 
     78 
     79 @tf_export('test.get_temp_dir')
     80 def get_temp_dir():
     81   """Returns a temporary directory for use during tests.
     82 
     83   There is no need to delete the directory after the test.
     84 
     85   Returns:
     86     The temporary directory.
     87   """
     88   return _googletest.GetTempDir()
     89 
     90 
     91 @tf_export('test.test_src_dir_path')
     92 def test_src_dir_path(relative_path):
     93   """Creates an absolute test srcdir path given a relative path.
     94 
     95   Args:
     96     relative_path: a path relative to tensorflow root.
     97       e.g. "core/platform".
     98 
     99   Returns:
    100     An absolute path to the linked in runfiles.
    101   """
    102   return _googletest.test_src_dir_path(relative_path)
    103 
    104 
    105 @tf_export('test.is_built_with_cuda')
    106 def is_built_with_cuda():
    107   """Returns whether TensorFlow was built with CUDA (GPU) support."""
    108   return _test_util.IsGoogleCudaEnabled()
    109 
    110 
    111 _allowed_symbols = [
    112     # We piggy-back googletest documentation.
    113     'Benchmark',
    114     'mock',
    115     'StubOutForTesting',
    116 ]
    117 
    118 remove_undocumented(__name__, _allowed_symbols)
    119