1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for export.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.core.framework import tensor_shape_pb2 22 from tensorflow.core.framework import types_pb2 23 from tensorflow.core.protobuf import meta_graph_pb2 24 from tensorflow.python.estimator.export import export_output as export_output_lib 25 from tensorflow.python.framework import constant_op 26 from tensorflow.python.framework import dtypes 27 from tensorflow.python.framework import sparse_tensor 28 from tensorflow.python.ops import array_ops 29 from tensorflow.python.platform import test 30 from tensorflow.python.saved_model import signature_constants 31 32 33 class ExportOutputTest(test.TestCase): 34 35 def test_regress_value_must_be_float(self): 36 value = array_ops.placeholder(dtypes.string, 1, name="output-tensor-1") 37 with self.assertRaises(ValueError) as e: 38 export_output_lib.RegressionOutput(value) 39 self.assertEqual('Regression output value must be a float32 Tensor; got ' 40 'Tensor("output-tensor-1:0", shape=(1,), dtype=string)', 41 str(e.exception)) 42 43 def test_classify_classes_must_be_strings(self): 44 classes = array_ops.placeholder(dtypes.float32, 1, name="output-tensor-1") 45 with self.assertRaises(ValueError) as e: 46 export_output_lib.ClassificationOutput(classes=classes) 47 self.assertEqual('Classification classes must be a string Tensor; got ' 48 'Tensor("output-tensor-1:0", shape=(1,), dtype=float32)', 49 str(e.exception)) 50 51 def test_classify_scores_must_be_float(self): 52 scores = array_ops.placeholder(dtypes.string, 1, name="output-tensor-1") 53 with self.assertRaises(ValueError) as e: 54 export_output_lib.ClassificationOutput(scores=scores) 55 self.assertEqual('Classification scores must be a float32 Tensor; got ' 56 'Tensor("output-tensor-1:0", shape=(1,), dtype=string)', 57 str(e.exception)) 58 59 def test_classify_requires_classes_or_scores(self): 60 with self.assertRaises(ValueError) as e: 61 export_output_lib.ClassificationOutput() 62 self.assertEqual("At least one of scores and classes must be set.", 63 str(e.exception)) 64 65 def test_build_standardized_signature_def_regression(self): 66 input_tensors = { 67 "input-1": 68 array_ops.placeholder( 69 dtypes.string, 1, name="input-tensor-1") 70 } 71 value = array_ops.placeholder(dtypes.float32, 1, name="output-tensor-1") 72 73 export_output = export_output_lib.RegressionOutput(value) 74 actual_signature_def = export_output.as_signature_def(input_tensors) 75 76 expected_signature_def = meta_graph_pb2.SignatureDef() 77 shape = tensor_shape_pb2.TensorShapeProto( 78 dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) 79 dtype_float = types_pb2.DataType.Value("DT_FLOAT") 80 dtype_string = types_pb2.DataType.Value("DT_STRING") 81 expected_signature_def.inputs[ 82 signature_constants.REGRESS_INPUTS].CopyFrom( 83 meta_graph_pb2.TensorInfo(name="input-tensor-1:0", 84 dtype=dtype_string, 85 tensor_shape=shape)) 86 expected_signature_def.outputs[ 87 signature_constants.REGRESS_OUTPUTS].CopyFrom( 88 meta_graph_pb2.TensorInfo(name="output-tensor-1:0", 89 dtype=dtype_float, 90 tensor_shape=shape)) 91 92 expected_signature_def.method_name = signature_constants.REGRESS_METHOD_NAME 93 self.assertEqual(actual_signature_def, expected_signature_def) 94 95 def test_build_standardized_signature_def_classify_classes_only(self): 96 """Tests classification with one output tensor.""" 97 input_tensors = { 98 "input-1": 99 array_ops.placeholder( 100 dtypes.string, 1, name="input-tensor-1") 101 } 102 classes = array_ops.placeholder(dtypes.string, 1, name="output-tensor-1") 103 104 export_output = export_output_lib.ClassificationOutput(classes=classes) 105 actual_signature_def = export_output.as_signature_def(input_tensors) 106 107 expected_signature_def = meta_graph_pb2.SignatureDef() 108 shape = tensor_shape_pb2.TensorShapeProto( 109 dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) 110 dtype_string = types_pb2.DataType.Value("DT_STRING") 111 expected_signature_def.inputs[ 112 signature_constants.CLASSIFY_INPUTS].CopyFrom( 113 meta_graph_pb2.TensorInfo(name="input-tensor-1:0", 114 dtype=dtype_string, 115 tensor_shape=shape)) 116 expected_signature_def.outputs[ 117 signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom( 118 meta_graph_pb2.TensorInfo(name="output-tensor-1:0", 119 dtype=dtype_string, 120 tensor_shape=shape)) 121 122 expected_signature_def.method_name = ( 123 signature_constants.CLASSIFY_METHOD_NAME) 124 self.assertEqual(actual_signature_def, expected_signature_def) 125 126 def test_build_standardized_signature_def_classify_both(self): 127 """Tests multiple output tensors that include classes and scores.""" 128 input_tensors = { 129 "input-1": 130 array_ops.placeholder( 131 dtypes.string, 1, name="input-tensor-1") 132 } 133 classes = array_ops.placeholder(dtypes.string, 1, 134 name="output-tensor-classes") 135 scores = array_ops.placeholder(dtypes.float32, 1, 136 name="output-tensor-scores") 137 138 export_output = export_output_lib.ClassificationOutput( 139 scores=scores, classes=classes) 140 actual_signature_def = export_output.as_signature_def(input_tensors) 141 142 expected_signature_def = meta_graph_pb2.SignatureDef() 143 shape = tensor_shape_pb2.TensorShapeProto( 144 dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) 145 dtype_float = types_pb2.DataType.Value("DT_FLOAT") 146 dtype_string = types_pb2.DataType.Value("DT_STRING") 147 expected_signature_def.inputs[ 148 signature_constants.CLASSIFY_INPUTS].CopyFrom( 149 meta_graph_pb2.TensorInfo(name="input-tensor-1:0", 150 dtype=dtype_string, 151 tensor_shape=shape)) 152 expected_signature_def.outputs[ 153 signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom( 154 meta_graph_pb2.TensorInfo(name="output-tensor-classes:0", 155 dtype=dtype_string, 156 tensor_shape=shape)) 157 expected_signature_def.outputs[ 158 signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom( 159 meta_graph_pb2.TensorInfo(name="output-tensor-scores:0", 160 dtype=dtype_float, 161 tensor_shape=shape)) 162 163 expected_signature_def.method_name = ( 164 signature_constants.CLASSIFY_METHOD_NAME) 165 self.assertEqual(actual_signature_def, expected_signature_def) 166 167 def test_build_standardized_signature_def_classify_scores_only(self): 168 """Tests classification without classes tensor.""" 169 input_tensors = { 170 "input-1": 171 array_ops.placeholder( 172 dtypes.string, 1, name="input-tensor-1") 173 } 174 175 scores = array_ops.placeholder(dtypes.float32, 1, 176 name="output-tensor-scores") 177 178 export_output = export_output_lib.ClassificationOutput( 179 scores=scores) 180 actual_signature_def = export_output.as_signature_def(input_tensors) 181 182 expected_signature_def = meta_graph_pb2.SignatureDef() 183 shape = tensor_shape_pb2.TensorShapeProto( 184 dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) 185 dtype_float = types_pb2.DataType.Value("DT_FLOAT") 186 dtype_string = types_pb2.DataType.Value("DT_STRING") 187 expected_signature_def.inputs[ 188 signature_constants.CLASSIFY_INPUTS].CopyFrom( 189 meta_graph_pb2.TensorInfo(name="input-tensor-1:0", 190 dtype=dtype_string, 191 tensor_shape=shape)) 192 expected_signature_def.outputs[ 193 signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom( 194 meta_graph_pb2.TensorInfo(name="output-tensor-scores:0", 195 dtype=dtype_float, 196 tensor_shape=shape)) 197 198 expected_signature_def.method_name = ( 199 signature_constants.CLASSIFY_METHOD_NAME) 200 self.assertEqual(actual_signature_def, expected_signature_def) 201 202 def test_predict_outputs_valid(self): 203 """Tests that no errors are raised when provided outputs are valid.""" 204 outputs = { 205 "output0": constant_op.constant([0]), 206 u"output1": constant_op.constant(["foo"]), 207 } 208 export_output_lib.PredictOutput(outputs) 209 210 # Single Tensor is OK too 211 export_output_lib.PredictOutput(constant_op.constant([0])) 212 213 def test_predict_outputs_invalid(self): 214 with self.assertRaisesRegexp( 215 ValueError, 216 "Prediction output key must be a string"): 217 export_output_lib.PredictOutput({1: constant_op.constant([0])}) 218 219 with self.assertRaisesRegexp( 220 ValueError, 221 "Prediction output value must be a Tensor"): 222 export_output_lib.PredictOutput({ 223 "prediction1": sparse_tensor.SparseTensor( 224 indices=[[0, 0]], values=[1], dense_shape=[1, 1]), 225 }) 226 227 228 if __name__ == "__main__": 229 test.main() 230