1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Extenders of tf.estimator.Estimator.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import six 22 23 from tensorflow.python.estimator import estimator as estimator_lib 24 from tensorflow.python.estimator import model_fn as model_fn_lib 25 from tensorflow.python.estimator import util as estimator_util 26 from tensorflow.python.framework import ops 27 from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib 28 from tensorflow.python.ops import clip_ops 29 from tensorflow.python.training import optimizer as optimizer_lib 30 31 32 _VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config']) 33 34 35 def add_metrics(estimator, metric_fn): 36 """Creates a new ${tf.estimator.Estimator} which has given metrics. 37 38 Example: 39 40 ```python 41 def my_auc(labels, predictions): 42 return {'auc': tf.metrics.auc(labels, predictions['logistic'])} 43 44 estimator = tf.estimator.DNNClassifier(...) 45 estimator = tf.contrib.estimator.add_metrics(estimator, my_auc) 46 estimator.train(...) 47 estimator.evaluate(...) 48 ``` 49 Example usage of custom metric which uses features: 50 51 ```python 52 def my_auc(features, labels, predictions): 53 return {'auc': tf.metrics.auc( 54 labels, predictions['logistic'], weights=features['weight'])} 55 56 estimator = tf.estimator.DNNClassifier(...) 57 estimator = tf.contrib.estimator.add_metrics(estimator, my_auc) 58 estimator.train(...) 59 estimator.evaluate(...) 60 ``` 61 62 Args: 63 estimator: A ${tf.estimator.Estimator} object. 64 metric_fn: A function which should obey the following signature: 65 - Args: can only have following four arguments in any order: 66 * predictions: Predictions `Tensor` or dict of `Tensor` created by given 67 `estimator`. 68 * features: Input `dict` of `Tensor` objects created by `input_fn` which 69 is given to `estimator.evaluate` as an argument. 70 * labels: Labels `Tensor` or dict of `Tensor` created by `input_fn` 71 which is given to `estimator.evaluate` as an argument. 72 * config: config attribute of the `estimator`. 73 - Returns: 74 Dict of metric results keyed by name. Final metrics are a union of this 75 and `estimator's` existing metrics. If there is a name conflict between 76 this and `estimator`s existing metrics, this will override the existing 77 one. The values of the dict are the results of calling a metric 78 function, namely a `(metric_tensor, update_op)` tuple. 79 80 Returns: 81 A new ${tf.estimator.Estimator} which has a union of original metrics with 82 given ones. 83 """ 84 _verify_metric_fn_args(metric_fn) 85 86 def new_model_fn(features, labels, mode, config): 87 spec = estimator.model_fn(features, labels, mode, config) 88 if mode != model_fn_lib.ModeKeys.EVAL: 89 return spec 90 new_metrics = _call_metric_fn(metric_fn, features, labels, spec.predictions, 91 config) 92 all_metrics = spec.eval_metric_ops or {} 93 all_metrics.update(new_metrics) 94 return spec._replace(eval_metric_ops=all_metrics) 95 96 return estimator_lib.Estimator( 97 model_fn=new_model_fn, 98 model_dir=estimator.model_dir, 99 config=estimator.config) 100 101 102 def clip_gradients_by_norm(optimizer, clip_norm): 103 """Returns an optimizer which clips gradients before applying them. 104 105 Example: 106 107 ```python 108 optimizer = tf.train.ProximalAdagradOptimizer( 109 learning_rate=0.1, 110 l1_regularization_strength=0.001) 111 optimizer = tf.contrib.estimator.clip_gradients_by_norm( 112 optimizer, clip_norm) 113 estimator = tf.estimator.DNNClassifier( 114 feature_columns=[...], 115 hidden_units=[1024, 512, 256], 116 optimizer=optimizer) 117 ``` 118 119 Args: 120 optimizer: An `tf.Optimizer` object to apply gradients. 121 clip_norm: A 0-D (scalar) `Tensor` > 0. The clipping ratio. 122 123 Returns: 124 A `tf.Optimizer`. 125 """ 126 127 def clip_grads(grads_and_vars): 128 gradients, variables = zip(*grads_and_vars) 129 gradients = clip_ops.clip_by_global_norm(gradients, clip_norm)[0] 130 grads_and_vars = list(zip(gradients, variables)) 131 return grads_and_vars 132 133 return _TransformGradients( 134 optimizer=optimizer, 135 transform_grads_fn=clip_grads, 136 name='ClipByNorm' + optimizer.get_name()) 137 138 139 def forward_features(estimator, keys=None): 140 """Forward features to predictions dictionary. 141 142 In some cases, user wants to see some of the features in estimators prediction 143 output. As an example, consider a batch prediction service: The service simply 144 runs inference on the users graph and returns the results. Keys are essential 145 because there is no order guarantee on the outputs so they need to be rejoined 146 to the inputs via keys or transclusion of the inputs in the outputs. 147 148 Example: 149 150 ```python 151 def input_fn(): 152 features, labels = ... 153 features['unique_example_id'] = ... 154 features, labels 155 156 estimator = tf.estimator.LinearClassifier(...) 157 estimator = tf.contrib.estimator.forward_features( 158 estimator, 'unique_example_id') 159 estimator.train(...) 160 assert 'unique_example_id' in estimator.predict(...) 161 ``` 162 163 Args: 164 estimator: A ${tf.estimator.Estimator} object. 165 keys: a `string` or a `list` of `string`. If it is `None`, all of the 166 `features` in `dict` is forwarded to the `predictions`. If it is a 167 `string`, only given key is forwarded. If it is a `list` of strings, all 168 the given `keys` are forwarded. 169 170 Returns: 171 A new ${tf.estimator.Estimator} which forwards features to predictions. 172 173 Raises: 174 ValueError: 175 * if `keys` is already part of `predictions`. We don't allow 176 override. 177 * if 'keys' does not exist in `features`. 178 * if feature key refers to a `SparseTensor`, since we don't support 179 `SparseTensor` in `predictions`. `SparseTensor` is common in `features`. 180 TypeError: if `keys` type is not one of `string` or list/tuple of `string`. 181 """ 182 183 def verify_key_types(keys): # pylint: disable=missing-docstring 184 if keys is None: 185 return keys 186 if isinstance(keys, six.string_types): 187 return [keys] 188 if not isinstance(keys, (list, tuple)): 189 raise TypeError('keys should be either a string or a list of strings. ' 190 'Given: {}'.format(type(keys))) 191 for key in keys: 192 if not isinstance(key, six.string_types): 193 raise TypeError('All items in the given keys list should be a string. ' 194 'There exist an item with type: {}'.format(type(key))) 195 return keys 196 197 def get_keys(features): 198 if keys is None: 199 return features.keys() 200 return keys 201 202 def verify_keys_and_predictions(features, predictions): 203 if not isinstance(predictions, dict): 204 raise ValueError( 205 'Predictions should be a dict to be able to forward features. ' 206 'Given: {}'.format(type(predictions))) 207 for key in get_keys(features): 208 if key not in features: 209 raise ValueError( 210 'keys should be exist in features. Key "{}" is not in features ' 211 'dict. features dict has following keys: {}. Please check ' 212 'arguments of forward_features.'.format(key, features.keys())) 213 if key in predictions: 214 raise ValueError( 215 'Cannot forward feature key ({}). Since it does exist in ' 216 'predictions. Existing prediction keys: {}. Please check arguments ' 217 'of forward_features.'.format(key, predictions.keys())) 218 219 keys = verify_key_types(keys) 220 221 def new_model_fn(features, labels, mode, config): # pylint: disable=missing-docstring 222 spec = estimator.model_fn(features, labels, mode, config) 223 predictions = spec.predictions 224 if predictions is None: 225 return spec 226 verify_keys_and_predictions(features, predictions) 227 for key in get_keys(features): 228 feature = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( 229 features[key]) 230 if not isinstance(feature, ops.Tensor): 231 raise ValueError( 232 'Forwarded feature ({}) should be a Tensor. Please use keys ' 233 'argument of forward_features to filter unwanted features. Type of ' 234 'features[{}] is {}.'.format(key, key, type(feature))) 235 predictions[key] = feature 236 return spec._replace(predictions=predictions) 237 238 return estimator_lib.Estimator( 239 model_fn=new_model_fn, 240 model_dir=estimator.model_dir, 241 config=estimator.config) 242 243 244 class _TransformGradients(optimizer_lib.Optimizer): 245 """Add given gradient transformation to the optimizer.""" 246 247 def __init__(self, optimizer, transform_grads_fn, name=None): 248 """Construct an `tf.Optimizer` wrapper to apply given transformations. 249 250 Example: 251 252 ```python 253 optimizer = tf.train.ProximalAdagradOptimizer( 254 learning_rate=0.1, 255 l1_regularization_strength=0.001) 256 def clip_grads(grads_and_vars): 257 gradients, variables = zip(*grads_and_vars) 258 gradients = tf.clip_by_global_norm(grads, my_norm)[0] 259 grads_and_vars = list(zip(gradients, variables)) 260 return grads_and_vars 261 optimizer = _TransformGradients( 262 opt=optimizer, transform_grads_fn=clip_grads) 263 estimator = tf.estimator.DNNClassifier( 264 feature_columns=[...], 265 hidden_units=[1024, 512, 256], 266 optimizer=optimizer) 267 ``` 268 269 Args: 270 optimizer: An `tf.Optimizer` object to apply gradients. 271 transform_grads_fn: A function which takes a single argument, a list of 272 gradient to variable pairs (tuples), performs any requested gradient 273 updates, such as gradient clipping or multipliers, and returns the 274 updated list. 275 name: A string which will be used for debugging purposes. 276 """ 277 super(_TransformGradients, self).__init__( 278 use_locking=False, name=name or optimizer.get_name()) 279 self._optimizer = optimizer 280 self._transform_grads_fn = transform_grads_fn 281 282 def compute_gradients(self, *args, **kwargs): 283 """See `tf.Optimizer`.""" 284 return self._optimizer.compute_gradients(*args, **kwargs) 285 286 def apply_gradients(self, grads_and_vars, global_step=None, name=None): 287 """Apply gradients to variables. 288 289 Calls `transform_grads_fn`, and then applies the real optimizer. 290 291 Args: 292 grads_and_vars: List of (gradient, variable) pairs as returned by 293 compute_gradients(). 294 global_step: Optional Variable to increment by one after the 295 variables have been updated. 296 name: Optional name for the returned operation. Default to the 297 name passed to the Optimizer constructor. 298 299 Returns: 300 An `Operation` that applies the gradients. If `global_step` was not None, 301 that operation also increments `global_step`. 302 303 Raises: 304 ValueError: If the grads_and_vars is malformed. 305 """ 306 grads_and_vars = self._transform_grads_fn(grads_and_vars) 307 return self._optimizer.apply_gradients(grads_and_vars, global_step, name) 308 309 def get_slot(self, *args, **kwargs): 310 """See `tf.Optimizer`.""" 311 return self._optimizer.get_slot(*args, **kwargs) 312 313 def get_slot_names(self, *args, **kwargs): 314 """See `tf.Optimizer`.""" 315 return self._optimizer.get_slot_names(*args, **kwargs) 316 317 318 def _verify_metric_fn_args(metric_fn): 319 args = set(estimator_util.fn_args(metric_fn)) 320 invalid_args = list(args - _VALID_METRIC_FN_ARGS) 321 if invalid_args: 322 raise ValueError('metric_fn (%s) has following not expected args: %s' % 323 (metric_fn, invalid_args)) 324 325 326 def _call_metric_fn(metric_fn, features, labels, predictions, config): 327 """Calls metric fn with proper arguments.""" 328 metric_fn_args = estimator_util.fn_args(metric_fn) 329 kwargs = {} 330 if 'features' in metric_fn_args: 331 kwargs['features'] = features 332 if 'labels' in metric_fn_args: 333 kwargs['labels'] = labels 334 if 'predictions' in metric_fn_args: 335 kwargs['predictions'] = predictions 336 if 'config' in metric_fn_args: 337 kwargs['config'] = config 338 return metric_fn(**kwargs) 339