Home | History | Annotate | Download | only in integration_tests
      1 # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """SavedModel integration tests."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import os
     22 import subprocess
     23 
     24 import tensorflow.compat.v2 as tf
     25 
     26 from tensorflow.python.framework import test_util
     27 from tensorflow.python.platform import resource_loader
     28 from tensorflow.python.platform import tf_logging as logging
     29 
     30 
     31 class SavedModelPart1Test(tf.test.TestCase):
     32 
     33   def assertCommandSucceeded(self, binary, **flags):
     34     command_parts = [binary]
     35     for flag_key, flag_value in flags.items():
     36       command_parts.append("--%s=%s" % (flag_key, flag_value))
     37 
     38     logging.info("Running: %s" % command_parts)
     39     subprocess.check_call(
     40         command_parts, env=dict(os.environ, TF2_BEHAVIOR="enabled"))
     41 
     42   @test_util.run_v2_only
     43   def test_text_rnn(self):
     44     export_dir = self.get_temp_dir()
     45     export_binary = resource_loader.get_path_to_datafile(
     46         "export_text_rnn_model")
     47     self.assertCommandSucceeded(export_binary, export_dir=export_dir)
     48 
     49     use_binary = resource_loader.get_path_to_datafile("use_text_rnn_model")
     50     self.assertCommandSucceeded(use_binary, model_dir=export_dir)
     51 
     52   @test_util.run_v2_only
     53   def test_rnn_cell(self):
     54     export_dir = self.get_temp_dir()
     55     export_binary = resource_loader.get_path_to_datafile(
     56         "export_rnn_cell")
     57     self.assertCommandSucceeded(export_binary, export_dir=export_dir)
     58 
     59     use_binary = resource_loader.get_path_to_datafile("use_rnn_cell")
     60     self.assertCommandSucceeded(use_binary, model_dir=export_dir)
     61 
     62   @test_util.run_v2_only
     63   def test_text_embedding_in_sequential_keras(self):
     64     export_dir = self.get_temp_dir()
     65     export_binary = resource_loader.get_path_to_datafile(
     66         "export_simple_text_embedding")
     67     self.assertCommandSucceeded(export_binary, export_dir=export_dir)
     68 
     69     use_binary = resource_loader.get_path_to_datafile(
     70         "use_model_in_sequential_keras")
     71     self.assertCommandSucceeded(use_binary, model_dir=export_dir)
     72 
     73 if __name__ == "__main__":
     74   tf.enable_v2_behavior()
     75   tf.test.main()
     76