Home | History | Annotate | Download | only in estimator
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Extenders of tf.estimator.Estimator."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import six
     22 
     23 from tensorflow.python.estimator import estimator as estimator_lib
     24 from tensorflow.python.estimator import model_fn as model_fn_lib
     25 from tensorflow.python.estimator import util as estimator_util
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
     28 from tensorflow.python.ops import clip_ops
     29 from tensorflow.python.training import optimizer as optimizer_lib
     30 
     31 
     32 _VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config'])
     33 
     34 
     35 def add_metrics(estimator, metric_fn):
     36   """Creates a new ${tf.estimator.Estimator} which has given metrics.
     37 
     38   Example:
     39 
     40   ```python
     41     def my_auc(labels, predictions):
     42       return {'auc': tf.metrics.auc(labels, predictions['logistic'])}
     43 
     44     estimator = tf.estimator.DNNClassifier(...)
     45     estimator = tf.contrib.estimator.add_metrics(estimator, my_auc)
     46     estimator.train(...)
     47     estimator.evaluate(...)
     48   ```
     49   Example usage of custom metric which uses features:
     50 
     51   ```python
     52     def my_auc(features, labels, predictions):
     53       return {'auc': tf.metrics.auc(
     54         labels, predictions['logistic'], weights=features['weight'])}
     55 
     56     estimator = tf.estimator.DNNClassifier(...)
     57     estimator = tf.contrib.estimator.add_metrics(estimator, my_auc)
     58     estimator.train(...)
     59     estimator.evaluate(...)
     60   ```
     61 
     62   Args:
     63     estimator: A ${tf.estimator.Estimator} object.
     64     metric_fn: A function which should obey the following signature:
     65       - Args: can only have following four arguments in any order:
     66         * predictions: Predictions `Tensor` or dict of `Tensor` created by given
     67           `estimator`.
     68         * features: Input `dict` of `Tensor` objects created by `input_fn` which
     69           is given to `estimator.evaluate` as an argument.
     70         * labels:  Labels `Tensor` or dict of `Tensor` created by `input_fn`
     71           which is given to `estimator.evaluate` as an argument.
     72         * config: config attribute of the `estimator`.
     73        - Returns:
     74          Dict of metric results keyed by name. Final metrics are a union of this
     75          and `estimator's` existing metrics. If there is a name conflict between
     76          this and `estimator`s existing metrics, this will override the existing
     77          one. The values of the dict are the results of calling a metric
     78          function, namely a `(metric_tensor, update_op)` tuple.
     79 
     80   Returns:
     81       A new ${tf.estimator.Estimator} which has a union of original metrics with
     82         given ones.
     83   """
     84   _verify_metric_fn_args(metric_fn)
     85 
     86   def new_model_fn(features, labels, mode, config):
     87     spec = estimator.model_fn(features, labels, mode, config)
     88     if mode != model_fn_lib.ModeKeys.EVAL:
     89       return spec
     90     new_metrics = _call_metric_fn(metric_fn, features, labels, spec.predictions,
     91                                   config)
     92     all_metrics = spec.eval_metric_ops or {}
     93     all_metrics.update(new_metrics)
     94     return spec._replace(eval_metric_ops=all_metrics)
     95 
     96   return estimator_lib.Estimator(
     97       model_fn=new_model_fn,
     98       model_dir=estimator.model_dir,
     99       config=estimator.config)
    100 
    101 
    102 def clip_gradients_by_norm(optimizer, clip_norm):
    103   """Returns an optimizer which clips gradients before applying them.
    104 
    105   Example:
    106 
    107   ```python
    108   optimizer = tf.train.ProximalAdagradOptimizer(
    109       learning_rate=0.1,
    110       l1_regularization_strength=0.001)
    111   optimizer = tf.contrib.estimator.clip_gradients_by_norm(
    112       optimizer, clip_norm)
    113   estimator = tf.estimator.DNNClassifier(
    114       feature_columns=[...],
    115       hidden_units=[1024, 512, 256],
    116       optimizer=optimizer)
    117   ```
    118 
    119   Args:
    120     optimizer: An `tf.Optimizer` object to apply gradients.
    121     clip_norm: A 0-D (scalar) `Tensor` > 0. The clipping ratio.
    122 
    123   Returns:
    124     A `tf.Optimizer`.
    125   """
    126 
    127   def clip_grads(grads_and_vars):
    128     gradients, variables = zip(*grads_and_vars)
    129     gradients = clip_ops.clip_by_global_norm(gradients, clip_norm)[0]
    130     grads_and_vars = list(zip(gradients, variables))
    131     return grads_and_vars
    132 
    133   return _TransformGradients(
    134       optimizer=optimizer,
    135       transform_grads_fn=clip_grads,
    136       name='ClipByNorm' + optimizer.get_name())
    137 
    138 
    139 def forward_features(estimator, keys=None):
    140   """Forward features to predictions dictionary.
    141 
    142   In some cases, user wants to see some of the features in estimators prediction
    143   output. As an example, consider a batch prediction service: The service simply
    144   runs inference on the users graph and returns the results. Keys are essential
    145   because there is no order guarantee on the outputs so they need to be rejoined
    146   to the inputs via keys or transclusion of the inputs in the outputs.
    147 
    148   Example:
    149 
    150   ```python
    151     def input_fn():
    152       features, labels = ...
    153       features['unique_example_id'] = ...
    154       features, labels
    155 
    156     estimator = tf.estimator.LinearClassifier(...)
    157     estimator = tf.contrib.estimator.forward_features(
    158         estimator, 'unique_example_id')
    159     estimator.train(...)
    160     assert 'unique_example_id' in estimator.predict(...)
    161   ```
    162 
    163   Args:
    164     estimator: A ${tf.estimator.Estimator} object.
    165     keys: a `string` or a `list` of `string`. If it is `None`, all of the
    166       `features` in `dict` is forwarded to the `predictions`. If it is a
    167       `string`, only given key is forwarded. If it is a `list` of strings, all
    168       the given `keys` are forwarded.
    169 
    170   Returns:
    171       A new ${tf.estimator.Estimator} which forwards features to predictions.
    172 
    173   Raises:
    174     ValueError:
    175       * if `keys` is already part of `predictions`. We don't allow
    176         override.
    177       * if 'keys' does not exist in `features`.
    178       * if feature key refers to a `SparseTensor`, since we don't support
    179         `SparseTensor` in `predictions`. `SparseTensor` is common in `features`.
    180     TypeError: if `keys` type is not one of `string` or list/tuple of `string`.
    181   """
    182 
    183   def verify_key_types(keys):  # pylint: disable=missing-docstring
    184     if keys is None:
    185       return keys
    186     if isinstance(keys, six.string_types):
    187       return [keys]
    188     if not isinstance(keys, (list, tuple)):
    189       raise TypeError('keys should be either a string or a list of strings. '
    190                       'Given: {}'.format(type(keys)))
    191     for key in keys:
    192       if not isinstance(key, six.string_types):
    193         raise TypeError('All items in the given keys list should be a string. '
    194                         'There exist an item with type: {}'.format(type(key)))
    195     return keys
    196 
    197   def get_keys(features):
    198     if keys is None:
    199       return features.keys()
    200     return keys
    201 
    202   def verify_keys_and_predictions(features, predictions):
    203     if not isinstance(predictions, dict):
    204       raise ValueError(
    205           'Predictions should be a dict to be able to forward features. '
    206           'Given: {}'.format(type(predictions)))
    207     for key in get_keys(features):
    208       if key not in features:
    209         raise ValueError(
    210             'keys should be exist in features. Key "{}" is not in features '
    211             'dict. features dict has following keys: {}. Please check '
    212             'arguments of forward_features.'.format(key, features.keys()))
    213       if key in predictions:
    214         raise ValueError(
    215             'Cannot forward feature key ({}). Since it does exist in '
    216             'predictions. Existing prediction keys: {}. Please check arguments '
    217             'of forward_features.'.format(key, predictions.keys()))
    218 
    219   keys = verify_key_types(keys)
    220 
    221   def new_model_fn(features, labels, mode, config):  # pylint: disable=missing-docstring
    222     spec = estimator.model_fn(features, labels, mode, config)
    223     predictions = spec.predictions
    224     if predictions is None:
    225       return spec
    226     verify_keys_and_predictions(features, predictions)
    227     for key in get_keys(features):
    228       feature = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
    229           features[key])
    230       if not isinstance(feature, ops.Tensor):
    231         raise ValueError(
    232             'Forwarded feature ({}) should be a Tensor. Please use keys '
    233             'argument of forward_features to filter unwanted features. Type of '
    234             'features[{}] is {}.'.format(key, key, type(feature)))
    235       predictions[key] = feature
    236     return spec._replace(predictions=predictions)
    237 
    238   return estimator_lib.Estimator(
    239       model_fn=new_model_fn,
    240       model_dir=estimator.model_dir,
    241       config=estimator.config)
    242 
    243 
    244 class _TransformGradients(optimizer_lib.Optimizer):
    245   """Add given gradient transformation to the optimizer."""
    246 
    247   def __init__(self, optimizer, transform_grads_fn, name=None):
    248     """Construct an `tf.Optimizer` wrapper to apply given transformations.
    249 
    250     Example:
    251 
    252     ```python
    253     optimizer = tf.train.ProximalAdagradOptimizer(
    254         learning_rate=0.1,
    255         l1_regularization_strength=0.001)
    256     def clip_grads(grads_and_vars):
    257       gradients, variables = zip(*grads_and_vars)
    258       gradients = tf.clip_by_global_norm(grads, my_norm)[0]
    259       grads_and_vars = list(zip(gradients, variables))
    260       return grads_and_vars
    261     optimizer = _TransformGradients(
    262         opt=optimizer, transform_grads_fn=clip_grads)
    263     estimator = tf.estimator.DNNClassifier(
    264         feature_columns=[...],
    265         hidden_units=[1024, 512, 256],
    266         optimizer=optimizer)
    267     ```
    268 
    269     Args:
    270       optimizer: An `tf.Optimizer` object to apply gradients.
    271       transform_grads_fn: A function which takes a single argument, a list of
    272         gradient to variable pairs (tuples), performs any requested gradient
    273         updates, such as gradient clipping or multipliers, and returns the
    274         updated list.
    275       name: A string which will be used for debugging purposes.
    276     """
    277     super(_TransformGradients, self).__init__(
    278         use_locking=False, name=name or optimizer.get_name())
    279     self._optimizer = optimizer
    280     self._transform_grads_fn = transform_grads_fn
    281 
    282   def compute_gradients(self, *args, **kwargs):
    283     """See `tf.Optimizer`."""
    284     return self._optimizer.compute_gradients(*args, **kwargs)
    285 
    286   def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    287     """Apply gradients to variables.
    288 
    289     Calls `transform_grads_fn`, and then applies the real optimizer.
    290 
    291     Args:
    292       grads_and_vars: List of (gradient, variable) pairs as returned by
    293         compute_gradients().
    294       global_step: Optional Variable to increment by one after the
    295         variables have been updated.
    296       name: Optional name for the returned operation.  Default to the
    297         name passed to the Optimizer constructor.
    298 
    299     Returns:
    300       An `Operation` that applies the gradients. If `global_step` was not None,
    301       that operation also increments `global_step`.
    302 
    303     Raises:
    304       ValueError: If the grads_and_vars is malformed.
    305     """
    306     grads_and_vars = self._transform_grads_fn(grads_and_vars)
    307     return self._optimizer.apply_gradients(grads_and_vars, global_step, name)
    308 
    309   def get_slot(self, *args, **kwargs):
    310     """See `tf.Optimizer`."""
    311     return self._optimizer.get_slot(*args, **kwargs)
    312 
    313   def get_slot_names(self, *args, **kwargs):
    314     """See `tf.Optimizer`."""
    315     return self._optimizer.get_slot_names(*args, **kwargs)
    316 
    317 
    318 def _verify_metric_fn_args(metric_fn):
    319   args = set(estimator_util.fn_args(metric_fn))
    320   invalid_args = list(args - _VALID_METRIC_FN_ARGS)
    321   if invalid_args:
    322     raise ValueError('metric_fn (%s) has following not expected args: %s' %
    323                      (metric_fn, invalid_args))
    324 
    325 
    326 def _call_metric_fn(metric_fn, features, labels, predictions, config):
    327   """Calls metric fn with proper arguments."""
    328   metric_fn_args = estimator_util.fn_args(metric_fn)
    329   kwargs = {}
    330   if 'features' in metric_fn_args:
    331     kwargs['features'] = features
    332   if 'labels' in metric_fn_args:
    333     kwargs['labels'] = labels
    334   if 'predictions' in metric_fn_args:
    335     kwargs['predictions'] = predictions
    336   if 'config' in metric_fn_args:
    337     kwargs['config'] = config
    338   return metric_fn(**kwargs)
    339