Home | History | Annotate | Download | only in estimator
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for dnn.py."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import shutil
     22 import tempfile
     23 
     24 import numpy as np
     25 import six
     26 
     27 from tensorflow.contrib.estimator.python.estimator import dnn
     28 from tensorflow.contrib.estimator.python.estimator import head as head_lib
     29 from tensorflow.python.estimator.canned import dnn_testing_utils
     30 from tensorflow.python.estimator.canned import prediction_keys
     31 from tensorflow.python.estimator.export import export
     32 from tensorflow.python.estimator.inputs import numpy_io
     33 from tensorflow.python.feature_column import feature_column
     34 from tensorflow.python.framework import ops
     35 from tensorflow.python.platform import gfile
     36 from tensorflow.python.platform import test
     37 from tensorflow.python.summary.writer import writer_cache
     38 
     39 
     40 def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs):
     41   """Returns a DNNEstimator that uses regression_head."""
     42   return dnn.DNNEstimator(
     43       head=head_lib.regression_head(
     44           weight_column=weight_column, label_dimension=label_dimension),
     45       *args, **kwargs)
     46 
     47 
     48 class DNNEstimatorEvaluateTest(
     49     dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase):
     50 
     51   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     52     test.TestCase.__init__(self, methodName)
     53     dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
     54         self, _dnn_estimator_fn)
     55 
     56 
     57 class DNNEstimatorPredictTest(
     58     dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase):
     59 
     60   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     61     test.TestCase.__init__(self, methodName)
     62     dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
     63         self, _dnn_estimator_fn)
     64 
     65 
     66 class DNNEstimatorTrainTest(
     67     dnn_testing_utils.BaseDNNRegressorTrainTest, test.TestCase):
     68 
     69   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     70     test.TestCase.__init__(self, methodName)
     71     dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
     72         self, _dnn_estimator_fn)
     73 
     74 
     75 class DNNEstimatorIntegrationTest(test.TestCase):
     76 
     77   def setUp(self):
     78     self._model_dir = tempfile.mkdtemp()
     79 
     80   def tearDown(self):
     81     if self._model_dir:
     82       writer_cache.FileWriterCache.clear()
     83       shutil.rmtree(self._model_dir)
     84 
     85   def _test_complete_flow(
     86       self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
     87       label_dimension, batch_size):
     88     feature_columns = [
     89         feature_column.numeric_column('x', shape=(input_dimension,))]
     90     est = dnn.DNNEstimator(
     91         head=head_lib.regression_head(label_dimension=label_dimension),
     92         hidden_units=(2, 2),
     93         feature_columns=feature_columns,
     94         model_dir=self._model_dir)
     95 
     96     # TRAIN
     97     num_steps = 10
     98     est.train(train_input_fn, steps=num_steps)
     99 
    100     # EVALUTE
    101     scores = est.evaluate(eval_input_fn)
    102     self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
    103     self.assertIn('loss', six.iterkeys(scores))
    104 
    105     # PREDICT
    106     predictions = np.array([
    107         x[prediction_keys.PredictionKeys.PREDICTIONS]
    108         for x in est.predict(predict_input_fn)
    109     ])
    110     self.assertAllEqual((batch_size, label_dimension), predictions.shape)
    111 
    112     # EXPORT
    113     feature_spec = feature_column.make_parse_example_spec(feature_columns)
    114     serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
    115         feature_spec)
    116     export_dir = est.export_savedmodel(tempfile.mkdtemp(),
    117                                        serving_input_receiver_fn)
    118     self.assertTrue(gfile.Exists(export_dir))
    119 
    120   def test_numpy_input_fn(self):
    121     """Tests complete flow with numpy_input_fn."""
    122     label_dimension = 2
    123     batch_size = 10
    124     data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
    125     data = data.reshape(batch_size, label_dimension)
    126     # learn y = x
    127     train_input_fn = numpy_io.numpy_input_fn(
    128         x={'x': data},
    129         y=data,
    130         batch_size=batch_size,
    131         num_epochs=None,
    132         shuffle=True)
    133     eval_input_fn = numpy_io.numpy_input_fn(
    134         x={'x': data},
    135         y=data,
    136         batch_size=batch_size,
    137         shuffle=False)
    138     predict_input_fn = numpy_io.numpy_input_fn(
    139         x={'x': data},
    140         batch_size=batch_size,
    141         shuffle=False)
    142 
    143     self._test_complete_flow(
    144         train_input_fn=train_input_fn,
    145         eval_input_fn=eval_input_fn,
    146         predict_input_fn=predict_input_fn,
    147         input_dimension=label_dimension,
    148         label_dimension=label_dimension,
    149         batch_size=batch_size)
    150 
    151 
    152 if __name__ == '__main__':
    153   test.main()
    154