1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """extenders tests.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.contrib.estimator.python.estimator import extenders 24 from tensorflow.python.data.ops import dataset_ops 25 from tensorflow.python.estimator import estimator_lib 26 from tensorflow.python.estimator.canned import linear 27 from tensorflow.python.feature_column import feature_column as fc 28 from tensorflow.python.framework import constant_op 29 from tensorflow.python.framework import ops 30 from tensorflow.python.framework import sparse_tensor 31 from tensorflow.python.ops import metrics as metrics_lib 32 from tensorflow.python.ops import variables 33 from tensorflow.python.platform import test 34 from tensorflow.python.training import training 35 36 37 def get_input_fn(x, y): 38 39 def input_fn(): 40 dataset = dataset_ops.Dataset.from_tensor_slices({'x': x, 'y': y}) 41 iterator = dataset.make_one_shot_iterator() 42 features = iterator.get_next() 43 labels = features.pop('y') 44 return features, labels 45 46 return input_fn 47 48 49 class AddMetricsTest(test.TestCase): 50 51 def test_should_add_metrics(self): 52 input_fn = get_input_fn( 53 x=np.arange(4)[:, None, None], y=np.ones(4)[:, None]) 54 estimator = linear.LinearClassifier([fc.numeric_column('x')]) 55 56 def metric_fn(features): 57 return {'mean_x': metrics_lib.mean(features['x'])} 58 59 estimator = extenders.add_metrics(estimator, metric_fn) 60 61 estimator.train(input_fn=input_fn) 62 metrics = estimator.evaluate(input_fn=input_fn) 63 self.assertIn('mean_x', metrics) 64 self.assertEqual(1.5, metrics['mean_x']) 65 # assert that it keeps original estimators metrics 66 self.assertIn('auc', metrics) 67 68 def test_should_error_out_for_not_recognized_args(self): 69 estimator = linear.LinearClassifier([fc.numeric_column('x')]) 70 71 def metric_fn(features, not_recognized): 72 _, _ = features, not_recognized 73 return {} 74 75 with self.assertRaisesRegexp(ValueError, 'not_recognized'): 76 estimator = extenders.add_metrics(estimator, metric_fn) 77 78 def test_all_supported_args(self): 79 input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]]) 80 estimator = linear.LinearClassifier([fc.numeric_column('x')]) 81 82 def metric_fn(features, predictions, labels, config): 83 self.assertIn('x', features) 84 self.assertIsNotNone(labels) 85 self.assertIn('logistic', predictions) 86 self.assertTrue(isinstance(config, estimator_lib.RunConfig)) 87 return {} 88 89 estimator = extenders.add_metrics(estimator, metric_fn) 90 91 estimator.train(input_fn=input_fn) 92 estimator.evaluate(input_fn=input_fn) 93 94 def test_all_supported_args_in_different_order(self): 95 input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]]) 96 estimator = linear.LinearClassifier([fc.numeric_column('x')]) 97 98 def metric_fn(labels, config, features, predictions): 99 self.assertIn('x', features) 100 self.assertIsNotNone(labels) 101 self.assertIn('logistic', predictions) 102 self.assertTrue(isinstance(config, estimator_lib.RunConfig)) 103 return {} 104 105 estimator = extenders.add_metrics(estimator, metric_fn) 106 107 estimator.train(input_fn=input_fn) 108 estimator.evaluate(input_fn=input_fn) 109 110 def test_all_args_are_optional(self): 111 input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]]) 112 estimator = linear.LinearClassifier([fc.numeric_column('x')]) 113 114 def metric_fn(): 115 return {'two': metrics_lib.mean(constant_op.constant([2.]))} 116 117 estimator = extenders.add_metrics(estimator, metric_fn) 118 119 estimator.train(input_fn=input_fn) 120 metrics = estimator.evaluate(input_fn=input_fn) 121 self.assertEqual(2., metrics['two']) 122 123 def test_overrides_existing_metrics(self): 124 input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]]) 125 estimator = linear.LinearClassifier([fc.numeric_column('x')]) 126 estimator.train(input_fn=input_fn) 127 metrics = estimator.evaluate(input_fn=input_fn) 128 self.assertNotEqual(2., metrics['auc']) 129 130 def metric_fn(): 131 return {'auc': metrics_lib.mean(constant_op.constant([2.]))} 132 133 estimator = extenders.add_metrics(estimator, metric_fn) 134 metrics = estimator.evaluate(input_fn=input_fn) 135 self.assertEqual(2., metrics['auc']) 136 137 138 class ClipGradientsByNormTest(test.TestCase): 139 """Tests clip_gradients_by_norm.""" 140 141 def test_applies_norm(self): 142 optimizer = extenders.clip_gradients_by_norm( 143 training.GradientDescentOptimizer(1.0), clip_norm=3.) 144 with ops.Graph().as_default(): 145 w = variables.Variable(1., name='weight') 146 x = constant_op.constant(5.) 147 y = -x * w 148 grads = optimizer.compute_gradients(y, var_list=[w])[0] 149 opt_op = optimizer.minimize(y, var_list=[w]) 150 with training.MonitoredSession() as sess: 151 grads_value = sess.run(grads) 152 self.assertEqual(-5., grads_value[0]) 153 sess.run(opt_op) 154 new_w = sess.run(w) 155 self.assertEqual(4., new_w) # 1 + 1*3 (w - lr * clipped_grad) 156 157 def test_name(self): 158 optimizer = extenders.clip_gradients_by_norm( 159 training.GradientDescentOptimizer(1.0), clip_norm=3.) 160 self.assertEqual('ClipByNormGradientDescent', optimizer.get_name()) 161 162 163 class ForwardFeaturesTest(test.TestCase): 164 """Tests forward_features.""" 165 166 def test_forward_single_key(self): 167 168 def input_fn(): 169 return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]] 170 171 estimator = linear.LinearRegressor([fc.numeric_column('x')]) 172 estimator.train(input_fn=input_fn, steps=1) 173 174 self.assertNotIn('id', next(estimator.predict(input_fn=input_fn))) 175 estimator = extenders.forward_features(estimator, 'id') 176 predictions = next(estimator.predict(input_fn=input_fn)) 177 self.assertIn('id', predictions) 178 self.assertEqual(101, predictions['id']) 179 180 def test_forward_list(self): 181 182 def input_fn(): 183 return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]] 184 185 estimator = linear.LinearRegressor([fc.numeric_column('x')]) 186 estimator.train(input_fn=input_fn, steps=1) 187 188 self.assertNotIn('id', next(estimator.predict(input_fn=input_fn))) 189 estimator = extenders.forward_features(estimator, ['x', 'id']) 190 predictions = next(estimator.predict(input_fn=input_fn)) 191 self.assertIn('id', predictions) 192 self.assertIn('x', predictions) 193 self.assertEqual(101, predictions['id']) 194 self.assertEqual(3., predictions['x']) 195 196 def test_forward_all(self): 197 198 def input_fn(): 199 return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]] 200 201 estimator = linear.LinearRegressor([fc.numeric_column('x')]) 202 estimator.train(input_fn=input_fn, steps=1) 203 204 self.assertNotIn('id', next(estimator.predict(input_fn=input_fn))) 205 self.assertNotIn('x', next(estimator.predict(input_fn=input_fn))) 206 estimator = extenders.forward_features(estimator) 207 predictions = next(estimator.predict(input_fn=input_fn)) 208 self.assertIn('id', predictions) 209 self.assertIn('x', predictions) 210 self.assertEqual(101, predictions['id']) 211 self.assertEqual(3., predictions['x']) 212 213 def test_key_should_be_string(self): 214 estimator = linear.LinearRegressor([fc.numeric_column('x')]) 215 with self.assertRaisesRegexp(TypeError, 'keys should be either a string'): 216 extenders.forward_features(estimator, estimator) 217 218 def test_key_should_be_list_of_string(self): 219 estimator = linear.LinearRegressor([fc.numeric_column('x')]) 220 with self.assertRaisesRegexp(TypeError, 'should be a string'): 221 extenders.forward_features(estimator, ['x', estimator]) 222 223 def test_key_should_be_in_features(self): 224 225 def input_fn(): 226 return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]] 227 228 estimator = linear.LinearRegressor([fc.numeric_column('x')]) 229 estimator.train(input_fn=input_fn, steps=1) 230 231 estimator = extenders.forward_features(estimator, 'y') 232 with self.assertRaisesRegexp(ValueError, 233 'keys should be exist in features'): 234 next(estimator.predict(input_fn=input_fn)) 235 236 def test_forwarded_feature_should_not_be_a_sparse_tensor(self): 237 238 def input_fn(): 239 return { 240 'x': [[3.], [5.]], 241 'id': 242 sparse_tensor.SparseTensor( 243 values=['1', '2'], 244 indices=[[0, 0], [1, 0]], 245 dense_shape=[2, 1]) 246 }, [[1.], [2.]] 247 248 estimator = linear.LinearRegressor([fc.numeric_column('x')]) 249 estimator.train(input_fn=input_fn, steps=1) 250 251 estimator = extenders.forward_features(estimator) 252 with self.assertRaisesRegexp(ValueError, 253 'Forwarded feature.* should be a Tensor.'): 254 next(estimator.predict(input_fn=input_fn)) 255 256 def test_predictions_should_be_dict(self): 257 258 def input_fn(): 259 return {'x': [[3.], [5.]], 'id': [[101], [102]]} 260 261 def model_fn(features, mode): 262 del features 263 global_step = training.get_global_step() 264 return estimator_lib.EstimatorSpec( 265 mode, 266 loss=constant_op.constant([5.]), 267 predictions=constant_op.constant([5.]), 268 train_op=global_step.assign_add(1)) 269 270 estimator = estimator_lib.Estimator(model_fn=model_fn) 271 estimator.train(input_fn=input_fn, steps=1) 272 273 estimator = extenders.forward_features(estimator) 274 with self.assertRaisesRegexp(ValueError, 'Predictions should be a dict'): 275 next(estimator.predict(input_fn=input_fn)) 276 277 def test_should_not_conflict_with_existing_predictions(self): 278 279 def input_fn(): 280 return {'x': [[3.], [5.]], 'id': [[101], [102]]} 281 282 def model_fn(features, mode): 283 del features 284 global_step = training.get_global_step() 285 return estimator_lib.EstimatorSpec( 286 mode, 287 loss=constant_op.constant([5.]), 288 predictions={'x': constant_op.constant([5.])}, 289 train_op=global_step.assign_add(1)) 290 291 estimator = estimator_lib.Estimator(model_fn=model_fn) 292 estimator.train(input_fn=input_fn, steps=1) 293 294 estimator = extenders.forward_features(estimator) 295 with self.assertRaisesRegexp(ValueError, 'Cannot forward feature key'): 296 next(estimator.predict(input_fn=input_fn)) 297 298 299 if __name__ == '__main__': 300 test.main() 301