1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Common code used for testing `Predictor`s.""" 17 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 from tensorflow.contrib.learn.python.learn.estimators import constants 23 from tensorflow.contrib.learn.python.learn.estimators import estimator as contrib_estimator 24 from tensorflow.contrib.learn.python.learn.estimators import model_fn as contrib_model_fn 25 from tensorflow.contrib.learn.python.learn.utils import input_fn_utils 26 from tensorflow.python.estimator import estimator as core_estimator 27 from tensorflow.python.estimator import model_fn 28 from tensorflow.python.estimator.export import export_lib 29 from tensorflow.python.estimator.export import export_output 30 from tensorflow.python.framework import constant_op 31 from tensorflow.python.framework import ops 32 from tensorflow.python.ops import array_ops 33 from tensorflow.python.ops import control_flow_ops 34 from tensorflow.python.ops import math_ops 35 from tensorflow.python.saved_model import signature_constants 36 37 38 def get_arithmetic_estimator(core=True, model_dir=None): 39 """Returns an `Estimator` that performs basic arithmetic. 40 41 Args: 42 core: if `True`, returns a `tensorflow.python.estimator.Estimator`. 43 Otherwise, returns a `tensorflow.contrib.learn.Estimator`. 44 model_dir: directory in which to export checkpoints and saved models. 45 Returns: 46 An `Estimator` that performs arithmetic operations on its inputs. 47 """ 48 def _model_fn(features, labels, mode): 49 _ = labels 50 x = features['x'] 51 y = features['y'] 52 with ops.name_scope('outputs'): 53 predictions = {'sum': math_ops.add(x, y, name='sum'), 54 'product': math_ops.multiply(x, y, name='product'), 55 'difference': math_ops.subtract(x, y, name='difference')} 56 if core: 57 export_outputs = {k: export_output.PredictOutput({k: v}) 58 for k, v in predictions.items()} 59 export_outputs[signature_constants. 60 DEFAULT_SERVING_SIGNATURE_DEF_KEY] = export_outputs['sum'] 61 return model_fn.EstimatorSpec(mode=mode, 62 predictions=predictions, 63 export_outputs=export_outputs, 64 loss=constant_op.constant(0), 65 train_op=control_flow_ops.no_op()) 66 else: 67 output_alternatives = {k: (constants.ProblemType.UNSPECIFIED, {k: v}) 68 for k, v in predictions.items()} 69 return contrib_model_fn.ModelFnOps( 70 mode=mode, 71 predictions=predictions, 72 output_alternatives=output_alternatives, 73 loss=constant_op.constant(0), 74 train_op=control_flow_ops.no_op()) 75 if core: 76 return core_estimator.Estimator(_model_fn) 77 else: 78 return contrib_estimator.Estimator(_model_fn, model_dir=model_dir) 79 80 81 def get_arithmetic_input_fn(core=True, train=False): 82 """Returns a input functions or serving input receiver function.""" 83 def _input_fn(): 84 with ops.name_scope('inputs'): 85 x = array_ops.placeholder_with_default(0.0, shape=[], name='x') 86 y = array_ops.placeholder_with_default(0.0, shape=[], name='y') 87 label = constant_op.constant(0.0) 88 features = {'x': x, 'y': y} 89 if core: 90 if train: 91 return features, label 92 return export_lib.ServingInputReceiver( 93 features=features, 94 receiver_tensors=features) 95 else: 96 if train: 97 return features, label 98 return input_fn_utils.InputFnOps( 99 features=features, 100 labels={}, 101 default_inputs=features) 102 return _input_fn 103