1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Functions to bridge `Distribution`s and `tf.contrib.learn.estimator` APIs.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib.learn.python.learn.estimators.head import _compute_weighted_loss 22 from tensorflow.contrib.learn.python.learn.estimators.head import _RegressionHead 23 from tensorflow.python.framework import ops 24 from tensorflow.python.framework import tensor_util 25 from tensorflow.python.ops import array_ops 26 27 28 __all__ = [ 29 "estimator_head_distribution_regression", 30 ] 31 32 33 def estimator_head_distribution_regression(make_distribution_fn, 34 label_dimension=1, 35 logits_dimension=None, 36 label_name=None, 37 weight_column_name=None, 38 enable_centered_bias=False, 39 head_name=None): 40 """Creates a `Head` for regression under a generic distribution. 41 42 Args: 43 make_distribution_fn: Python `callable` which returns a `tf.Distribution` 44 instance created using only logits. 45 label_dimension: Number of regression labels per example. This is the size 46 of the last dimension of the labels `Tensor` (typically, this has shape 47 `[batch_size, label_dimension]`). 48 logits_dimension: Number of logits per example. This is the size of the last 49 dimension of the logits `Tensor` (typically, this has shape 50 `[batch_size, logits_dimension]`). 51 Default value: `label_dimension`. 52 label_name: Python `str`, name of the key in label `dict`. Can be `None` if 53 label is a `Tensor` (single headed models). 54 weight_column_name: Python `str` defining feature column name representing 55 weights. It is used to down weight or boost examples during training. It 56 will be multiplied by the loss of the example. 57 enable_centered_bias: Python `bool`. If `True`, estimator will learn a 58 centered bias variable for each class. Rest of the model structure learns 59 the residual after centered bias. 60 head_name: Python `str`, name of the head. Predictions, summary and metrics 61 keys are suffixed by `"/" + head_name` and the default variable scope is 62 `head_name`. 63 64 Returns: 65 An instance of `Head` for generic regression. 66 """ 67 return _DistributionRegressionHead( 68 make_distribution_fn=make_distribution_fn, 69 label_dimension=label_dimension, 70 logits_dimension=logits_dimension, 71 label_name=label_name, 72 weight_column_name=weight_column_name, 73 enable_centered_bias=enable_centered_bias, 74 head_name=head_name) 75 76 77 class _DistributionRegressionHead(_RegressionHead): 78 """Creates a _RegressionHead instance from an arbitray `Distribution`.""" 79 80 def __init__(self, 81 make_distribution_fn, 82 label_dimension, 83 logits_dimension=None, 84 label_name=None, 85 weight_column_name=None, 86 enable_centered_bias=False, 87 head_name=None): 88 """`Head` for regression. 89 90 Args: 91 make_distribution_fn: Python `callable` which returns a `tf.Distribution` 92 instance created using only logits. 93 label_dimension: Number of regression labels per example. This is the 94 size of the last dimension of the labels `Tensor` (typically, this has 95 shape `[batch_size, label_dimension]`). 96 logits_dimension: Number of logits per example. This is the size of the 97 last dimension of the logits `Tensor` (typically, this has shape 98 `[batch_size, logits_dimension]`). 99 Default value: `label_dimension`. 100 label_name: Python `str`, name of the key in label `dict`. Can be `None` 101 if label is a tensor (single headed models). 102 weight_column_name: Python `str` defining feature column name representing 103 weights. It is used to down weight or boost examples during training. It 104 will be multiplied by the loss of the example. 105 enable_centered_bias: Python `bool`. If `True`, estimator will learn a 106 centered bias variable for each class. Rest of the model structure 107 learns the residual after centered bias. 108 head_name: Python `str`, name of the head. Predictions, summary and 109 metrics keys are suffixed by `"/" + head_name` and the default variable 110 scope is `head_name`. 111 112 Raises: 113 TypeError: if `make_distribution_fn` is not `callable`. 114 """ 115 if not callable(make_distribution_fn): 116 raise TypeError("`make_distribution_fn` must be a callable function.") 117 118 self._distributions = {} 119 self._make_distribution_fn = make_distribution_fn 120 121 def static_value(x): 122 """Returns the static value of a `Tensor` or `None`.""" 123 return tensor_util.constant_value(ops.convert_to_tensor(x)) 124 125 def concat_vectors(*args): 126 """Concatenates input vectors, statically if possible.""" 127 args_ = [static_value(x) for x in args] 128 if any(vec is None for vec in args_): 129 return array_ops.concat(args, axis=0) 130 return [val for vec in args_ for val in vec] 131 132 def loss_fn(labels, logits, weights=None): 133 """Returns the loss of using `logits` to predict `labels`.""" 134 d = self.distribution(logits) 135 labels_batch_shape = labels.shape.with_rank_at_least(1)[:-1] 136 labels_batch_shape = ( 137 labels_batch_shape.as_list() if labels_batch_shape.is_fully_defined() 138 else array_ops.shape(labels)[:-1]) 139 labels = array_ops.reshape( 140 labels, 141 shape=concat_vectors(labels_batch_shape, d.event_shape_tensor())) 142 return _compute_weighted_loss( 143 loss_unweighted=-d.log_prob(labels), 144 weight=weights) 145 146 def link_fn(logits): 147 """Returns the inverse link function at `logits`.""" 148 # Note: What the API calls a "link function" is really the inverse-link 149 # function, i.e., the "mean". 150 d = self.distribution(logits) 151 return d.mean() 152 153 super(_DistributionRegressionHead, self).__init__( 154 label_dimension=label_dimension, 155 loss_fn=loss_fn, 156 link_fn=link_fn, 157 logits_dimension=logits_dimension, 158 label_name=label_name, 159 weight_column_name=weight_column_name, 160 enable_centered_bias=enable_centered_bias, 161 head_name=head_name) 162 163 @property 164 def distributions(self): 165 """Returns all distributions created by `DistributionRegressionHead`.""" 166 return self._distributions 167 168 def distribution(self, logits, name=None): 169 """Retrieves a distribution instance, parameterized by `logits`. 170 171 Args: 172 logits: `float`-like `Tensor` representing the parameters of the 173 underlying distribution. 174 name: The Python `str` name to given to this op. 175 Default value: "distribution". 176 177 Returns: 178 distribution: `tf.Distribution` instance parameterized by `logits`. 179 """ 180 with ops.name_scope(name, "distribution", [logits]): 181 d = self._distributions.get(logits, None) 182 if d is None: 183 d = self._make_distribution_fn(logits) 184 self._distributions[logits] = d 185 return d 186