Home | History | Annotate | Download | only in estimators
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for Estimator input."""
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     21 import functools
     22 import tempfile
     24 import numpy as np
     26 from tensorflow.python.training import training_util
     27 from tensorflow.contrib.layers.python.layers import optimizers
     28 from tensorflow.contrib.learn.python.learn import metric_spec
     29 from tensorflow.contrib.learn.python.learn import models
     30 from tensorflow.contrib.learn.python.learn.datasets import base
     31 from tensorflow.contrib.learn.python.learn.estimators import _sklearn
     32 from tensorflow.contrib.learn.python.learn.estimators import estimator
     33 from tensorflow.contrib.learn.python.learn.estimators import model_fn
     34 from tensorflow.contrib.metrics.python.ops import metric_ops
     35 from tensorflow.python.framework import constant_op
     36 from tensorflow.python.framework import dtypes
     37 from tensorflow.python.ops import array_ops
     38 from tensorflow.python.ops import data_flow_ops
     39 from tensorflow.python.ops import math_ops
     40 from tensorflow.python.platform import test
     41 from tensorflow.python.training import input as input_lib
     42 from tensorflow.python.training import queue_runner_impl
     44 _BOSTON_INPUT_DIM = 13
     45 _IRIS_INPUT_DIM = 4
     48 def boston_input_fn(num_epochs=None):
     49   boston = base.load_boston()
     50   features = input_lib.limit_epochs(
     51       array_ops.reshape(
     52           constant_op.constant(boston.data), [-1, _BOSTON_INPUT_DIM]),
     53       num_epochs=num_epochs)
     54   labels = array_ops.reshape(constant_op.constant(boston.target), [-1, 1])
     55   return features, labels
     58 def boston_input_fn_with_queue(num_epochs=None):
     59   features, labels = boston_input_fn(num_epochs=num_epochs)
     61   # Create a minimal queue runner.
     62   fake_queue = data_flow_ops.FIFOQueue(30, dtypes.int32)
     63   queue_runner = queue_runner_impl.QueueRunner(fake_queue,
     64                                                [constant_op.constant(0)])
     65   queue_runner_impl.add_queue_runner(queue_runner)
     67   return features, labels
     70 def iris_input_fn():
     71   iris = base.load_iris()
     72   features = array_ops.reshape(
     73       constant_op.constant(iris.data), [-1, _IRIS_INPUT_DIM])
     74   labels = array_ops.reshape(constant_op.constant(iris.target), [-1])
     75   return features, labels
     78 def iris_input_fn_labels_dict():
     79   iris = base.load_iris()
     80   features = array_ops.reshape(
     81       constant_op.constant(iris.data), [-1, _IRIS_INPUT_DIM])
     82   labels = {
     83       'labels': array_ops.reshape(constant_op.constant(iris.target), [-1])
     84   }
     85   return features, labels
     88 def boston_eval_fn():
     89   boston = base.load_boston()
     90   n_examples = len(boston.target)
     91   features = array_ops.reshape(
     92       constant_op.constant(boston.data), [n_examples, _BOSTON_INPUT_DIM])
     93   labels = array_ops.reshape(
     94       constant_op.constant(boston.target), [n_examples, 1])
     95   return array_ops.concat([features, features],
     96                           0), array_ops.concat([labels, labels], 0)
     99 def extract(data, key):
    100   if isinstance(data, dict):
    101     assert key in data
    102     return data[key]
    103   else:
    104     return data
    107 def linear_model_params_fn(features, labels, mode, params):
    108   features = extract(features, 'input')
    109   labels = extract(labels, 'labels')
    111   assert mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
    112                   model_fn.ModeKeys.INFER)
    113   prediction, loss = (models.linear_regression_zero_init(features, labels))
    114   train_op = optimizers.optimize_loss(
    115       loss,
    116       training_util.get_global_step(),
    117       optimizer='Adagrad',
    118       learning_rate=params['learning_rate'])
    119   return prediction, loss, train_op
    122 def linear_model_fn(features, labels, mode):
    123   features = extract(features, 'input')
    124   labels = extract(labels, 'labels')
    125   assert mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
    126                   model_fn.ModeKeys.INFER)
    127   if isinstance(features, dict):
    128     (_, features), = features.items()
    129   prediction, loss = (models.linear_regression_zero_init(features, labels))
    130   train_op = optimizers.optimize_loss(
    131       loss,
    132       training_util.get_global_step(),
    133       optimizer='Adagrad',
    134       learning_rate=0.1)
    135   return prediction, loss, train_op
    138 def linear_model_fn_with_model_fn_ops(features, labels, mode):
    139   """Same as linear_model_fn, but returns `ModelFnOps`."""
    140   assert mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
    141                   model_fn.ModeKeys.INFER)
    142   prediction, loss = (models.linear_regression_zero_init(features, labels))
    143   train_op = optimizers.optimize_loss(
    144       loss,
    145       training_util.get_global_step(),
    146       optimizer='Adagrad',
    147       learning_rate=0.1)
    148   return model_fn.ModelFnOps(
    149       mode=mode, predictions=prediction, loss=loss, train_op=train_op)
    152 def logistic_model_no_mode_fn(features, labels):
    153   features = extract(features, 'input')
    154   labels = extract(labels, 'labels')
    155   labels = array_ops.one_hot(labels, 3, 1, 0)
    156   prediction, loss = (models.logistic_regression_zero_init(features, labels))
    157   train_op = optimizers.optimize_loss(
    158       loss,
    159       training_util.get_global_step(),
    160       optimizer='Adagrad',
    161       learning_rate=0.1)
    162   return {
    163       'class': math_ops.argmax(prediction, 1),
    164       'prob': prediction
    165   }, loss, train_op
    168 VOCAB_FILE_CONTENT = 'emerson\nlake\npalmer\n'
    169 EXTRA_FILE_CONTENT = 'kermit\npiggy\nralph\n'
    172 class EstimatorInputTest(test.TestCase):
    174   def testContinueTrainingDictionaryInput(self):
    175     boston = base.load_boston()
    176     output_dir = tempfile.mkdtemp()
    177     est = estimator.Estimator(model_fn=linear_model_fn, model_dir=output_dir)
    178     boston_input = {'input': boston.data}
    179     float64_target = {'labels': boston.target.astype(np.float64)}
    180     est.fit(x=boston_input, y=float64_target, steps=50)
    181     scores = est.evaluate(
    182         x=boston_input,
    183         y=float64_target,
    184         metrics={
    185             'MSE': metric_ops.streaming_mean_squared_error
    186         })
    187     del est
    188     # Create another estimator object with the same output dir.
    189     est2 = estimator.Estimator(model_fn=linear_model_fn, model_dir=output_dir)
    191     # Check we can evaluate and predict.
    192     scores2 = est2.evaluate(
    193         x=boston_input,
    194         y=float64_target,
    195         metrics={
    196             'MSE': metric_ops.streaming_mean_squared_error
    197         })
    198     self.assertAllClose(scores2['MSE'], scores['MSE'])
    199     predictions = np.array(list(est2.predict(x=boston_input)))
    200     other_score = _sklearn.mean_squared_error(predictions,
    201                                               float64_target['labels'])
    202     self.assertAllClose(other_score, scores['MSE'])
    204   def testBostonAll(self):
    205     boston = base.load_boston()
    206     est = estimator.SKCompat(estimator.Estimator(model_fn=linear_model_fn))
    207     float64_labels = boston.target.astype(np.float64)
    208     est.fit(x=boston.data, y=float64_labels, steps=100)
    209     scores = est.score(
    210         x=boston.data,
    211         y=float64_labels,
    212         metrics={
    213             'MSE': metric_ops.streaming_mean_squared_error
    214         })
    215     predictions = np.array(list(est.predict(x=boston.data)))
    216     other_score = _sklearn.mean_squared_error(predictions, boston.target)
    217     self.assertAllClose(scores['MSE'], other_score)
    218     self.assertTrue('global_step' in scores)
    219     self.assertEqual(100, scores['global_step'])
    221   def testBostonAllDictionaryInput(self):
    222     boston = base.load_boston()
    223     est = estimator.Estimator(model_fn=linear_model_fn)
    224     boston_input = {'input': boston.data}
    225     float64_target = {'labels': boston.target.astype(np.float64)}
    226     est.fit(x=boston_input, y=float64_target, steps=100)
    227     scores = est.evaluate(
    228         x=boston_input,
    229         y=float64_target,
    230         metrics={
    231             'MSE': metric_ops.streaming_mean_squared_error
    232         })
    233     predictions = np.array(list(est.predict(x=boston_input)))
    234     other_score = _sklearn.mean_squared_error(predictions, boston.target)
    235     self.assertAllClose(other_score, scores['MSE'])
    236     self.assertTrue('global_step' in scores)
    237     self.assertEqual(scores['global_step'], 100)
    239   def testIrisAll(self):
    240     iris = base.load_iris()
    241     est = estimator.SKCompat(
    242         estimator.Estimator(model_fn=logistic_model_no_mode_fn))
    243     est.fit(iris.data, iris.target, steps=100)
    244     scores = est.score(
    245         x=iris.data,
    246         y=iris.target,
    247         metrics={
    248             ('accuracy', 'class'): metric_ops.streaming_accuracy
    249         })
    250     predictions = est.predict(x=iris.data)
    251     predictions_class = est.predict(x=iris.data, outputs=['class'])['class']
    252     self.assertEqual(predictions['prob'].shape[0], iris.target.shape[0])
    253     self.assertAllClose(predictions['class'], predictions_class)
    254     self.assertAllClose(predictions['class'],
    255                         np.argmax(predictions['prob'], axis=1))
    256     other_score = _sklearn.accuracy_score(iris.target, predictions['class'])
    257     self.assertAllClose(scores['accuracy'], other_score)
    258     self.assertTrue('global_step' in scores)
    259     self.assertEqual(100, scores['global_step'])
    261   def testIrisAllDictionaryInput(self):
    262     iris = base.load_iris()
    263     est = estimator.Estimator(model_fn=logistic_model_no_mode_fn)
    264     iris_data = {'input': iris.data}
    265     iris_target = {'labels': iris.target}
    266     est.fit(iris_data, iris_target, steps=100)
    267     scores = est.evaluate(
    268         x=iris_data,
    269         y=iris_target,
    270         metrics={
    271             ('accuracy', 'class'): metric_ops.streaming_accuracy
    272         })
    273     predictions = list(est.predict(x=iris_data))
    274     predictions_class = list(est.predict(x=iris_data, outputs=['class']))
    275     self.assertEqual(len(predictions), iris.target.shape[0])
    276     classes_batch = np.array([p['class'] for p in predictions])
    277     self.assertAllClose(classes_batch,
    278                         np.array([p['class'] for p in predictions_class]))
    279     self.assertAllClose(classes_batch,
    280                         np.argmax(
    281                             np.array([p['prob'] for p in predictions]), axis=1))
    282     other_score = _sklearn.accuracy_score(iris.target, classes_batch)
    283     self.assertAllClose(other_score, scores['accuracy'])
    284     self.assertTrue('global_step' in scores)
    285     self.assertEqual(scores['global_step'], 100)
    287   def testIrisInputFn(self):
    288     iris = base.load_iris()
    289     est = estimator.Estimator(model_fn=logistic_model_no_mode_fn)
    290     est.fit(input_fn=iris_input_fn, steps=100)
    291     _ = est.evaluate(input_fn=iris_input_fn, steps=1)
    292     predictions = list(est.predict(x=iris.data))
    293     self.assertEqual(len(predictions), iris.target.shape[0])
    295   def testIrisInputFnLabelsDict(self):
    296     iris = base.load_iris()
    297     est = estimator.Estimator(model_fn=logistic_model_no_mode_fn)
    298     est.fit(input_fn=iris_input_fn_labels_dict, steps=100)
    299     _ = est.evaluate(
    300         input_fn=iris_input_fn_labels_dict,
    301         steps=1,
    302         metrics={
    303             'accuracy':
    304                 metric_spec.MetricSpec(
    305                     metric_fn=metric_ops.streaming_accuracy,
    306                     prediction_key='class',
    307                     label_key='labels')
    308         })
    309     predictions = list(est.predict(x=iris.data))
    310     self.assertEqual(len(predictions), iris.target.shape[0])
    312   def testTrainInputFn(self):
    313     est = estimator.Estimator(model_fn=linear_model_fn)
    314     est.fit(input_fn=boston_input_fn, steps=1)
    315     _ = est.evaluate(input_fn=boston_eval_fn, steps=1)
    317   def testPredictInputFn(self):
    318     est = estimator.Estimator(model_fn=linear_model_fn)
    319     boston = base.load_boston()
    320     est.fit(input_fn=boston_input_fn, steps=1)
    321     input_fn = functools.partial(boston_input_fn, num_epochs=1)
    322     output = list(est.predict(input_fn=input_fn))
    323     self.assertEqual(len(output), boston.target.shape[0])
    325   def testPredictInputFnWithQueue(self):
    326     est = estimator.Estimator(model_fn=linear_model_fn)
    327     boston = base.load_boston()
    328     est.fit(input_fn=boston_input_fn, steps=1)
    329     input_fn = functools.partial(boston_input_fn_with_queue, num_epochs=2)
    330     output = list(est.predict(input_fn=input_fn))
    331     self.assertEqual(len(output), boston.target.shape[0] * 2)
    333   def testPredictConstInputFn(self):
    334     est = estimator.Estimator(model_fn=linear_model_fn)
    335     boston = base.load_boston()
    336     est.fit(input_fn=boston_input_fn, steps=1)
    338     def input_fn():
    339       features = array_ops.reshape(
    340           constant_op.constant(boston.data), [-1, _BOSTON_INPUT_DIM])
    341       labels = array_ops.reshape(constant_op.constant(boston.target), [-1, 1])
    342       return features, labels
    344     output = list(est.predict(input_fn=input_fn))
    345     self.assertEqual(len(output), boston.target.shape[0])
    348 if __name__ == '__main__':
    349   test.main()