Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Functions to bridge `Distribution`s and `tf.contrib.learn.estimator` APIs."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.learn.python.learn.estimators.head import _compute_weighted_loss
     22 from tensorflow.contrib.learn.python.learn.estimators.head import _RegressionHead
     23 from tensorflow.python.framework import ops
     24 from tensorflow.python.framework import tensor_util
     25 from tensorflow.python.ops import array_ops
     26 
     27 
     28 __all__ = [
     29     "estimator_head_distribution_regression",
     30 ]
     31 
     32 
     33 def estimator_head_distribution_regression(make_distribution_fn,
     34                                            label_dimension=1,
     35                                            logits_dimension=None,
     36                                            label_name=None,
     37                                            weight_column_name=None,
     38                                            enable_centered_bias=False,
     39                                            head_name=None):
     40   """Creates a `Head` for regression under a generic distribution.
     41 
     42   Args:
     43     make_distribution_fn: Python `callable` which returns a `tf.Distribution`
     44       instance created using only logits.
     45     label_dimension: Number of regression labels per example. This is the size
     46       of the last dimension of the labels `Tensor` (typically, this has shape
     47       `[batch_size, label_dimension]`).
     48     logits_dimension: Number of logits per example. This is the size of the last
     49       dimension of the logits `Tensor` (typically, this has shape
     50       `[batch_size, logits_dimension]`).
     51       Default value: `label_dimension`.
     52     label_name: Python `str`, name of the key in label `dict`. Can be `None` if
     53       label is a `Tensor` (single headed models).
     54     weight_column_name: Python `str` defining feature column name representing
     55       weights. It is used to down weight or boost examples during training. It
     56       will be multiplied by the loss of the example.
     57     enable_centered_bias: Python `bool`. If `True`, estimator will learn a
     58       centered bias variable for each class. Rest of the model structure learns
     59       the residual after centered bias.
     60     head_name: Python `str`, name of the head. Predictions, summary and metrics
     61       keys are suffixed by `"/" + head_name` and the default variable scope is
     62       `head_name`.
     63 
     64   Returns:
     65     An instance of `Head` for generic regression.
     66   """
     67   return _DistributionRegressionHead(
     68       make_distribution_fn=make_distribution_fn,
     69       label_dimension=label_dimension,
     70       logits_dimension=logits_dimension,
     71       label_name=label_name,
     72       weight_column_name=weight_column_name,
     73       enable_centered_bias=enable_centered_bias,
     74       head_name=head_name)
     75 
     76 
     77 class _DistributionRegressionHead(_RegressionHead):
     78   """Creates a _RegressionHead instance from an arbitray `Distribution`."""
     79 
     80   def __init__(self,
     81                make_distribution_fn,
     82                label_dimension,
     83                logits_dimension=None,
     84                label_name=None,
     85                weight_column_name=None,
     86                enable_centered_bias=False,
     87                head_name=None):
     88     """`Head` for regression.
     89 
     90     Args:
     91       make_distribution_fn: Python `callable` which returns a `tf.Distribution`
     92         instance created using only logits.
     93       label_dimension: Number of regression labels per example. This is the
     94         size of the last dimension of the labels `Tensor` (typically, this has
     95         shape `[batch_size, label_dimension]`).
     96       logits_dimension: Number of logits per example. This is the size of the
     97         last dimension of the logits `Tensor` (typically, this has shape
     98         `[batch_size, logits_dimension]`).
     99         Default value: `label_dimension`.
    100       label_name: Python `str`, name of the key in label `dict`. Can be `None`
    101         if label is a tensor (single headed models).
    102       weight_column_name: Python `str` defining feature column name representing
    103         weights. It is used to down weight or boost examples during training. It
    104         will be multiplied by the loss of the example.
    105       enable_centered_bias: Python `bool`. If `True`, estimator will learn a
    106         centered bias variable for each class. Rest of the model structure
    107         learns the residual after centered bias.
    108       head_name: Python `str`, name of the head. Predictions, summary and
    109         metrics keys are suffixed by `"/" + head_name` and the default variable
    110         scope is `head_name`.
    111 
    112     Raises:
    113       TypeError: if `make_distribution_fn` is not `callable`.
    114     """
    115     if not callable(make_distribution_fn):
    116       raise TypeError("`make_distribution_fn` must be a callable function.")
    117 
    118     self._distributions = {}
    119     self._make_distribution_fn = make_distribution_fn
    120 
    121     def static_value(x):
    122       """Returns the static value of a `Tensor` or `None`."""
    123       return tensor_util.constant_value(ops.convert_to_tensor(x))
    124 
    125     def concat_vectors(*args):
    126       """Concatenates input vectors, statically if possible."""
    127       args_ = [static_value(x) for x in args]
    128       if any(vec is None for vec in args_):
    129         return array_ops.concat(args, axis=0)
    130       return [val for vec in args_ for val in vec]
    131 
    132     def loss_fn(labels, logits, weights=None):
    133       """Returns the loss of using `logits` to predict `labels`."""
    134       d = self.distribution(logits)
    135       labels_batch_shape = labels.shape.with_rank_at_least(1)[:-1]
    136       labels_batch_shape = (
    137           labels_batch_shape.as_list() if labels_batch_shape.is_fully_defined()
    138           else array_ops.shape(labels)[:-1])
    139       labels = array_ops.reshape(
    140           labels,
    141           shape=concat_vectors(labels_batch_shape, d.event_shape_tensor()))
    142       return _compute_weighted_loss(
    143           loss_unweighted=-d.log_prob(labels),
    144           weight=weights)
    145 
    146     def link_fn(logits):
    147       """Returns the inverse link function at `logits`."""
    148       # Note: What the API calls a "link function" is really the inverse-link
    149       # function, i.e., the "mean".
    150       d = self.distribution(logits)
    151       return d.mean()
    152 
    153     super(_DistributionRegressionHead, self).__init__(
    154         label_dimension=label_dimension,
    155         loss_fn=loss_fn,
    156         link_fn=link_fn,
    157         logits_dimension=logits_dimension,
    158         label_name=label_name,
    159         weight_column_name=weight_column_name,
    160         enable_centered_bias=enable_centered_bias,
    161         head_name=head_name)
    162 
    163   @property
    164   def distributions(self):
    165     """Returns all distributions created by `DistributionRegressionHead`."""
    166     return self._distributions
    167 
    168   def distribution(self, logits, name=None):
    169     """Retrieves a distribution instance, parameterized by `logits`.
    170 
    171     Args:
    172       logits: `float`-like `Tensor` representing the parameters of the
    173         underlying distribution.
    174       name: The Python `str` name to given to this op.
    175         Default value: "distribution".
    176 
    177     Returns:
    178       distribution: `tf.Distribution` instance parameterized by `logits`.
    179     """
    180     with ops.name_scope(name, "distribution", [logits]):
    181       d = self._distributions.get(logits, None)
    182       if d is None:
    183         d = self._make_distribution_fn(logits)
    184         self._distributions[logits] = d
    185       return d
    186