Home | History | Annotate | Download | only in estimator
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for linear.py."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import shutil
     22 import tempfile
     23 
     24 import numpy as np
     25 import six
     26 
     27 from tensorflow.contrib.estimator.python.estimator import head as head_lib
     28 from tensorflow.contrib.estimator.python.estimator import linear
     29 from tensorflow.python.estimator.canned import linear_testing_utils
     30 from tensorflow.python.estimator.canned import prediction_keys
     31 from tensorflow.python.estimator.export import export
     32 from tensorflow.python.estimator.inputs import numpy_io
     33 from tensorflow.python.feature_column import feature_column
     34 from tensorflow.python.framework import ops
     35 from tensorflow.python.platform import gfile
     36 from tensorflow.python.platform import test
     37 from tensorflow.python.summary.writer import writer_cache
     38 
     39 
     40 def _linear_estimator_fn(
     41     weight_column=None, label_dimension=1, *args, **kwargs):
     42   """Returns a LinearEstimator that uses regression_head."""
     43   return linear.LinearEstimator(
     44       head=head_lib.regression_head(
     45           weight_column=weight_column, label_dimension=label_dimension),
     46       *args, **kwargs)
     47 
     48 
     49 class LinearEstimatorEvaluateTest(
     50     linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase):
     51 
     52   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     53     test.TestCase.__init__(self, methodName)
     54     linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(
     55         self, _linear_estimator_fn)
     56 
     57 
     58 class LinearEstimatorPredictTest(
     59     linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase):
     60 
     61   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     62     test.TestCase.__init__(self, methodName)
     63     linear_testing_utils.BaseLinearRegressorPredictTest.__init__(
     64         self, _linear_estimator_fn)
     65 
     66 
     67 class LinearEstimatorTrainTest(
     68     linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase):
     69 
     70   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     71     test.TestCase.__init__(self, methodName)
     72     linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(
     73         self, _linear_estimator_fn)
     74 
     75 
     76 class LinearEstimatorIntegrationTest(test.TestCase):
     77 
     78   def setUp(self):
     79     self._model_dir = tempfile.mkdtemp()
     80 
     81   def tearDown(self):
     82     if self._model_dir:
     83       writer_cache.FileWriterCache.clear()
     84       shutil.rmtree(self._model_dir)
     85 
     86   def _test_complete_flow(
     87       self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
     88       label_dimension, batch_size):
     89     feature_columns = [
     90         feature_column.numeric_column('x', shape=(input_dimension,))]
     91     est = linear.LinearEstimator(
     92         head=head_lib.regression_head(label_dimension=label_dimension),
     93         feature_columns=feature_columns,
     94         model_dir=self._model_dir)
     95 
     96     # TRAIN
     97     num_steps = 10
     98     est.train(train_input_fn, steps=num_steps)
     99 
    100     # EVALUTE
    101     scores = est.evaluate(eval_input_fn)
    102     self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
    103     self.assertIn('loss', six.iterkeys(scores))
    104 
    105     # PREDICT
    106     predictions = np.array([
    107         x[prediction_keys.PredictionKeys.PREDICTIONS]
    108         for x in est.predict(predict_input_fn)
    109     ])
    110     self.assertAllEqual((batch_size, label_dimension), predictions.shape)
    111 
    112     # EXPORT
    113     feature_spec = feature_column.make_parse_example_spec(feature_columns)
    114     serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
    115         feature_spec)
    116     export_dir = est.export_savedmodel(tempfile.mkdtemp(),
    117                                        serving_input_receiver_fn)
    118     self.assertTrue(gfile.Exists(export_dir))
    119 
    120   def test_numpy_input_fn(self):
    121     """Tests complete flow with numpy_input_fn."""
    122     label_dimension = 2
    123     batch_size = 10
    124     data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
    125     data = data.reshape(batch_size, label_dimension)
    126     # learn y = x
    127     train_input_fn = numpy_io.numpy_input_fn(
    128         x={'x': data},
    129         y=data,
    130         batch_size=batch_size,
    131         num_epochs=None,
    132         shuffle=True)
    133     eval_input_fn = numpy_io.numpy_input_fn(
    134         x={'x': data},
    135         y=data,
    136         batch_size=batch_size,
    137         shuffle=False)
    138     predict_input_fn = numpy_io.numpy_input_fn(
    139         x={'x': data},
    140         batch_size=batch_size,
    141         shuffle=False)
    142 
    143     self._test_complete_flow(
    144         train_input_fn=train_input_fn,
    145         eval_input_fn=eval_input_fn,
    146         predict_input_fn=predict_input_fn,
    147         input_dimension=label_dimension,
    148         label_dimension=label_dimension,
    149         batch_size=batch_size)
    150 
    151 
    152 if __name__ == '__main__':
    153   test.main()
    154