Home | History | Annotate | Download | only in estimator
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Aliases for logit_fn builders used by canned (core) tf.Estimator's.
     16 
     17 A logit_fn is an abstraction within model_fn that factors out the logit
     18 construction logic.  Its output can be fed into Heads or otherwise composed.  It
     19 should follow the following signature:
     20 
     21 Args:
     22 `features`: This is the first item returned from the `input_fn` passed to
     23             `train`, `evaluate`, and `predict`. This should be a single
     24             `Tensor` or `dict` of same, and is the only required argument.
     25 `mode`: Optional. Specifies if this training, evaluation or prediction. See
     26         `ModeKeys`.
     27 `params`: Optional `dict` of hyperparameters.  Will receive what is passed to
     28           Estimator in `params` parameter. This allows configuration of
     29           Estimators from hyperparameter tuning.
     30 `config`: Optional configuration object. Will receive what is passed to
     31           Estimator in `config` parameter, or the default `config`. Allows
     32           updating things in your model_fn based on configuration such as
     33           `num_ps_replicas`, or `model_dir`.
     34 
     35 Returns:
     36     A Tensor representing the logits.
     37 """
     38 from __future__ import absolute_import
     39 from __future__ import division
     40 from __future__ import print_function
     41 
     42 import six
     43 
     44 from tensorflow.python.estimator import util
     45 from tensorflow.python.estimator.canned import dnn as dnn_core
     46 from tensorflow.python.estimator.canned import linear as linear_core
     47 from tensorflow.python.framework import ops
     48 
     49 # pylint: disable=protected-access
     50 dnn_logit_fn_builder = dnn_core._dnn_logit_fn_builder
     51 linear_logit_fn_builder = linear_core._linear_logit_fn_builder
     52 # pylint: enable=protected-access
     53 
     54 
     55 def call_logit_fn(logit_fn, features, mode, params, config):
     56   """Calls logit_fn.
     57 
     58   A utility function that calls the provided logit_fn with the relevant subset
     59   of provided arguments.  Similar to tf.estimator._call_model_fn().
     60 
     61   Args:
     62     logit_fn: A logit_fn as defined above.
     63     features: The features dict.
     64     mode: TRAIN / EVAL / PREDICT ModeKeys.
     65     params: The hyperparameter dict.
     66     config: The configuration object.
     67 
     68   Returns:
     69     A logit Tensor, the output of logit_fn.
     70 
     71   Raises:
     72     ValueError: if logit_fn does not return a Tensor or a dictionary mapping
     73       strings to Tensors.
     74   """
     75   logit_fn_args = util.fn_args(logit_fn)
     76   kwargs = {}
     77   if 'mode' in logit_fn_args:
     78     kwargs['mode'] = mode
     79   if 'params' in logit_fn_args:
     80     kwargs['params'] = params
     81   if 'config' in logit_fn_args:
     82     kwargs['config'] = config
     83   logit_fn_results = logit_fn(features=features, **kwargs)
     84 
     85   result_is_valid_dictionary = (
     86       isinstance(logit_fn_results, dict) and
     87       all([(isinstance(k, six.string_types) and isinstance(v, ops.Tensor))
     88            for k, v in six.iteritems(logit_fn_results)]))
     89   result_is_tensor = isinstance(logit_fn_results, ops.Tensor)
     90 
     91   if not (result_is_valid_dictionary or result_is_tensor):
     92     raise ValueError('logit_fn should return a Tensor or a dictionary mapping '
     93                      'strings to Tensors.  logit_fn returned: %s' %
     94                      logit_fn_results)
     95 
     96   return logit_fn_results
     97