Home | History | Annotate | Download | only in estimator
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for dnn_linear_combined.py."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import shutil
     22 import tempfile
     23 
     24 import numpy as np
     25 import six
     26 
     27 from tensorflow.contrib.estimator.python.estimator import dnn_linear_combined
     28 from tensorflow.contrib.estimator.python.estimator import head as head_lib
     29 from tensorflow.python.estimator.canned import dnn_testing_utils
     30 from tensorflow.python.estimator.canned import linear_testing_utils
     31 from tensorflow.python.estimator.canned import prediction_keys
     32 from tensorflow.python.estimator.export import export
     33 from tensorflow.python.estimator.inputs import numpy_io
     34 from tensorflow.python.feature_column import feature_column
     35 from tensorflow.python.framework import ops
     36 from tensorflow.python.ops import nn
     37 from tensorflow.python.platform import gfile
     38 from tensorflow.python.platform import test
     39 from tensorflow.python.summary.writer import writer_cache
     40 
     41 
     42 def _dnn_only_estimator_fn(
     43     hidden_units,
     44     feature_columns,
     45     model_dir=None,
     46     label_dimension=1,
     47     weight_column=None,
     48     optimizer='Adagrad',
     49     activation_fn=nn.relu,
     50     dropout=None,
     51     input_layer_partitioner=None,
     52     config=None):
     53   return dnn_linear_combined.DNNLinearCombinedEstimator(
     54       head=head_lib.regression_head(
     55           weight_column=weight_column, label_dimension=label_dimension),
     56       model_dir=model_dir,
     57       dnn_feature_columns=feature_columns,
     58       dnn_optimizer=optimizer,
     59       dnn_hidden_units=hidden_units,
     60       dnn_activation_fn=activation_fn,
     61       dnn_dropout=dropout,
     62       input_layer_partitioner=input_layer_partitioner,
     63       config=config)
     64 
     65 
     66 class DNNOnlyEstimatorEvaluateTest(
     67     dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase):
     68 
     69   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     70     test.TestCase.__init__(self, methodName)
     71     dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
     72         self, _dnn_only_estimator_fn)
     73 
     74 
     75 class DNNOnlyEstimatorPredictTest(
     76     dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase):
     77 
     78   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     79     test.TestCase.__init__(self, methodName)
     80     dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
     81         self, _dnn_only_estimator_fn)
     82 
     83 
     84 class DNNOnlyEstimatorTrainTest(
     85     dnn_testing_utils.BaseDNNRegressorTrainTest, test.TestCase):
     86 
     87   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     88     test.TestCase.__init__(self, methodName)
     89     dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
     90         self, _dnn_only_estimator_fn)
     91 
     92 
     93 def _linear_only_estimator_fn(
     94     feature_columns,
     95     model_dir=None,
     96     label_dimension=1,
     97     weight_column=None,
     98     optimizer='Ftrl',
     99     config=None,
    100     partitioner=None):
    101   return dnn_linear_combined.DNNLinearCombinedEstimator(
    102       head=head_lib.regression_head(
    103           weight_column=weight_column, label_dimension=label_dimension),
    104       model_dir=model_dir,
    105       linear_feature_columns=feature_columns,
    106       linear_optimizer=optimizer,
    107       input_layer_partitioner=partitioner,
    108       config=config)
    109 
    110 
    111 class LinearOnlyEstimatorEvaluateTest(
    112     linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase):
    113 
    114   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    115     test.TestCase.__init__(self, methodName)
    116     linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(
    117         self, _linear_only_estimator_fn)
    118 
    119 
    120 class LinearOnlyEstimatorPredictTest(
    121     linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase):
    122 
    123   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    124     test.TestCase.__init__(self, methodName)
    125     linear_testing_utils.BaseLinearRegressorPredictTest.__init__(
    126         self, _linear_only_estimator_fn)
    127 
    128 
    129 class LinearOnlyEstimatorTrainTest(
    130     linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase):
    131 
    132   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    133     test.TestCase.__init__(self, methodName)
    134     linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(
    135         self, _linear_only_estimator_fn)
    136 
    137 
    138 class DNNLinearCombinedEstimatorIntegrationTest(test.TestCase):
    139 
    140   def setUp(self):
    141     self._model_dir = tempfile.mkdtemp()
    142 
    143   def tearDown(self):
    144     if self._model_dir:
    145       writer_cache.FileWriterCache.clear()
    146       shutil.rmtree(self._model_dir)
    147 
    148   def _test_complete_flow(
    149       self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
    150       label_dimension, batch_size):
    151     linear_feature_columns = [
    152         feature_column.numeric_column('x', shape=(input_dimension,))]
    153     dnn_feature_columns = [
    154         feature_column.numeric_column('x', shape=(input_dimension,))]
    155     feature_columns = linear_feature_columns + dnn_feature_columns
    156     est = dnn_linear_combined.DNNLinearCombinedEstimator(
    157         head=head_lib.regression_head(label_dimension=label_dimension),
    158         linear_feature_columns=linear_feature_columns,
    159         dnn_feature_columns=dnn_feature_columns,
    160         dnn_hidden_units=(2, 2),
    161         model_dir=self._model_dir)
    162 
    163     # TRAIN
    164     num_steps = 10
    165     est.train(train_input_fn, steps=num_steps)
    166 
    167     # EVALUTE
    168     scores = est.evaluate(eval_input_fn)
    169     self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
    170     self.assertIn('loss', six.iterkeys(scores))
    171 
    172     # PREDICT
    173     predictions = np.array([
    174         x[prediction_keys.PredictionKeys.PREDICTIONS]
    175         for x in est.predict(predict_input_fn)
    176     ])
    177     self.assertAllEqual((batch_size, label_dimension), predictions.shape)
    178 
    179     # EXPORT
    180     feature_spec = feature_column.make_parse_example_spec(feature_columns)
    181     serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
    182         feature_spec)
    183     export_dir = est.export_savedmodel(tempfile.mkdtemp(),
    184                                        serving_input_receiver_fn)
    185     self.assertTrue(gfile.Exists(export_dir))
    186 
    187   def test_numpy_input_fn(self):
    188     """Tests complete flow with numpy_input_fn."""
    189     label_dimension = 2
    190     batch_size = 10
    191     data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
    192     data = data.reshape(batch_size, label_dimension)
    193     # learn y = x
    194     train_input_fn = numpy_io.numpy_input_fn(
    195         x={'x': data},
    196         y=data,
    197         batch_size=batch_size,
    198         num_epochs=None,
    199         shuffle=True)
    200     eval_input_fn = numpy_io.numpy_input_fn(
    201         x={'x': data},
    202         y=data,
    203         batch_size=batch_size,
    204         shuffle=False)
    205     predict_input_fn = numpy_io.numpy_input_fn(
    206         x={'x': data},
    207         batch_size=batch_size,
    208         shuffle=False)
    209 
    210     self._test_complete_flow(
    211         train_input_fn=train_input_fn,
    212         eval_input_fn=eval_input_fn,
    213         predict_input_fn=predict_input_fn,
    214         input_dimension=label_dimension,
    215         label_dimension=label_dimension,
    216         batch_size=batch_size)
    217 
    218 
    219 if __name__ == '__main__':
    220   test.main()
    221