1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================= 15 """Tests for create_python_api.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import imp 22 import sys 23 24 from tensorflow.python.platform import test 25 from tensorflow.python.util.tf_export import tf_export 26 from tensorflow.tools.api.generator import create_python_api 27 28 29 @tf_export('test_op', 'test_op1') 30 def test_op(): 31 pass 32 33 34 @tf_export('TestClass', 'NewTestClass') 35 class TestClass(object): 36 pass 37 38 39 _TEST_CONSTANT = 5 40 _MODULE_NAME = 'test.tensorflow.test_module' 41 42 43 class CreatePythonApiTest(test.TestCase): 44 45 def setUp(self): 46 # Add fake op to a module that has 'tensorflow' in the name. 47 sys.modules[_MODULE_NAME] = imp.new_module(_MODULE_NAME) 48 setattr(sys.modules[_MODULE_NAME], 'test_op', test_op) 49 setattr(sys.modules[_MODULE_NAME], 'TestClass', TestClass) 50 test_op.__module__ = _MODULE_NAME 51 TestClass.__module__ = _MODULE_NAME 52 tf_export('consts._TEST_CONSTANT').export_constant( 53 _MODULE_NAME, '_TEST_CONSTANT') 54 55 def tearDown(self): 56 del sys.modules[_MODULE_NAME] 57 58 def testFunctionImportIsAdded(self): 59 imports = create_python_api.get_api_imports() 60 expected_import = ( 61 'from test.tensorflow.test_module import test_op as test_op1') 62 self.assertTrue( 63 expected_import in str(imports), 64 msg='%s not in %s' % (expected_import, str(imports))) 65 66 expected_import = 'from test.tensorflow.test_module import test_op' 67 self.assertTrue( 68 expected_import in str(imports), 69 msg='%s not in %s' % (expected_import, str(imports))) 70 71 def testClassImportIsAdded(self): 72 imports = create_python_api.get_api_imports() 73 expected_import = 'from test.tensorflow.test_module import TestClass' 74 self.assertTrue( 75 'TestClass' in str(imports), 76 msg='%s not in %s' % (expected_import, str(imports))) 77 78 def testConstantIsAdded(self): 79 imports = create_python_api.get_api_imports() 80 expected = 'from test.tensorflow.test_module import _TEST_CONSTANT' 81 self.assertTrue(expected in str(imports), 82 msg='%s not in %s' % (expected, str(imports))) 83 84 85 if __name__ == '__main__': 86 test.main() 87