1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for export.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import os 22 import tempfile 23 import time 24 25 from google.protobuf import text_format 26 27 from tensorflow.core.example import example_pb2 28 from tensorflow.python.estimator.export import export 29 from tensorflow.python.estimator.export import export_output 30 from tensorflow.python.framework import constant_op 31 from tensorflow.python.framework import dtypes 32 from tensorflow.python.framework import ops 33 from tensorflow.python.framework import sparse_tensor 34 from tensorflow.python.framework import test_util 35 from tensorflow.python.ops import array_ops 36 from tensorflow.python.ops import parsing_ops 37 from tensorflow.python.platform import test 38 from tensorflow.python.saved_model import signature_constants 39 from tensorflow.python.saved_model import signature_def_utils 40 41 42 class ExportTest(test_util.TensorFlowTestCase): 43 44 def test_serving_input_receiver_constructor(self): 45 """Tests that no errors are raised when input is expected.""" 46 features = { 47 "feature0": constant_op.constant([0]), 48 u"feature1": constant_op.constant([1]), 49 "feature2": sparse_tensor.SparseTensor( 50 indices=[[0, 0]], values=[1], dense_shape=[1, 1]), 51 } 52 receiver_tensors = { 53 "example0": array_ops.placeholder(dtypes.string, name="example0"), 54 u"example1": array_ops.placeholder(dtypes.string, name="example1"), 55 } 56 export.ServingInputReceiver(features, receiver_tensors) 57 58 def test_serving_input_receiver_features_invalid(self): 59 receiver_tensors = { 60 "example0": array_ops.placeholder(dtypes.string, name="example0"), 61 u"example1": array_ops.placeholder(dtypes.string, name="example1"), 62 } 63 64 with self.assertRaisesRegexp(ValueError, "features must be defined"): 65 export.ServingInputReceiver( 66 features=None, 67 receiver_tensors=receiver_tensors) 68 69 with self.assertRaisesRegexp(ValueError, "feature keys must be strings"): 70 export.ServingInputReceiver( 71 features={1: constant_op.constant([1])}, 72 receiver_tensors=receiver_tensors) 73 74 with self.assertRaisesRegexp( 75 ValueError, "feature feature1 must be a Tensor or SparseTensor"): 76 export.ServingInputReceiver( 77 features={"feature1": [1]}, 78 receiver_tensors=receiver_tensors) 79 80 def test_serving_input_receiver_receiver_tensors_invalid(self): 81 features = { 82 "feature0": constant_op.constant([0]), 83 u"feature1": constant_op.constant([1]), 84 "feature2": sparse_tensor.SparseTensor( 85 indices=[[0, 0]], values=[1], dense_shape=[1, 1]), 86 } 87 88 with self.assertRaisesRegexp( 89 ValueError, "receiver_tensors must be defined"): 90 export.ServingInputReceiver( 91 features=features, 92 receiver_tensors=None) 93 94 with self.assertRaisesRegexp( 95 ValueError, "receiver_tensors keys must be strings"): 96 export.ServingInputReceiver( 97 features=features, 98 receiver_tensors={ 99 1: array_ops.placeholder(dtypes.string, name="example0")}) 100 101 with self.assertRaisesRegexp( 102 ValueError, "receiver_tensor example1 must be a Tensor"): 103 export.ServingInputReceiver( 104 features=features, 105 receiver_tensors={"example1": [1]}) 106 107 def test_single_feature_single_receiver(self): 108 feature = constant_op.constant(5) 109 receiver_tensor = array_ops.placeholder(dtypes.string) 110 input_receiver = export.ServingInputReceiver( 111 feature, receiver_tensor) 112 # single feature is automatically named 113 feature_key, = input_receiver.features.keys() 114 self.assertEqual("feature", feature_key) 115 # single receiver is automatically named 116 receiver_key, = input_receiver.receiver_tensors.keys() 117 self.assertEqual("input", receiver_key) 118 119 def test_multi_feature_single_receiver(self): 120 features = {"foo": constant_op.constant(5), 121 "bar": constant_op.constant(6)} 122 receiver_tensor = array_ops.placeholder(dtypes.string) 123 _ = export.ServingInputReceiver(features, receiver_tensor) 124 125 def test_multi_feature_multi_receiver(self): 126 features = {"foo": constant_op.constant(5), 127 "bar": constant_op.constant(6)} 128 receiver_tensors = {"baz": array_ops.placeholder(dtypes.int64), 129 "qux": array_ops.placeholder(dtypes.float32)} 130 _ = export.ServingInputReceiver(features, receiver_tensors) 131 132 def test_feature_wrong_type(self): 133 feature = "not a tensor" 134 receiver_tensor = array_ops.placeholder(dtypes.string) 135 with self.assertRaises(ValueError): 136 _ = export.ServingInputReceiver(feature, receiver_tensor) 137 138 def test_receiver_wrong_type(self): 139 feature = constant_op.constant(5) 140 receiver_tensor = "not a tensor" 141 with self.assertRaises(ValueError): 142 _ = export.ServingInputReceiver(feature, receiver_tensor) 143 144 def test_build_parsing_serving_input_receiver_fn(self): 145 feature_spec = {"int_feature": parsing_ops.VarLenFeature(dtypes.int64), 146 "float_feature": parsing_ops.VarLenFeature(dtypes.float32)} 147 serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( 148 feature_spec) 149 with ops.Graph().as_default(): 150 serving_input_receiver = serving_input_receiver_fn() 151 self.assertEqual(set(["int_feature", "float_feature"]), 152 set(serving_input_receiver.features.keys())) 153 self.assertEqual(set(["examples"]), 154 set(serving_input_receiver.receiver_tensors.keys())) 155 156 example = example_pb2.Example() 157 text_format.Parse("features: { " 158 " feature: { " 159 " key: 'int_feature' " 160 " value: { " 161 " int64_list: { " 162 " value: [ 21, 2, 5 ] " 163 " } " 164 " } " 165 " } " 166 " feature: { " 167 " key: 'float_feature' " 168 " value: { " 169 " float_list: { " 170 " value: [ 525.25 ] " 171 " } " 172 " } " 173 " } " 174 "} ", example) 175 176 with self.test_session() as sess: 177 sparse_result = sess.run( 178 serving_input_receiver.features, 179 feed_dict={ 180 serving_input_receiver.receiver_tensors["examples"].name: 181 [example.SerializeToString()]}) 182 self.assertAllEqual([[0, 0], [0, 1], [0, 2]], 183 sparse_result["int_feature"].indices) 184 self.assertAllEqual([21, 2, 5], 185 sparse_result["int_feature"].values) 186 self.assertAllEqual([[0, 0]], 187 sparse_result["float_feature"].indices) 188 self.assertAllEqual([525.25], 189 sparse_result["float_feature"].values) 190 191 def test_build_raw_serving_input_receiver_fn_name(self): 192 """Test case for issue #12755.""" 193 f = { 194 "feature": 195 array_ops.placeholder( 196 name="feature", shape=[32], dtype=dtypes.float32) 197 } 198 serving_input_receiver_fn = export.build_raw_serving_input_receiver_fn(f) 199 v = serving_input_receiver_fn() 200 self.assertTrue(isinstance(v, export.ServingInputReceiver)) 201 202 def test_build_raw_serving_input_receiver_fn(self): 203 features = {"feature_1": constant_op.constant(["hello"]), 204 "feature_2": constant_op.constant([42])} 205 serving_input_receiver_fn = export.build_raw_serving_input_receiver_fn( 206 features) 207 with ops.Graph().as_default(): 208 serving_input_receiver = serving_input_receiver_fn() 209 self.assertEqual(set(["feature_1", "feature_2"]), 210 set(serving_input_receiver.features.keys())) 211 self.assertEqual(set(["feature_1", "feature_2"]), 212 set(serving_input_receiver.receiver_tensors.keys())) 213 self.assertEqual( 214 dtypes.string, 215 serving_input_receiver.receiver_tensors["feature_1"].dtype) 216 self.assertEqual( 217 dtypes.int32, 218 serving_input_receiver.receiver_tensors["feature_2"].dtype) 219 220 def test_build_all_signature_defs_without_receiver_alternatives(self): 221 receiver_tensor = array_ops.placeholder(dtypes.string) 222 output_1 = constant_op.constant([1.]) 223 output_2 = constant_op.constant(["2"]) 224 output_3 = constant_op.constant(["3"]) 225 export_outputs = { 226 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 227 export_output.RegressionOutput(value=output_1), 228 "head-2": export_output.ClassificationOutput(classes=output_2), 229 "head-3": export_output.PredictOutput(outputs={ 230 "some_output_3": output_3 231 }), 232 } 233 234 signature_defs = export.build_all_signature_defs( 235 receiver_tensor, export_outputs) 236 237 expected_signature_defs = { 238 "serving_default": 239 signature_def_utils.regression_signature_def(receiver_tensor, 240 output_1), 241 "head-2": 242 signature_def_utils.classification_signature_def(receiver_tensor, 243 output_2, None), 244 "head-3": 245 signature_def_utils.predict_signature_def({ 246 "input": receiver_tensor 247 }, {"some_output_3": output_3}) 248 } 249 250 self.assertDictEqual(expected_signature_defs, signature_defs) 251 252 def test_build_all_signature_defs_with_dict_alternatives(self): 253 receiver_tensor = array_ops.placeholder(dtypes.string) 254 receiver_tensors_alternative_1 = { 255 "foo": array_ops.placeholder(dtypes.int64), 256 "bar": array_ops.sparse_placeholder(dtypes.float32)} 257 receiver_tensors_alternatives = {"other": receiver_tensors_alternative_1} 258 output_1 = constant_op.constant([1.]) 259 output_2 = constant_op.constant(["2"]) 260 output_3 = constant_op.constant(["3"]) 261 export_outputs = { 262 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 263 export_output.RegressionOutput(value=output_1), 264 "head-2": export_output.ClassificationOutput(classes=output_2), 265 "head-3": export_output.PredictOutput(outputs={ 266 "some_output_3": output_3 267 }), 268 } 269 270 signature_defs = export.build_all_signature_defs( 271 receiver_tensor, export_outputs, receiver_tensors_alternatives) 272 273 expected_signature_defs = { 274 "serving_default": 275 signature_def_utils.regression_signature_def( 276 receiver_tensor, 277 output_1), 278 "head-2": 279 signature_def_utils.classification_signature_def( 280 receiver_tensor, 281 output_2, None), 282 "head-3": 283 signature_def_utils.predict_signature_def( 284 {"input": receiver_tensor}, 285 {"some_output_3": output_3}), 286 "other:head-3": 287 signature_def_utils.predict_signature_def( 288 receiver_tensors_alternative_1, 289 {"some_output_3": output_3}) 290 291 # Note that the alternatives 'other:serving_default' and 'other:head-2' 292 # are invalid, because regession and classification signatures must take 293 # a single string input. Here we verify that these invalid signatures 294 # are not included in the export. 295 } 296 297 self.assertDictEqual(expected_signature_defs, signature_defs) 298 299 def test_build_all_signature_defs_with_single_alternatives(self): 300 receiver_tensor = array_ops.placeholder(dtypes.string) 301 receiver_tensors_alternative_1 = array_ops.placeholder(dtypes.int64) 302 receiver_tensors_alternative_2 = array_ops.sparse_placeholder( 303 dtypes.float32) 304 # Note we are passing single Tensors as values of 305 # receiver_tensors_alternatives, where normally that is a dict. 306 # In this case a dict will be created using the default receiver tensor 307 # name "input". 308 receiver_tensors_alternatives = {"other1": receiver_tensors_alternative_1, 309 "other2": receiver_tensors_alternative_2} 310 output_1 = constant_op.constant([1.]) 311 output_2 = constant_op.constant(["2"]) 312 output_3 = constant_op.constant(["3"]) 313 export_outputs = { 314 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 315 export_output.RegressionOutput(value=output_1), 316 "head-2": export_output.ClassificationOutput(classes=output_2), 317 "head-3": export_output.PredictOutput(outputs={ 318 "some_output_3": output_3 319 }), 320 } 321 322 signature_defs = export.build_all_signature_defs( 323 receiver_tensor, export_outputs, receiver_tensors_alternatives) 324 325 expected_signature_defs = { 326 "serving_default": 327 signature_def_utils.regression_signature_def( 328 receiver_tensor, 329 output_1), 330 "head-2": 331 signature_def_utils.classification_signature_def( 332 receiver_tensor, 333 output_2, None), 334 "head-3": 335 signature_def_utils.predict_signature_def( 336 {"input": receiver_tensor}, 337 {"some_output_3": output_3}), 338 "other1:head-3": 339 signature_def_utils.predict_signature_def( 340 {"input": receiver_tensors_alternative_1}, 341 {"some_output_3": output_3}), 342 "other2:head-3": 343 signature_def_utils.predict_signature_def( 344 {"input": receiver_tensors_alternative_2}, 345 {"some_output_3": output_3}) 346 347 # Note that the alternatives 'other:serving_default' and 'other:head-2' 348 # are invalid, because regession and classification signatures must take 349 # a single string input. Here we verify that these invalid signatures 350 # are not included in the export. 351 } 352 353 self.assertDictEqual(expected_signature_defs, signature_defs) 354 355 def test_build_all_signature_defs_export_outputs_required(self): 356 receiver_tensor = constant_op.constant(["11"]) 357 358 with self.assertRaises(ValueError) as e: 359 export.build_all_signature_defs(receiver_tensor, None) 360 361 self.assertTrue(str(e.exception).startswith( 362 "export_outputs must be a dict")) 363 364 def test_get_timestamped_export_dir(self): 365 export_dir_base = tempfile.mkdtemp() + "export/" 366 export_dir_1 = export.get_timestamped_export_dir( 367 export_dir_base) 368 time.sleep(2) 369 export_dir_2 = export.get_timestamped_export_dir( 370 export_dir_base) 371 time.sleep(2) 372 export_dir_3 = export.get_timestamped_export_dir( 373 export_dir_base) 374 375 # Export directories should be named using a timestamp that is seconds 376 # since epoch. Such a timestamp is 10 digits long. 377 time_1 = os.path.basename(export_dir_1) 378 self.assertEqual(10, len(time_1)) 379 time_2 = os.path.basename(export_dir_2) 380 self.assertEqual(10, len(time_2)) 381 time_3 = os.path.basename(export_dir_3) 382 self.assertEqual(10, len(time_3)) 383 384 self.assertTrue(int(time_1) < int(time_2)) 385 self.assertTrue(int(time_2) < int(time_3)) 386 387 388 if __name__ == "__main__": 389 test.main() 390