Home | History | Annotate | Download | only in estimator
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """extenders tests."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.contrib.estimator.python.estimator import extenders
     24 from tensorflow.python.data.ops import dataset_ops
     25 from tensorflow.python.estimator import estimator_lib
     26 from tensorflow.python.estimator.canned import linear
     27 from tensorflow.python.feature_column import feature_column as fc
     28 from tensorflow.python.framework import constant_op
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.framework import sparse_tensor
     31 from tensorflow.python.ops import metrics as metrics_lib
     32 from tensorflow.python.ops import variables
     33 from tensorflow.python.platform import test
     34 from tensorflow.python.training import training
     35 
     36 
     37 def get_input_fn(x, y):
     38 
     39   def input_fn():
     40     dataset = dataset_ops.Dataset.from_tensor_slices({'x': x, 'y': y})
     41     iterator = dataset.make_one_shot_iterator()
     42     features = iterator.get_next()
     43     labels = features.pop('y')
     44     return features, labels
     45 
     46   return input_fn
     47 
     48 
     49 class AddMetricsTest(test.TestCase):
     50 
     51   def test_should_add_metrics(self):
     52     input_fn = get_input_fn(
     53         x=np.arange(4)[:, None, None], y=np.ones(4)[:, None])
     54     estimator = linear.LinearClassifier([fc.numeric_column('x')])
     55 
     56     def metric_fn(features):
     57       return {'mean_x': metrics_lib.mean(features['x'])}
     58 
     59     estimator = extenders.add_metrics(estimator, metric_fn)
     60 
     61     estimator.train(input_fn=input_fn)
     62     metrics = estimator.evaluate(input_fn=input_fn)
     63     self.assertIn('mean_x', metrics)
     64     self.assertEqual(1.5, metrics['mean_x'])
     65     # assert that it keeps original estimators metrics
     66     self.assertIn('auc', metrics)
     67 
     68   def test_should_error_out_for_not_recognized_args(self):
     69     estimator = linear.LinearClassifier([fc.numeric_column('x')])
     70 
     71     def metric_fn(features, not_recognized):
     72       _, _ = features, not_recognized
     73       return {}
     74 
     75     with self.assertRaisesRegexp(ValueError, 'not_recognized'):
     76       estimator = extenders.add_metrics(estimator, metric_fn)
     77 
     78   def test_all_supported_args(self):
     79     input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]])
     80     estimator = linear.LinearClassifier([fc.numeric_column('x')])
     81 
     82     def metric_fn(features, predictions, labels, config):
     83       self.assertIn('x', features)
     84       self.assertIsNotNone(labels)
     85       self.assertIn('logistic', predictions)
     86       self.assertTrue(isinstance(config, estimator_lib.RunConfig))
     87       return {}
     88 
     89     estimator = extenders.add_metrics(estimator, metric_fn)
     90 
     91     estimator.train(input_fn=input_fn)
     92     estimator.evaluate(input_fn=input_fn)
     93 
     94   def test_all_supported_args_in_different_order(self):
     95     input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]])
     96     estimator = linear.LinearClassifier([fc.numeric_column('x')])
     97 
     98     def metric_fn(labels, config, features, predictions):
     99       self.assertIn('x', features)
    100       self.assertIsNotNone(labels)
    101       self.assertIn('logistic', predictions)
    102       self.assertTrue(isinstance(config, estimator_lib.RunConfig))
    103       return {}
    104 
    105     estimator = extenders.add_metrics(estimator, metric_fn)
    106 
    107     estimator.train(input_fn=input_fn)
    108     estimator.evaluate(input_fn=input_fn)
    109 
    110   def test_all_args_are_optional(self):
    111     input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]])
    112     estimator = linear.LinearClassifier([fc.numeric_column('x')])
    113 
    114     def metric_fn():
    115       return {'two': metrics_lib.mean(constant_op.constant([2.]))}
    116 
    117     estimator = extenders.add_metrics(estimator, metric_fn)
    118 
    119     estimator.train(input_fn=input_fn)
    120     metrics = estimator.evaluate(input_fn=input_fn)
    121     self.assertEqual(2., metrics['two'])
    122 
    123   def test_overrides_existing_metrics(self):
    124     input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]])
    125     estimator = linear.LinearClassifier([fc.numeric_column('x')])
    126     estimator.train(input_fn=input_fn)
    127     metrics = estimator.evaluate(input_fn=input_fn)
    128     self.assertNotEqual(2., metrics['auc'])
    129 
    130     def metric_fn():
    131       return {'auc': metrics_lib.mean(constant_op.constant([2.]))}
    132 
    133     estimator = extenders.add_metrics(estimator, metric_fn)
    134     metrics = estimator.evaluate(input_fn=input_fn)
    135     self.assertEqual(2., metrics['auc'])
    136 
    137 
    138 class ClipGradientsByNormTest(test.TestCase):
    139   """Tests clip_gradients_by_norm."""
    140 
    141   def test_applies_norm(self):
    142     optimizer = extenders.clip_gradients_by_norm(
    143         training.GradientDescentOptimizer(1.0), clip_norm=3.)
    144     with ops.Graph().as_default():
    145       w = variables.Variable(1., name='weight')
    146       x = constant_op.constant(5.)
    147       y = -x * w
    148       grads = optimizer.compute_gradients(y, var_list=[w])[0]
    149       opt_op = optimizer.minimize(y, var_list=[w])
    150       with training.MonitoredSession() as sess:
    151         grads_value = sess.run(grads)
    152         self.assertEqual(-5., grads_value[0])
    153         sess.run(opt_op)
    154         new_w = sess.run(w)
    155         self.assertEqual(4., new_w)  # 1 + 1*3 (w - lr * clipped_grad)
    156 
    157   def test_name(self):
    158     optimizer = extenders.clip_gradients_by_norm(
    159         training.GradientDescentOptimizer(1.0), clip_norm=3.)
    160     self.assertEqual('ClipByNormGradientDescent', optimizer.get_name())
    161 
    162 
    163 class ForwardFeaturesTest(test.TestCase):
    164   """Tests forward_features."""
    165 
    166   def test_forward_single_key(self):
    167 
    168     def input_fn():
    169       return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]]
    170 
    171     estimator = linear.LinearRegressor([fc.numeric_column('x')])
    172     estimator.train(input_fn=input_fn, steps=1)
    173 
    174     self.assertNotIn('id', next(estimator.predict(input_fn=input_fn)))
    175     estimator = extenders.forward_features(estimator, 'id')
    176     predictions = next(estimator.predict(input_fn=input_fn))
    177     self.assertIn('id', predictions)
    178     self.assertEqual(101, predictions['id'])
    179 
    180   def test_forward_list(self):
    181 
    182     def input_fn():
    183       return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]]
    184 
    185     estimator = linear.LinearRegressor([fc.numeric_column('x')])
    186     estimator.train(input_fn=input_fn, steps=1)
    187 
    188     self.assertNotIn('id', next(estimator.predict(input_fn=input_fn)))
    189     estimator = extenders.forward_features(estimator, ['x', 'id'])
    190     predictions = next(estimator.predict(input_fn=input_fn))
    191     self.assertIn('id', predictions)
    192     self.assertIn('x', predictions)
    193     self.assertEqual(101, predictions['id'])
    194     self.assertEqual(3., predictions['x'])
    195 
    196   def test_forward_all(self):
    197 
    198     def input_fn():
    199       return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]]
    200 
    201     estimator = linear.LinearRegressor([fc.numeric_column('x')])
    202     estimator.train(input_fn=input_fn, steps=1)
    203 
    204     self.assertNotIn('id', next(estimator.predict(input_fn=input_fn)))
    205     self.assertNotIn('x', next(estimator.predict(input_fn=input_fn)))
    206     estimator = extenders.forward_features(estimator)
    207     predictions = next(estimator.predict(input_fn=input_fn))
    208     self.assertIn('id', predictions)
    209     self.assertIn('x', predictions)
    210     self.assertEqual(101, predictions['id'])
    211     self.assertEqual(3., predictions['x'])
    212 
    213   def test_key_should_be_string(self):
    214     estimator = linear.LinearRegressor([fc.numeric_column('x')])
    215     with self.assertRaisesRegexp(TypeError, 'keys should be either a string'):
    216       extenders.forward_features(estimator, estimator)
    217 
    218   def test_key_should_be_list_of_string(self):
    219     estimator = linear.LinearRegressor([fc.numeric_column('x')])
    220     with self.assertRaisesRegexp(TypeError, 'should be a string'):
    221       extenders.forward_features(estimator, ['x', estimator])
    222 
    223   def test_key_should_be_in_features(self):
    224 
    225     def input_fn():
    226       return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]]
    227 
    228     estimator = linear.LinearRegressor([fc.numeric_column('x')])
    229     estimator.train(input_fn=input_fn, steps=1)
    230 
    231     estimator = extenders.forward_features(estimator, 'y')
    232     with self.assertRaisesRegexp(ValueError,
    233                                  'keys should be exist in features'):
    234       next(estimator.predict(input_fn=input_fn))
    235 
    236   def test_forwarded_feature_should_not_be_a_sparse_tensor(self):
    237 
    238     def input_fn():
    239       return {
    240           'x': [[3.], [5.]],
    241           'id':
    242               sparse_tensor.SparseTensor(
    243                   values=['1', '2'],
    244                   indices=[[0, 0], [1, 0]],
    245                   dense_shape=[2, 1])
    246       }, [[1.], [2.]]
    247 
    248     estimator = linear.LinearRegressor([fc.numeric_column('x')])
    249     estimator.train(input_fn=input_fn, steps=1)
    250 
    251     estimator = extenders.forward_features(estimator)
    252     with self.assertRaisesRegexp(ValueError,
    253                                  'Forwarded feature.* should be a Tensor.'):
    254       next(estimator.predict(input_fn=input_fn))
    255 
    256   def test_predictions_should_be_dict(self):
    257 
    258     def input_fn():
    259       return {'x': [[3.], [5.]], 'id': [[101], [102]]}
    260 
    261     def model_fn(features, mode):
    262       del features
    263       global_step = training.get_global_step()
    264       return estimator_lib.EstimatorSpec(
    265           mode,
    266           loss=constant_op.constant([5.]),
    267           predictions=constant_op.constant([5.]),
    268           train_op=global_step.assign_add(1))
    269 
    270     estimator = estimator_lib.Estimator(model_fn=model_fn)
    271     estimator.train(input_fn=input_fn, steps=1)
    272 
    273     estimator = extenders.forward_features(estimator)
    274     with self.assertRaisesRegexp(ValueError, 'Predictions should be a dict'):
    275       next(estimator.predict(input_fn=input_fn))
    276 
    277   def test_should_not_conflict_with_existing_predictions(self):
    278 
    279     def input_fn():
    280       return {'x': [[3.], [5.]], 'id': [[101], [102]]}
    281 
    282     def model_fn(features, mode):
    283       del features
    284       global_step = training.get_global_step()
    285       return estimator_lib.EstimatorSpec(
    286           mode,
    287           loss=constant_op.constant([5.]),
    288           predictions={'x': constant_op.constant([5.])},
    289           train_op=global_step.assign_add(1))
    290 
    291     estimator = estimator_lib.Estimator(model_fn=model_fn)
    292     estimator.train(input_fn=input_fn, steps=1)
    293 
    294     estimator = extenders.forward_features(estimator)
    295     with self.assertRaisesRegexp(ValueError, 'Cannot forward feature key'):
    296       next(estimator.predict(input_fn=input_fn))
    297 
    298 
    299 if __name__ == '__main__':
    300   test.main()
    301