Home | History | Annotate | Download | only in estimator
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Abstractions for the head(s) of a model.
     16 """
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import six
     22 
     23 from tensorflow.python.estimator import model_fn
     24 from tensorflow.python.estimator.canned import head as head_lib
     25 from tensorflow.python.estimator.canned import metric_keys
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import control_flow_ops
     29 from tensorflow.python.ops import math_ops
     30 from tensorflow.python.ops import metrics as metrics_lib
     31 from tensorflow.python.saved_model import signature_constants
     32 from tensorflow.python.summary import summary
     33 
     34 
     35 _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
     36 
     37 
     38 def multi_head(heads, head_weights=None):
     39   """Creates a `_Head` for multi-objective learning.
     40 
     41   This class merges the output of multiple `_Head` objects.
     42   Specifically:
     43   * For training, sums losses of each head, calls `train_op_fn` with this
     44     final loss.
     45   * For eval, merges metrics by adding `head.name` suffix to the keys in eval
     46     metrics, such as `precision/head1`, `precision/head2`.
     47   * For prediction, merges predictions and updates keys in prediction dict to a
     48     2-tuple, `(head.name, prediction_key)`. Merges `export_outputs` such that
     49     by default the first head is served.
     50 
     51   Usage:
     52 
     53   ```python
     54   # In `input_fn` specify labels as a dict keyed by head name:
     55   def input_fn():
     56     features = ...
     57     labels1 = ...
     58     labels2 = ...
     59     return features, {'head1': labels1, 'head2': labels2}
     60 
     61   # In `model_fn`, specify logits as a dict keyed by head name:
     62   def model_fn(features, labels, mode):
     63     # Create simple heads and specify head name.
     64     head1 = multi_class_head(n_classes=3, name='head1')
     65     head2 = binary_classification_head(name='head2')
     66     # Create multi-head from two simple heads.
     67     head = multi_head([head1, head2])
     68     # Create logits for each head, and combine them into a dict.
     69     logits1, logits2 = logit_fn()
     70     logits = {'head1': logits1, 'head2': logits2}
     71     # Return the merged EstimatorSpec
     72     return head.create_estimator_spec(..., logits=logits, ...)
     73 
     74   # Create an estimator with this model_fn.
     75   estimator = tf.estimator.Estimator(model_fn=model_fn)
     76   estimator.train(input_fn=input_fn, steps=100)
     77   ```
     78 
     79   Also supports `logits` as a `Tensor` of shape
     80   `[D0, D1, ... DN, logits_dimension]`. It will split the `Tensor` along the
     81   last dimension and distribute it appropriately among the heads. E.g.:
     82 
     83   ```python
     84   def model_fn(features, labels, mode):
     85     # Create simple heads and specify head name.
     86     head1 = multi_class_head(n_classes=3, name='head1')
     87     head2 = binary_classification_head(name='head2')
     88     # Create multi-head from two simple heads.
     89     head = multi_head([head1, head2])
     90     # Create logits for the multihead.
     91     logits = logit_fn(logits_dimension=head.logits_dimension)
     92     # Return the merged EstimatorSpec
     93     return head.create_estimator_spec(..., logits=logits, ...)
     94   ```
     95 
     96   Args:
     97     heads: List or tuple of `_Head` instances. All heads must have `name`
     98       specified. The first head in the list is the default used at serving time.
     99     head_weights: Optional list of weights, same length as `heads`. Used when
    100       merging losses to calculate the weighted sum of losses from each head. If
    101       `None`, all losses are weighted equally.
    102 
    103   Returns:
    104     A instance of `_Head` that merges multiple heads.
    105 
    106   Raises:
    107     ValueError: If `heads` is empty.
    108     ValueError: If any of the `heads` does not have `name` specified.
    109     ValueError: If `heads` and `head_weights` have different size.
    110   """
    111   if head_weights:
    112     if len(head_weights) != len(heads):
    113       raise ValueError(
    114           'heads and head_weights must have the same size. '
    115           'Given len(heads): {}. Given len(head_weights): {}.'.format(
    116               len(heads), len(head_weights)))
    117   if not heads:
    118     raise ValueError('Must specify heads. Given: {}'.format(heads))
    119   for head in heads:
    120     if not head.name:
    121       raise ValueError(
    122           'All given heads must have name specified. '
    123           'Given: {}'.format(head))
    124 
    125   return _MultiHead(
    126       heads=tuple(heads),
    127       head_weights=tuple(head_weights) if head_weights else tuple())
    128 
    129 
    130 def _no_op_train_fn(loss):
    131   del loss
    132   return control_flow_ops.no_op()
    133 
    134 
    135 def _merge_losses(losses, head_weights=None):
    136   """Merges the given losses into one tensor."""
    137   losses = tuple(losses)
    138   with ops.name_scope(
    139       'merge_losses', values=losses + (head_weights or tuple())):
    140     if head_weights:
    141       weighted_losses = []
    142       for loss, weight in zip(losses, head_weights):
    143         weighted_losses.append(math_ops.multiply(loss, weight))
    144     else:
    145       weighted_losses = losses
    146     return math_ops.add_n(weighted_losses)
    147 
    148 
    149 def _default_export_output(export_outputs, head_name):
    150   """Extracts the default export output from the given export_outputs dict."""
    151   if len(export_outputs) == 1:
    152     return next(six.itervalues(export_outputs))
    153   for k, v in six.iteritems(export_outputs):
    154     if k == _DEFAULT_SERVING_KEY:
    155       return v
    156   raise ValueError(
    157       '{} did not specify default export_outputs. '
    158       'Given: {} '
    159       'Suggested fix: Use one of the heads in tf.contrib.estimator, or include '
    160       'key {} in export_outputs.'.format(
    161           head_name, export_outputs, _DEFAULT_SERVING_KEY))
    162 
    163 
    164 class _MultiHead(head_lib._Head):  # pylint:disable=protected-access
    165   """`_Head` for multi objective learning."""
    166 
    167   def __init__(self, heads, head_weights):
    168     self._logits_dimension = 0
    169     for head in heads:
    170       self._logits_dimension += head.logits_dimension
    171 
    172     self._heads = heads
    173     self._head_weights = head_weights
    174 
    175   @property
    176   def name(self):
    177     return '_'.join([h.name for h in self._heads])
    178 
    179   @property
    180   def logits_dimension(self):
    181     return self._logits_dimension
    182 
    183   def create_loss(self, features, mode, logits, labels):
    184     """See `Head`."""
    185     if isinstance(logits, dict):
    186       logits_dict = logits
    187     else:
    188       logits_dict = self._split_logits(logits)
    189     training_losses = []
    190     labels_by_head = {}
    191     unreduced_losses_by_head = {}
    192     example_weights_by_head = {}
    193     for i, head in enumerate(self._heads):
    194       (training_loss, unreduced_loss,
    195        weights, processed_labels) = head.create_loss(
    196            features, mode, logits_dict[head.name], labels[head.name])
    197       training_losses.append(training_loss)
    198       labels_by_head[head.name] = processed_labels
    199       if self._head_weights:
    200         head_weight = self._head_weights[i]
    201         unreduced_losses_by_head[head.name] = math_ops.multiply(
    202             unreduced_loss, head_weight)
    203         example_weights_by_head[head.name] = math_ops.multiply(
    204             weights, head_weight)
    205       else:
    206         unreduced_losses_by_head[head.name] = unreduced_loss
    207         example_weights_by_head[head.name] = weights
    208 
    209     training_losses = tuple(training_losses)
    210     with ops.name_scope(
    211         'merge_losses',
    212         values=training_losses + (self._head_weights or tuple())):
    213       if self._head_weights:
    214         head_weighted_training_losses = []
    215         for training_loss, head_weight in zip(
    216             training_losses, self._head_weights):
    217           head_weighted_training_losses.append(
    218               math_ops.multiply(training_loss, head_weight))
    219         merged_training_loss = math_ops.add_n(head_weighted_training_losses)
    220       else:
    221         merged_training_loss = math_ops.add_n(training_losses)
    222 
    223     return head_lib.LossSpec(
    224         training_loss=merged_training_loss,
    225         unreduced_loss=unreduced_losses_by_head,
    226         weights=example_weights_by_head,
    227         processed_labels=labels_by_head)
    228 
    229   def create_estimator_spec(
    230       self, features, mode, logits, labels=None, train_op_fn=None):
    231     """See `_Head`."""
    232     if isinstance(logits, dict):
    233       logits_dict = logits
    234     else:
    235       logits_dict = self._split_logits(logits)
    236     if labels and not isinstance(labels, dict):
    237       raise ValueError('labels must be a dict. Given: {}'.format(labels))
    238 
    239     all_estimator_spec = []
    240     for head in self._heads:
    241       head_name = head.name
    242       all_estimator_spec.append(
    243           head.create_estimator_spec(
    244               features=features,
    245               mode=mode,
    246               logits=logits_dict[head_name],
    247               labels=labels[head_name] if labels else None,
    248               train_op_fn=_no_op_train_fn))
    249 
    250     if mode == model_fn.ModeKeys.TRAIN:
    251       if train_op_fn is None:
    252         raise ValueError('train_op_fn can not be None in TRAIN mode.')
    253       spec = self._merge_train(all_estimator_spec, train_op_fn)
    254       with ops.name_scope(''):
    255         summary.scalar(metric_keys.MetricKeys.LOSS, spec.loss)
    256       return spec
    257     if mode == model_fn.ModeKeys.PREDICT:
    258       return self._merge_predict(all_estimator_spec)
    259     if mode == model_fn.ModeKeys.EVAL:
    260       return self._merge_eval(all_estimator_spec)
    261     raise ValueError('mode={} unrecognized'.format(mode))
    262 
    263   def _split_logits(self, logits):
    264     """Splits logits along the last dimension and returns a dict."""
    265     logits_dict = {}
    266     with ops.name_scope(None, 'split_logits', values=[logits]):
    267       logits = ops.convert_to_tensor(logits)
    268       batch_shape = array_ops.shape(logits)[:-1]
    269       zeros_like_batch_shape = array_ops.zeros_like(batch_shape)
    270       minus_ones_like_batch_shape = -1 * array_ops.ones_like(batch_shape)
    271       begin_idx = 0
    272       for head in self._heads:
    273         begin_tensor = array_ops.concat(
    274             [zeros_like_batch_shape, [begin_idx]], axis=0)
    275         size_tensor = array_ops.concat(
    276             [minus_ones_like_batch_shape, [head.logits_dimension]], axis=0)
    277         logits_dict[head.name] = array_ops.slice(
    278             logits, begin=begin_tensor, size=size_tensor)
    279         begin_idx += head.logits_dimension
    280     return logits_dict
    281 
    282   def _merge_train(self, all_estimator_spec, train_op_fn):
    283     """Merges list of `EstimatorSpec` for training.
    284 
    285     Args:
    286       all_estimator_spec: list of `EstimatorSpec` for the individual heads.
    287       train_op_fn: Function to create train op. See `create_estimator_spec`
    288         documentation for more details.
    289 
    290     Returns:
    291       `EstimatorSpec` that merges all heads for TRAIN.
    292     """
    293     losses = []
    294     metrics = {}
    295     for spec in all_estimator_spec:
    296       losses.append(spec.loss)
    297       # Metric keys already contain head.name.
    298       metrics.update(spec.eval_metric_ops or {})
    299     loss = _merge_losses(losses, self._head_weights)
    300 
    301     return model_fn.EstimatorSpec(
    302         mode=model_fn.ModeKeys.TRAIN,
    303         loss=loss,
    304         train_op=train_op_fn(loss),
    305         eval_metric_ops=metrics)
    306 
    307   def _merge_predict(self, all_estimator_spec):
    308     """Merges list of `EstimatorSpec` for prediction.
    309 
    310     Args:
    311       all_estimator_spec: list of `EstimatorSpec` for the individual heads.
    312 
    313     Returns:
    314       `EstimatorSpec` that merges all heads for PREDICT.
    315     """
    316     predictions = {}
    317     export_outputs = {
    318         _DEFAULT_SERVING_KEY: _default_export_output(
    319             all_estimator_spec[0].export_outputs,
    320             self._heads[0].name),
    321     }
    322     for head, spec in zip(self._heads, all_estimator_spec):
    323       head_name = head.name
    324       for k, v in six.iteritems(spec.export_outputs):
    325         if k == _DEFAULT_SERVING_KEY:
    326           key = head_name
    327         else:
    328           key = '%s/%s' % (k, head_name)
    329         export_outputs[key] = v
    330       for k, v in six.iteritems(spec.predictions):
    331         predictions[(head_name, k)] = v
    332 
    333     return model_fn.EstimatorSpec(
    334         mode=model_fn.ModeKeys.PREDICT,
    335         predictions=predictions,
    336         export_outputs=export_outputs)
    337 
    338   def _merge_eval(self, all_estimator_spec):
    339     """Merges list of `EstimatorSpec` for eval.
    340 
    341     Args:
    342       all_estimator_spec: list of `EstimatorSpec` for the individual heads.
    343 
    344     Returns:
    345       `EstimatorSpec` that merges all heads for EVAL.
    346     """
    347     predictions = {}
    348     metrics = {}
    349     losses = []
    350     with ops.name_scope('merge_eval'):
    351       for head, spec in zip(self._heads, all_estimator_spec):
    352         losses.append(spec.loss)
    353         head_name = head.name
    354         # Loss metric is not added by default.
    355         loss_name = head_lib._summary_key(  # pylint:disable=protected-access
    356             head_name, metric_keys.MetricKeys.LOSS)
    357         metrics[loss_name] = metrics_lib.mean(spec.loss, name=loss_name)
    358         # Metric keys already contain head.name.
    359         metrics.update(spec.eval_metric_ops or {})
    360         for k, v in six.iteritems(spec.predictions):
    361           predictions[(head_name, k)] = v
    362       loss = _merge_losses(losses, self._head_weights)
    363 
    364     return model_fn.EstimatorSpec(
    365         mode=model_fn.ModeKeys.EVAL,
    366         predictions=predictions,
    367         loss=loss,
    368         eval_metric_ops=metrics)
    369