1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Abstractions for the head(s) of a model. 16 """ 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import six 22 23 from tensorflow.python.estimator import model_fn 24 from tensorflow.python.estimator.canned import head as head_lib 25 from tensorflow.python.estimator.canned import metric_keys 26 from tensorflow.python.framework import ops 27 from tensorflow.python.ops import array_ops 28 from tensorflow.python.ops import control_flow_ops 29 from tensorflow.python.ops import math_ops 30 from tensorflow.python.ops import metrics as metrics_lib 31 from tensorflow.python.saved_model import signature_constants 32 from tensorflow.python.summary import summary 33 34 35 _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 36 37 38 def multi_head(heads, head_weights=None): 39 """Creates a `_Head` for multi-objective learning. 40 41 This class merges the output of multiple `_Head` objects. 42 Specifically: 43 * For training, sums losses of each head, calls `train_op_fn` with this 44 final loss. 45 * For eval, merges metrics by adding `head.name` suffix to the keys in eval 46 metrics, such as `precision/head1`, `precision/head2`. 47 * For prediction, merges predictions and updates keys in prediction dict to a 48 2-tuple, `(head.name, prediction_key)`. Merges `export_outputs` such that 49 by default the first head is served. 50 51 Usage: 52 53 ```python 54 # In `input_fn` specify labels as a dict keyed by head name: 55 def input_fn(): 56 features = ... 57 labels1 = ... 58 labels2 = ... 59 return features, {'head1': labels1, 'head2': labels2} 60 61 # In `model_fn`, specify logits as a dict keyed by head name: 62 def model_fn(features, labels, mode): 63 # Create simple heads and specify head name. 64 head1 = multi_class_head(n_classes=3, name='head1') 65 head2 = binary_classification_head(name='head2') 66 # Create multi-head from two simple heads. 67 head = multi_head([head1, head2]) 68 # Create logits for each head, and combine them into a dict. 69 logits1, logits2 = logit_fn() 70 logits = {'head1': logits1, 'head2': logits2} 71 # Return the merged EstimatorSpec 72 return head.create_estimator_spec(..., logits=logits, ...) 73 74 # Create an estimator with this model_fn. 75 estimator = tf.estimator.Estimator(model_fn=model_fn) 76 estimator.train(input_fn=input_fn, steps=100) 77 ``` 78 79 Also supports `logits` as a `Tensor` of shape 80 `[D0, D1, ... DN, logits_dimension]`. It will split the `Tensor` along the 81 last dimension and distribute it appropriately among the heads. E.g.: 82 83 ```python 84 def model_fn(features, labels, mode): 85 # Create simple heads and specify head name. 86 head1 = multi_class_head(n_classes=3, name='head1') 87 head2 = binary_classification_head(name='head2') 88 # Create multi-head from two simple heads. 89 head = multi_head([head1, head2]) 90 # Create logits for the multihead. 91 logits = logit_fn(logits_dimension=head.logits_dimension) 92 # Return the merged EstimatorSpec 93 return head.create_estimator_spec(..., logits=logits, ...) 94 ``` 95 96 Args: 97 heads: List or tuple of `_Head` instances. All heads must have `name` 98 specified. The first head in the list is the default used at serving time. 99 head_weights: Optional list of weights, same length as `heads`. Used when 100 merging losses to calculate the weighted sum of losses from each head. If 101 `None`, all losses are weighted equally. 102 103 Returns: 104 A instance of `_Head` that merges multiple heads. 105 106 Raises: 107 ValueError: If `heads` is empty. 108 ValueError: If any of the `heads` does not have `name` specified. 109 ValueError: If `heads` and `head_weights` have different size. 110 """ 111 if head_weights: 112 if len(head_weights) != len(heads): 113 raise ValueError( 114 'heads and head_weights must have the same size. ' 115 'Given len(heads): {}. Given len(head_weights): {}.'.format( 116 len(heads), len(head_weights))) 117 if not heads: 118 raise ValueError('Must specify heads. Given: {}'.format(heads)) 119 for head in heads: 120 if not head.name: 121 raise ValueError( 122 'All given heads must have name specified. ' 123 'Given: {}'.format(head)) 124 125 return _MultiHead( 126 heads=tuple(heads), 127 head_weights=tuple(head_weights) if head_weights else tuple()) 128 129 130 def _no_op_train_fn(loss): 131 del loss 132 return control_flow_ops.no_op() 133 134 135 def _merge_losses(losses, head_weights=None): 136 """Merges the given losses into one tensor.""" 137 losses = tuple(losses) 138 with ops.name_scope( 139 'merge_losses', values=losses + (head_weights or tuple())): 140 if head_weights: 141 weighted_losses = [] 142 for loss, weight in zip(losses, head_weights): 143 weighted_losses.append(math_ops.multiply(loss, weight)) 144 else: 145 weighted_losses = losses 146 return math_ops.add_n(weighted_losses) 147 148 149 def _default_export_output(export_outputs, head_name): 150 """Extracts the default export output from the given export_outputs dict.""" 151 if len(export_outputs) == 1: 152 return next(six.itervalues(export_outputs)) 153 for k, v in six.iteritems(export_outputs): 154 if k == _DEFAULT_SERVING_KEY: 155 return v 156 raise ValueError( 157 '{} did not specify default export_outputs. ' 158 'Given: {} ' 159 'Suggested fix: Use one of the heads in tf.contrib.estimator, or include ' 160 'key {} in export_outputs.'.format( 161 head_name, export_outputs, _DEFAULT_SERVING_KEY)) 162 163 164 class _MultiHead(head_lib._Head): # pylint:disable=protected-access 165 """`_Head` for multi objective learning.""" 166 167 def __init__(self, heads, head_weights): 168 self._logits_dimension = 0 169 for head in heads: 170 self._logits_dimension += head.logits_dimension 171 172 self._heads = heads 173 self._head_weights = head_weights 174 175 @property 176 def name(self): 177 return '_'.join([h.name for h in self._heads]) 178 179 @property 180 def logits_dimension(self): 181 return self._logits_dimension 182 183 def create_loss(self, features, mode, logits, labels): 184 """See `Head`.""" 185 if isinstance(logits, dict): 186 logits_dict = logits 187 else: 188 logits_dict = self._split_logits(logits) 189 training_losses = [] 190 labels_by_head = {} 191 unreduced_losses_by_head = {} 192 example_weights_by_head = {} 193 for i, head in enumerate(self._heads): 194 (training_loss, unreduced_loss, 195 weights, processed_labels) = head.create_loss( 196 features, mode, logits_dict[head.name], labels[head.name]) 197 training_losses.append(training_loss) 198 labels_by_head[head.name] = processed_labels 199 if self._head_weights: 200 head_weight = self._head_weights[i] 201 unreduced_losses_by_head[head.name] = math_ops.multiply( 202 unreduced_loss, head_weight) 203 example_weights_by_head[head.name] = math_ops.multiply( 204 weights, head_weight) 205 else: 206 unreduced_losses_by_head[head.name] = unreduced_loss 207 example_weights_by_head[head.name] = weights 208 209 training_losses = tuple(training_losses) 210 with ops.name_scope( 211 'merge_losses', 212 values=training_losses + (self._head_weights or tuple())): 213 if self._head_weights: 214 head_weighted_training_losses = [] 215 for training_loss, head_weight in zip( 216 training_losses, self._head_weights): 217 head_weighted_training_losses.append( 218 math_ops.multiply(training_loss, head_weight)) 219 merged_training_loss = math_ops.add_n(head_weighted_training_losses) 220 else: 221 merged_training_loss = math_ops.add_n(training_losses) 222 223 return head_lib.LossSpec( 224 training_loss=merged_training_loss, 225 unreduced_loss=unreduced_losses_by_head, 226 weights=example_weights_by_head, 227 processed_labels=labels_by_head) 228 229 def create_estimator_spec( 230 self, features, mode, logits, labels=None, train_op_fn=None): 231 """See `_Head`.""" 232 if isinstance(logits, dict): 233 logits_dict = logits 234 else: 235 logits_dict = self._split_logits(logits) 236 if labels and not isinstance(labels, dict): 237 raise ValueError('labels must be a dict. Given: {}'.format(labels)) 238 239 all_estimator_spec = [] 240 for head in self._heads: 241 head_name = head.name 242 all_estimator_spec.append( 243 head.create_estimator_spec( 244 features=features, 245 mode=mode, 246 logits=logits_dict[head_name], 247 labels=labels[head_name] if labels else None, 248 train_op_fn=_no_op_train_fn)) 249 250 if mode == model_fn.ModeKeys.TRAIN: 251 if train_op_fn is None: 252 raise ValueError('train_op_fn can not be None in TRAIN mode.') 253 spec = self._merge_train(all_estimator_spec, train_op_fn) 254 with ops.name_scope(''): 255 summary.scalar(metric_keys.MetricKeys.LOSS, spec.loss) 256 return spec 257 if mode == model_fn.ModeKeys.PREDICT: 258 return self._merge_predict(all_estimator_spec) 259 if mode == model_fn.ModeKeys.EVAL: 260 return self._merge_eval(all_estimator_spec) 261 raise ValueError('mode={} unrecognized'.format(mode)) 262 263 def _split_logits(self, logits): 264 """Splits logits along the last dimension and returns a dict.""" 265 logits_dict = {} 266 with ops.name_scope(None, 'split_logits', values=[logits]): 267 logits = ops.convert_to_tensor(logits) 268 batch_shape = array_ops.shape(logits)[:-1] 269 zeros_like_batch_shape = array_ops.zeros_like(batch_shape) 270 minus_ones_like_batch_shape = -1 * array_ops.ones_like(batch_shape) 271 begin_idx = 0 272 for head in self._heads: 273 begin_tensor = array_ops.concat( 274 [zeros_like_batch_shape, [begin_idx]], axis=0) 275 size_tensor = array_ops.concat( 276 [minus_ones_like_batch_shape, [head.logits_dimension]], axis=0) 277 logits_dict[head.name] = array_ops.slice( 278 logits, begin=begin_tensor, size=size_tensor) 279 begin_idx += head.logits_dimension 280 return logits_dict 281 282 def _merge_train(self, all_estimator_spec, train_op_fn): 283 """Merges list of `EstimatorSpec` for training. 284 285 Args: 286 all_estimator_spec: list of `EstimatorSpec` for the individual heads. 287 train_op_fn: Function to create train op. See `create_estimator_spec` 288 documentation for more details. 289 290 Returns: 291 `EstimatorSpec` that merges all heads for TRAIN. 292 """ 293 losses = [] 294 metrics = {} 295 for spec in all_estimator_spec: 296 losses.append(spec.loss) 297 # Metric keys already contain head.name. 298 metrics.update(spec.eval_metric_ops or {}) 299 loss = _merge_losses(losses, self._head_weights) 300 301 return model_fn.EstimatorSpec( 302 mode=model_fn.ModeKeys.TRAIN, 303 loss=loss, 304 train_op=train_op_fn(loss), 305 eval_metric_ops=metrics) 306 307 def _merge_predict(self, all_estimator_spec): 308 """Merges list of `EstimatorSpec` for prediction. 309 310 Args: 311 all_estimator_spec: list of `EstimatorSpec` for the individual heads. 312 313 Returns: 314 `EstimatorSpec` that merges all heads for PREDICT. 315 """ 316 predictions = {} 317 export_outputs = { 318 _DEFAULT_SERVING_KEY: _default_export_output( 319 all_estimator_spec[0].export_outputs, 320 self._heads[0].name), 321 } 322 for head, spec in zip(self._heads, all_estimator_spec): 323 head_name = head.name 324 for k, v in six.iteritems(spec.export_outputs): 325 if k == _DEFAULT_SERVING_KEY: 326 key = head_name 327 else: 328 key = '%s/%s' % (k, head_name) 329 export_outputs[key] = v 330 for k, v in six.iteritems(spec.predictions): 331 predictions[(head_name, k)] = v 332 333 return model_fn.EstimatorSpec( 334 mode=model_fn.ModeKeys.PREDICT, 335 predictions=predictions, 336 export_outputs=export_outputs) 337 338 def _merge_eval(self, all_estimator_spec): 339 """Merges list of `EstimatorSpec` for eval. 340 341 Args: 342 all_estimator_spec: list of `EstimatorSpec` for the individual heads. 343 344 Returns: 345 `EstimatorSpec` that merges all heads for EVAL. 346 """ 347 predictions = {} 348 metrics = {} 349 losses = [] 350 with ops.name_scope('merge_eval'): 351 for head, spec in zip(self._heads, all_estimator_spec): 352 losses.append(spec.loss) 353 head_name = head.name 354 # Loss metric is not added by default. 355 loss_name = head_lib._summary_key( # pylint:disable=protected-access 356 head_name, metric_keys.MetricKeys.LOSS) 357 metrics[loss_name] = metrics_lib.mean(spec.loss, name=loss_name) 358 # Metric keys already contain head.name. 359 metrics.update(spec.eval_metric_ops or {}) 360 for k, v in six.iteritems(spec.predictions): 361 predictions[(head_name, k)] = v 362 loss = _merge_losses(losses, self._head_weights) 363 364 return model_fn.EstimatorSpec( 365 mode=model_fn.ModeKeys.EVAL, 366 predictions=predictions, 367 loss=loss, 368 eval_metric_ops=metrics) 369