Home | History | Annotate | Download | only in generator
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # =============================================================================
     15 """Tests for create_python_api."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import imp
     22 import sys
     23 
     24 from tensorflow.python.platform import test
     25 from tensorflow.python.util.tf_export import tf_export
     26 from tensorflow.tools.api.generator import create_python_api
     27 
     28 
     29 @tf_export('test_op', 'test_op1')
     30 def test_op():
     31   pass
     32 
     33 
     34 @tf_export('TestClass', 'NewTestClass')
     35 class TestClass(object):
     36   pass
     37 
     38 
     39 _TEST_CONSTANT = 5
     40 _MODULE_NAME = 'test.tensorflow.test_module'
     41 
     42 
     43 class CreatePythonApiTest(test.TestCase):
     44 
     45   def setUp(self):
     46     # Add fake op to a module that has 'tensorflow' in the name.
     47     sys.modules[_MODULE_NAME] = imp.new_module(_MODULE_NAME)
     48     setattr(sys.modules[_MODULE_NAME], 'test_op', test_op)
     49     setattr(sys.modules[_MODULE_NAME], 'TestClass', TestClass)
     50     test_op.__module__ = _MODULE_NAME
     51     TestClass.__module__ = _MODULE_NAME
     52     tf_export('consts._TEST_CONSTANT').export_constant(
     53         _MODULE_NAME, '_TEST_CONSTANT')
     54 
     55   def tearDown(self):
     56     del sys.modules[_MODULE_NAME]
     57 
     58   def testFunctionImportIsAdded(self):
     59     imports = create_python_api.get_api_imports()
     60     expected_import = (
     61         'from test.tensorflow.test_module import test_op as test_op1')
     62     self.assertTrue(
     63         expected_import in str(imports),
     64         msg='%s not in %s' % (expected_import, str(imports)))
     65 
     66     expected_import = 'from test.tensorflow.test_module import test_op'
     67     self.assertTrue(
     68         expected_import in str(imports),
     69         msg='%s not in %s' % (expected_import, str(imports)))
     70 
     71   def testClassImportIsAdded(self):
     72     imports = create_python_api.get_api_imports()
     73     expected_import = 'from test.tensorflow.test_module import TestClass'
     74     self.assertTrue(
     75         'TestClass' in str(imports),
     76         msg='%s not in %s' % (expected_import, str(imports)))
     77 
     78   def testConstantIsAdded(self):
     79     imports = create_python_api.get_api_imports()
     80     expected = 'from test.tensorflow.test_module import _TEST_CONSTANT'
     81     self.assertTrue(expected in str(imports),
     82                     msg='%s not in %s' % (expected, str(imports)))
     83 
     84 
     85 if __name__ == '__main__':
     86   test.main()
     87