Home | History | Annotate | Download | only in predictor
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Common code used for testing `Predictor`s."""
     17 
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 from tensorflow.contrib.learn.python.learn.estimators import constants
     23 from tensorflow.contrib.learn.python.learn.estimators import estimator as contrib_estimator
     24 from tensorflow.contrib.learn.python.learn.estimators import model_fn as contrib_model_fn
     25 from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
     26 from tensorflow.python.estimator import estimator as core_estimator
     27 from tensorflow.python.estimator import model_fn
     28 from tensorflow.python.estimator.export import export_lib
     29 from tensorflow.python.estimator.export import export_output
     30 from tensorflow.python.framework import constant_op
     31 from tensorflow.python.framework import ops
     32 from tensorflow.python.ops import array_ops
     33 from tensorflow.python.ops import control_flow_ops
     34 from tensorflow.python.ops import math_ops
     35 from tensorflow.python.saved_model import signature_constants
     36 
     37 
     38 def get_arithmetic_estimator(core=True, model_dir=None):
     39   """Returns an `Estimator` that performs basic arithmetic.
     40 
     41   Args:
     42     core: if `True`, returns a `tensorflow.python.estimator.Estimator`.
     43       Otherwise, returns a `tensorflow.contrib.learn.Estimator`.
     44     model_dir: directory in which to export checkpoints and saved models.
     45   Returns:
     46     An `Estimator` that performs arithmetic operations on its inputs.
     47   """
     48   def _model_fn(features, labels, mode):
     49     _ = labels
     50     x = features['x']
     51     y = features['y']
     52     with ops.name_scope('outputs'):
     53       predictions = {'sum': math_ops.add(x, y, name='sum'),
     54                      'product': math_ops.multiply(x, y, name='product'),
     55                      'difference': math_ops.subtract(x, y, name='difference')}
     56     if core:
     57       export_outputs = {k: export_output.PredictOutput({k: v})
     58                         for k, v in predictions.items()}
     59       export_outputs[signature_constants.
     60                      DEFAULT_SERVING_SIGNATURE_DEF_KEY] = export_outputs['sum']
     61       return model_fn.EstimatorSpec(mode=mode,
     62                                     predictions=predictions,
     63                                     export_outputs=export_outputs,
     64                                     loss=constant_op.constant(0),
     65                                     train_op=control_flow_ops.no_op())
     66     else:
     67       output_alternatives = {k: (constants.ProblemType.UNSPECIFIED, {k: v})
     68                              for k, v in predictions.items()}
     69       return contrib_model_fn.ModelFnOps(
     70           mode=mode,
     71           predictions=predictions,
     72           output_alternatives=output_alternatives,
     73           loss=constant_op.constant(0),
     74           train_op=control_flow_ops.no_op())
     75   if core:
     76     return core_estimator.Estimator(_model_fn)
     77   else:
     78     return contrib_estimator.Estimator(_model_fn, model_dir=model_dir)
     79 
     80 
     81 def get_arithmetic_input_fn(core=True, train=False):
     82   """Returns a input functions or serving input receiver function."""
     83   def _input_fn():
     84     with ops.name_scope('inputs'):
     85       x = array_ops.placeholder_with_default(0.0, shape=[], name='x')
     86       y = array_ops.placeholder_with_default(0.0, shape=[], name='y')
     87     label = constant_op.constant(0.0)
     88     features = {'x': x, 'y': y}
     89     if core:
     90       if train:
     91         return features, label
     92       return export_lib.ServingInputReceiver(
     93           features=features,
     94           receiver_tensors=features)
     95     else:
     96       if train:
     97         return features, label
     98       return input_fn_utils.InputFnOps(
     99           features=features,
    100           labels={},
    101           default_inputs=features)
    102   return _input_fn
    103