1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for dnn.py.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import shutil 22 import tempfile 23 24 import numpy as np 25 import six 26 27 from tensorflow.contrib.estimator.python.estimator import dnn 28 from tensorflow.contrib.estimator.python.estimator import head as head_lib 29 from tensorflow.python.estimator.canned import dnn_testing_utils 30 from tensorflow.python.estimator.canned import prediction_keys 31 from tensorflow.python.estimator.export import export 32 from tensorflow.python.estimator.inputs import numpy_io 33 from tensorflow.python.feature_column import feature_column 34 from tensorflow.python.framework import ops 35 from tensorflow.python.platform import gfile 36 from tensorflow.python.platform import test 37 from tensorflow.python.summary.writer import writer_cache 38 39 40 def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs): 41 """Returns a DNNEstimator that uses regression_head.""" 42 return dnn.DNNEstimator( 43 head=head_lib.regression_head( 44 weight_column=weight_column, label_dimension=label_dimension), 45 *args, **kwargs) 46 47 48 class DNNEstimatorEvaluateTest( 49 dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase): 50 51 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 52 test.TestCase.__init__(self, methodName) 53 dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__( 54 self, _dnn_estimator_fn) 55 56 57 class DNNEstimatorPredictTest( 58 dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase): 59 60 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 61 test.TestCase.__init__(self, methodName) 62 dnn_testing_utils.BaseDNNRegressorPredictTest.__init__( 63 self, _dnn_estimator_fn) 64 65 66 class DNNEstimatorTrainTest( 67 dnn_testing_utils.BaseDNNRegressorTrainTest, test.TestCase): 68 69 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 70 test.TestCase.__init__(self, methodName) 71 dnn_testing_utils.BaseDNNRegressorTrainTest.__init__( 72 self, _dnn_estimator_fn) 73 74 75 class DNNEstimatorIntegrationTest(test.TestCase): 76 77 def setUp(self): 78 self._model_dir = tempfile.mkdtemp() 79 80 def tearDown(self): 81 if self._model_dir: 82 writer_cache.FileWriterCache.clear() 83 shutil.rmtree(self._model_dir) 84 85 def _test_complete_flow( 86 self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension, 87 label_dimension, batch_size): 88 feature_columns = [ 89 feature_column.numeric_column('x', shape=(input_dimension,))] 90 est = dnn.DNNEstimator( 91 head=head_lib.regression_head(label_dimension=label_dimension), 92 hidden_units=(2, 2), 93 feature_columns=feature_columns, 94 model_dir=self._model_dir) 95 96 # TRAIN 97 num_steps = 10 98 est.train(train_input_fn, steps=num_steps) 99 100 # EVALUTE 101 scores = est.evaluate(eval_input_fn) 102 self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) 103 self.assertIn('loss', six.iterkeys(scores)) 104 105 # PREDICT 106 predictions = np.array([ 107 x[prediction_keys.PredictionKeys.PREDICTIONS] 108 for x in est.predict(predict_input_fn) 109 ]) 110 self.assertAllEqual((batch_size, label_dimension), predictions.shape) 111 112 # EXPORT 113 feature_spec = feature_column.make_parse_example_spec(feature_columns) 114 serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( 115 feature_spec) 116 export_dir = est.export_savedmodel(tempfile.mkdtemp(), 117 serving_input_receiver_fn) 118 self.assertTrue(gfile.Exists(export_dir)) 119 120 def test_numpy_input_fn(self): 121 """Tests complete flow with numpy_input_fn.""" 122 label_dimension = 2 123 batch_size = 10 124 data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) 125 data = data.reshape(batch_size, label_dimension) 126 # learn y = x 127 train_input_fn = numpy_io.numpy_input_fn( 128 x={'x': data}, 129 y=data, 130 batch_size=batch_size, 131 num_epochs=None, 132 shuffle=True) 133 eval_input_fn = numpy_io.numpy_input_fn( 134 x={'x': data}, 135 y=data, 136 batch_size=batch_size, 137 shuffle=False) 138 predict_input_fn = numpy_io.numpy_input_fn( 139 x={'x': data}, 140 batch_size=batch_size, 141 shuffle=False) 142 143 self._test_complete_flow( 144 train_input_fn=train_input_fn, 145 eval_input_fn=eval_input_fn, 146 predict_input_fn=predict_input_fn, 147 input_dimension=label_dimension, 148 label_dimension=label_dimension, 149 batch_size=batch_size) 150 151 152 if __name__ == '__main__': 153 test.main() 154