1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Aliases for logit_fn builders used by canned (core) tf.Estimator's. 16 17 A logit_fn is an abstraction within model_fn that factors out the logit 18 construction logic. Its output can be fed into Heads or otherwise composed. It 19 should follow the following signature: 20 21 Args: 22 `features`: This is the first item returned from the `input_fn` passed to 23 `train`, `evaluate`, and `predict`. This should be a single 24 `Tensor` or `dict` of same, and is the only required argument. 25 `mode`: Optional. Specifies if this training, evaluation or prediction. See 26 `ModeKeys`. 27 `params`: Optional `dict` of hyperparameters. Will receive what is passed to 28 Estimator in `params` parameter. This allows configuration of 29 Estimators from hyperparameter tuning. 30 `config`: Optional configuration object. Will receive what is passed to 31 Estimator in `config` parameter, or the default `config`. Allows 32 updating things in your model_fn based on configuration such as 33 `num_ps_replicas`, or `model_dir`. 34 35 Returns: 36 A Tensor representing the logits. 37 """ 38 from __future__ import absolute_import 39 from __future__ import division 40 from __future__ import print_function 41 42 import six 43 44 from tensorflow.python.estimator import util 45 from tensorflow.python.estimator.canned import dnn as dnn_core 46 from tensorflow.python.estimator.canned import linear as linear_core 47 from tensorflow.python.framework import ops 48 49 # pylint: disable=protected-access 50 dnn_logit_fn_builder = dnn_core._dnn_logit_fn_builder 51 linear_logit_fn_builder = linear_core._linear_logit_fn_builder 52 # pylint: enable=protected-access 53 54 55 def call_logit_fn(logit_fn, features, mode, params, config): 56 """Calls logit_fn. 57 58 A utility function that calls the provided logit_fn with the relevant subset 59 of provided arguments. Similar to tf.estimator._call_model_fn(). 60 61 Args: 62 logit_fn: A logit_fn as defined above. 63 features: The features dict. 64 mode: TRAIN / EVAL / PREDICT ModeKeys. 65 params: The hyperparameter dict. 66 config: The configuration object. 67 68 Returns: 69 A logit Tensor, the output of logit_fn. 70 71 Raises: 72 ValueError: if logit_fn does not return a Tensor or a dictionary mapping 73 strings to Tensors. 74 """ 75 logit_fn_args = util.fn_args(logit_fn) 76 kwargs = {} 77 if 'mode' in logit_fn_args: 78 kwargs['mode'] = mode 79 if 'params' in logit_fn_args: 80 kwargs['params'] = params 81 if 'config' in logit_fn_args: 82 kwargs['config'] = config 83 logit_fn_results = logit_fn(features=features, **kwargs) 84 85 result_is_valid_dictionary = ( 86 isinstance(logit_fn_results, dict) and 87 all([(isinstance(k, six.string_types) and isinstance(v, ops.Tensor)) 88 for k, v in six.iteritems(logit_fn_results)])) 89 result_is_tensor = isinstance(logit_fn_results, ops.Tensor) 90 91 if not (result_is_valid_dictionary or result_is_tensor): 92 raise ValueError('logit_fn should return a Tensor or a dictionary mapping ' 93 'strings to Tensors. logit_fn returned: %s' % 94 logit_fn_results) 95 96 return logit_fn_results 97