1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Tests for predictor.core_estimator_predictor.""" 17 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import tempfile 23 import numpy as np 24 25 from tensorflow.contrib.predictor import core_estimator_predictor 26 from tensorflow.contrib.predictor import testing_common 27 from tensorflow.python.platform import test 28 29 30 KEYS_AND_OPS = (('sum', lambda x, y: x + y), 31 ('product', lambda x, y: x * y,), 32 ('difference', lambda x, y: x - y)) 33 34 35 class CoreEstimatorPredictorTest(test.TestCase): 36 """Test fixture for `CoreEstimatorPredictor`.""" 37 38 def setUp(self): 39 model_dir = tempfile.mkdtemp() 40 self._estimator = testing_common.get_arithmetic_estimator( 41 core=True, model_dir=model_dir) 42 self._serving_input_receiver_fn = testing_common.get_arithmetic_input_fn( 43 core=True, train=False) 44 45 def testDefault(self): 46 """Test prediction with default signature.""" 47 np.random.seed(1111) 48 x = np.random.rand() 49 y = np.random.rand() 50 predictor = core_estimator_predictor.CoreEstimatorPredictor( 51 estimator=self._estimator, 52 serving_input_receiver_fn=self._serving_input_receiver_fn) 53 output = predictor({'x': x, 'y': y})['sum'] 54 self.assertAlmostEqual(output, x + y, places=3) 55 56 def testSpecifiedSignatureKey(self): 57 """Test prediction with spedicified signatures.""" 58 np.random.seed(1234) 59 for output_key, op in KEYS_AND_OPS: 60 x = np.random.rand() 61 y = np.random.rand() 62 expected_output = op(x, y) 63 64 predictor = core_estimator_predictor.CoreEstimatorPredictor( 65 estimator=self._estimator, 66 serving_input_receiver_fn=self._serving_input_receiver_fn, 67 output_key=output_key) 68 output_tensor_name = predictor.fetch_tensors[output_key].name 69 self.assertRegexpMatches( 70 output_tensor_name, 71 output_key, 72 msg='Unexpected fetch tensor.') 73 output = predictor({'x': x, 'y': y})[output_key] 74 self.assertAlmostEqual( 75 expected_output, output, places=3, 76 msg='Failed for output key "{}." ' 77 'Got output {} for x = {} and y = {}'.format( 78 output_key, output, x, y)) 79 80 if __name__ == '__main__': 81 test.main() 82