Home | History | Annotate | Download | only in export
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for export."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import os
     22 import tempfile
     23 import time
     24 
     25 from google.protobuf import text_format
     26 
     27 from tensorflow.core.example import example_pb2
     28 from tensorflow.python.estimator.export import export
     29 from tensorflow.python.estimator.export import export_output
     30 from tensorflow.python.framework import constant_op
     31 from tensorflow.python.framework import dtypes
     32 from tensorflow.python.framework import ops
     33 from tensorflow.python.framework import sparse_tensor
     34 from tensorflow.python.framework import test_util
     35 from tensorflow.python.ops import array_ops
     36 from tensorflow.python.ops import parsing_ops
     37 from tensorflow.python.platform import test
     38 from tensorflow.python.saved_model import signature_constants
     39 from tensorflow.python.saved_model import signature_def_utils
     40 
     41 
     42 class ExportTest(test_util.TensorFlowTestCase):
     43 
     44   def test_serving_input_receiver_constructor(self):
     45     """Tests that no errors are raised when input is expected."""
     46     features = {
     47         "feature0": constant_op.constant([0]),
     48         u"feature1": constant_op.constant([1]),
     49         "feature2": sparse_tensor.SparseTensor(
     50             indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
     51     }
     52     receiver_tensors = {
     53         "example0": array_ops.placeholder(dtypes.string, name="example0"),
     54         u"example1": array_ops.placeholder(dtypes.string, name="example1"),
     55     }
     56     export.ServingInputReceiver(features, receiver_tensors)
     57 
     58   def test_serving_input_receiver_features_invalid(self):
     59     receiver_tensors = {
     60         "example0": array_ops.placeholder(dtypes.string, name="example0"),
     61         u"example1": array_ops.placeholder(dtypes.string, name="example1"),
     62     }
     63 
     64     with self.assertRaisesRegexp(ValueError, "features must be defined"):
     65       export.ServingInputReceiver(
     66           features=None,
     67           receiver_tensors=receiver_tensors)
     68 
     69     with self.assertRaisesRegexp(ValueError, "feature keys must be strings"):
     70       export.ServingInputReceiver(
     71           features={1: constant_op.constant([1])},
     72           receiver_tensors=receiver_tensors)
     73 
     74     with self.assertRaisesRegexp(
     75         ValueError, "feature feature1 must be a Tensor or SparseTensor"):
     76       export.ServingInputReceiver(
     77           features={"feature1": [1]},
     78           receiver_tensors=receiver_tensors)
     79 
     80   def test_serving_input_receiver_receiver_tensors_invalid(self):
     81     features = {
     82         "feature0": constant_op.constant([0]),
     83         u"feature1": constant_op.constant([1]),
     84         "feature2": sparse_tensor.SparseTensor(
     85             indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
     86     }
     87 
     88     with self.assertRaisesRegexp(
     89         ValueError, "receiver_tensors must be defined"):
     90       export.ServingInputReceiver(
     91           features=features,
     92           receiver_tensors=None)
     93 
     94     with self.assertRaisesRegexp(
     95         ValueError, "receiver_tensors keys must be strings"):
     96       export.ServingInputReceiver(
     97           features=features,
     98           receiver_tensors={
     99               1: array_ops.placeholder(dtypes.string, name="example0")})
    100 
    101     with self.assertRaisesRegexp(
    102         ValueError, "receiver_tensor example1 must be a Tensor"):
    103       export.ServingInputReceiver(
    104           features=features,
    105           receiver_tensors={"example1": [1]})
    106 
    107   def test_single_feature_single_receiver(self):
    108     feature = constant_op.constant(5)
    109     receiver_tensor = array_ops.placeholder(dtypes.string)
    110     input_receiver = export.ServingInputReceiver(
    111         feature, receiver_tensor)
    112     # single feature is automatically named
    113     feature_key, = input_receiver.features.keys()
    114     self.assertEqual("feature", feature_key)
    115     # single receiver is automatically named
    116     receiver_key, = input_receiver.receiver_tensors.keys()
    117     self.assertEqual("input", receiver_key)
    118 
    119   def test_multi_feature_single_receiver(self):
    120     features = {"foo": constant_op.constant(5),
    121                 "bar": constant_op.constant(6)}
    122     receiver_tensor = array_ops.placeholder(dtypes.string)
    123     _ = export.ServingInputReceiver(features, receiver_tensor)
    124 
    125   def test_multi_feature_multi_receiver(self):
    126     features = {"foo": constant_op.constant(5),
    127                 "bar": constant_op.constant(6)}
    128     receiver_tensors = {"baz": array_ops.placeholder(dtypes.int64),
    129                         "qux": array_ops.placeholder(dtypes.float32)}
    130     _ = export.ServingInputReceiver(features, receiver_tensors)
    131 
    132   def test_feature_wrong_type(self):
    133     feature = "not a tensor"
    134     receiver_tensor = array_ops.placeholder(dtypes.string)
    135     with self.assertRaises(ValueError):
    136       _ = export.ServingInputReceiver(feature, receiver_tensor)
    137 
    138   def test_receiver_wrong_type(self):
    139     feature = constant_op.constant(5)
    140     receiver_tensor = "not a tensor"
    141     with self.assertRaises(ValueError):
    142       _ = export.ServingInputReceiver(feature, receiver_tensor)
    143 
    144   def test_build_parsing_serving_input_receiver_fn(self):
    145     feature_spec = {"int_feature": parsing_ops.VarLenFeature(dtypes.int64),
    146                     "float_feature": parsing_ops.VarLenFeature(dtypes.float32)}
    147     serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
    148         feature_spec)
    149     with ops.Graph().as_default():
    150       serving_input_receiver = serving_input_receiver_fn()
    151       self.assertEqual(set(["int_feature", "float_feature"]),
    152                        set(serving_input_receiver.features.keys()))
    153       self.assertEqual(set(["examples"]),
    154                        set(serving_input_receiver.receiver_tensors.keys()))
    155 
    156       example = example_pb2.Example()
    157       text_format.Parse("features: { "
    158                         "  feature: { "
    159                         "    key: 'int_feature' "
    160                         "    value: { "
    161                         "      int64_list: { "
    162                         "        value: [ 21, 2, 5 ] "
    163                         "      } "
    164                         "    } "
    165                         "  } "
    166                         "  feature: { "
    167                         "    key: 'float_feature' "
    168                         "    value: { "
    169                         "      float_list: { "
    170                         "        value: [ 525.25 ] "
    171                         "      } "
    172                         "    } "
    173                         "  } "
    174                         "} ", example)
    175 
    176       with self.test_session() as sess:
    177         sparse_result = sess.run(
    178             serving_input_receiver.features,
    179             feed_dict={
    180                 serving_input_receiver.receiver_tensors["examples"].name:
    181                 [example.SerializeToString()]})
    182         self.assertAllEqual([[0, 0], [0, 1], [0, 2]],
    183                             sparse_result["int_feature"].indices)
    184         self.assertAllEqual([21, 2, 5],
    185                             sparse_result["int_feature"].values)
    186         self.assertAllEqual([[0, 0]],
    187                             sparse_result["float_feature"].indices)
    188         self.assertAllEqual([525.25],
    189                             sparse_result["float_feature"].values)
    190 
    191   def test_build_raw_serving_input_receiver_fn_name(self):
    192     """Test case for issue #12755."""
    193     f = {
    194         "feature":
    195             array_ops.placeholder(
    196                 name="feature", shape=[32], dtype=dtypes.float32)
    197     }
    198     serving_input_receiver_fn = export.build_raw_serving_input_receiver_fn(f)
    199     v = serving_input_receiver_fn()
    200     self.assertTrue(isinstance(v, export.ServingInputReceiver))
    201 
    202   def test_build_raw_serving_input_receiver_fn(self):
    203     features = {"feature_1": constant_op.constant(["hello"]),
    204                 "feature_2": constant_op.constant([42])}
    205     serving_input_receiver_fn = export.build_raw_serving_input_receiver_fn(
    206         features)
    207     with ops.Graph().as_default():
    208       serving_input_receiver = serving_input_receiver_fn()
    209       self.assertEqual(set(["feature_1", "feature_2"]),
    210                        set(serving_input_receiver.features.keys()))
    211       self.assertEqual(set(["feature_1", "feature_2"]),
    212                        set(serving_input_receiver.receiver_tensors.keys()))
    213       self.assertEqual(
    214           dtypes.string,
    215           serving_input_receiver.receiver_tensors["feature_1"].dtype)
    216       self.assertEqual(
    217           dtypes.int32,
    218           serving_input_receiver.receiver_tensors["feature_2"].dtype)
    219 
    220   def test_build_all_signature_defs_without_receiver_alternatives(self):
    221     receiver_tensor = array_ops.placeholder(dtypes.string)
    222     output_1 = constant_op.constant([1.])
    223     output_2 = constant_op.constant(["2"])
    224     output_3 = constant_op.constant(["3"])
    225     export_outputs = {
    226         signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
    227             export_output.RegressionOutput(value=output_1),
    228         "head-2": export_output.ClassificationOutput(classes=output_2),
    229         "head-3": export_output.PredictOutput(outputs={
    230             "some_output_3": output_3
    231         }),
    232     }
    233 
    234     signature_defs = export.build_all_signature_defs(
    235         receiver_tensor, export_outputs)
    236 
    237     expected_signature_defs = {
    238         "serving_default":
    239             signature_def_utils.regression_signature_def(receiver_tensor,
    240                                                          output_1),
    241         "head-2":
    242             signature_def_utils.classification_signature_def(receiver_tensor,
    243                                                              output_2, None),
    244         "head-3":
    245             signature_def_utils.predict_signature_def({
    246                 "input": receiver_tensor
    247             }, {"some_output_3": output_3})
    248     }
    249 
    250     self.assertDictEqual(expected_signature_defs, signature_defs)
    251 
    252   def test_build_all_signature_defs_with_dict_alternatives(self):
    253     receiver_tensor = array_ops.placeholder(dtypes.string)
    254     receiver_tensors_alternative_1 = {
    255         "foo": array_ops.placeholder(dtypes.int64),
    256         "bar": array_ops.sparse_placeholder(dtypes.float32)}
    257     receiver_tensors_alternatives = {"other": receiver_tensors_alternative_1}
    258     output_1 = constant_op.constant([1.])
    259     output_2 = constant_op.constant(["2"])
    260     output_3 = constant_op.constant(["3"])
    261     export_outputs = {
    262         signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
    263             export_output.RegressionOutput(value=output_1),
    264         "head-2": export_output.ClassificationOutput(classes=output_2),
    265         "head-3": export_output.PredictOutput(outputs={
    266             "some_output_3": output_3
    267         }),
    268     }
    269 
    270     signature_defs = export.build_all_signature_defs(
    271         receiver_tensor, export_outputs, receiver_tensors_alternatives)
    272 
    273     expected_signature_defs = {
    274         "serving_default":
    275             signature_def_utils.regression_signature_def(
    276                 receiver_tensor,
    277                 output_1),
    278         "head-2":
    279             signature_def_utils.classification_signature_def(
    280                 receiver_tensor,
    281                 output_2, None),
    282         "head-3":
    283             signature_def_utils.predict_signature_def(
    284                 {"input": receiver_tensor},
    285                 {"some_output_3": output_3}),
    286         "other:head-3":
    287             signature_def_utils.predict_signature_def(
    288                 receiver_tensors_alternative_1,
    289                 {"some_output_3": output_3})
    290 
    291         # Note that the alternatives 'other:serving_default' and 'other:head-2'
    292         # are invalid, because regession and classification signatures must take
    293         # a single string input.  Here we verify that these invalid signatures
    294         # are not included in the export.
    295     }
    296 
    297     self.assertDictEqual(expected_signature_defs, signature_defs)
    298 
    299   def test_build_all_signature_defs_with_single_alternatives(self):
    300     receiver_tensor = array_ops.placeholder(dtypes.string)
    301     receiver_tensors_alternative_1 = array_ops.placeholder(dtypes.int64)
    302     receiver_tensors_alternative_2 = array_ops.sparse_placeholder(
    303         dtypes.float32)
    304     # Note we are passing single Tensors as values of
    305     # receiver_tensors_alternatives, where normally that is a dict.
    306     # In this case a dict will be created using the default receiver tensor
    307     # name "input".
    308     receiver_tensors_alternatives = {"other1": receiver_tensors_alternative_1,
    309                                      "other2": receiver_tensors_alternative_2}
    310     output_1 = constant_op.constant([1.])
    311     output_2 = constant_op.constant(["2"])
    312     output_3 = constant_op.constant(["3"])
    313     export_outputs = {
    314         signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
    315             export_output.RegressionOutput(value=output_1),
    316         "head-2": export_output.ClassificationOutput(classes=output_2),
    317         "head-3": export_output.PredictOutput(outputs={
    318             "some_output_3": output_3
    319         }),
    320     }
    321 
    322     signature_defs = export.build_all_signature_defs(
    323         receiver_tensor, export_outputs, receiver_tensors_alternatives)
    324 
    325     expected_signature_defs = {
    326         "serving_default":
    327             signature_def_utils.regression_signature_def(
    328                 receiver_tensor,
    329                 output_1),
    330         "head-2":
    331             signature_def_utils.classification_signature_def(
    332                 receiver_tensor,
    333                 output_2, None),
    334         "head-3":
    335             signature_def_utils.predict_signature_def(
    336                 {"input": receiver_tensor},
    337                 {"some_output_3": output_3}),
    338         "other1:head-3":
    339             signature_def_utils.predict_signature_def(
    340                 {"input": receiver_tensors_alternative_1},
    341                 {"some_output_3": output_3}),
    342         "other2:head-3":
    343             signature_def_utils.predict_signature_def(
    344                 {"input": receiver_tensors_alternative_2},
    345                 {"some_output_3": output_3})
    346 
    347         # Note that the alternatives 'other:serving_default' and 'other:head-2'
    348         # are invalid, because regession and classification signatures must take
    349         # a single string input.  Here we verify that these invalid signatures
    350         # are not included in the export.
    351     }
    352 
    353     self.assertDictEqual(expected_signature_defs, signature_defs)
    354 
    355   def test_build_all_signature_defs_export_outputs_required(self):
    356     receiver_tensor = constant_op.constant(["11"])
    357 
    358     with self.assertRaises(ValueError) as e:
    359       export.build_all_signature_defs(receiver_tensor, None)
    360 
    361     self.assertTrue(str(e.exception).startswith(
    362         "export_outputs must be a dict"))
    363 
    364   def test_get_timestamped_export_dir(self):
    365     export_dir_base = tempfile.mkdtemp() + "export/"
    366     export_dir_1 = export.get_timestamped_export_dir(
    367         export_dir_base)
    368     time.sleep(2)
    369     export_dir_2 = export.get_timestamped_export_dir(
    370         export_dir_base)
    371     time.sleep(2)
    372     export_dir_3 = export.get_timestamped_export_dir(
    373         export_dir_base)
    374 
    375     # Export directories should be named using a timestamp that is seconds
    376     # since epoch.  Such a timestamp is 10 digits long.
    377     time_1 = os.path.basename(export_dir_1)
    378     self.assertEqual(10, len(time_1))
    379     time_2 = os.path.basename(export_dir_2)
    380     self.assertEqual(10, len(time_2))
    381     time_3 = os.path.basename(export_dir_3)
    382     self.assertEqual(10, len(time_3))
    383 
    384     self.assertTrue(int(time_1) < int(time_2))
    385     self.assertTrue(int(time_2) < int(time_3))
    386 
    387 
    388 if __name__ == "__main__":
    389   test.main()
    390