1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for linear.py.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import shutil 22 import tempfile 23 24 import numpy as np 25 import six 26 27 from tensorflow.contrib.estimator.python.estimator import head as head_lib 28 from tensorflow.contrib.estimator.python.estimator import linear 29 from tensorflow.python.estimator.canned import linear_testing_utils 30 from tensorflow.python.estimator.canned import prediction_keys 31 from tensorflow.python.estimator.export import export 32 from tensorflow.python.estimator.inputs import numpy_io 33 from tensorflow.python.feature_column import feature_column 34 from tensorflow.python.framework import ops 35 from tensorflow.python.platform import gfile 36 from tensorflow.python.platform import test 37 from tensorflow.python.summary.writer import writer_cache 38 39 40 def _linear_estimator_fn( 41 weight_column=None, label_dimension=1, *args, **kwargs): 42 """Returns a LinearEstimator that uses regression_head.""" 43 return linear.LinearEstimator( 44 head=head_lib.regression_head( 45 weight_column=weight_column, label_dimension=label_dimension), 46 *args, **kwargs) 47 48 49 class LinearEstimatorEvaluateTest( 50 linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase): 51 52 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 53 test.TestCase.__init__(self, methodName) 54 linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__( 55 self, _linear_estimator_fn) 56 57 58 class LinearEstimatorPredictTest( 59 linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase): 60 61 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 62 test.TestCase.__init__(self, methodName) 63 linear_testing_utils.BaseLinearRegressorPredictTest.__init__( 64 self, _linear_estimator_fn) 65 66 67 class LinearEstimatorTrainTest( 68 linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase): 69 70 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 71 test.TestCase.__init__(self, methodName) 72 linear_testing_utils.BaseLinearRegressorTrainingTest.__init__( 73 self, _linear_estimator_fn) 74 75 76 class LinearEstimatorIntegrationTest(test.TestCase): 77 78 def setUp(self): 79 self._model_dir = tempfile.mkdtemp() 80 81 def tearDown(self): 82 if self._model_dir: 83 writer_cache.FileWriterCache.clear() 84 shutil.rmtree(self._model_dir) 85 86 def _test_complete_flow( 87 self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension, 88 label_dimension, batch_size): 89 feature_columns = [ 90 feature_column.numeric_column('x', shape=(input_dimension,))] 91 est = linear.LinearEstimator( 92 head=head_lib.regression_head(label_dimension=label_dimension), 93 feature_columns=feature_columns, 94 model_dir=self._model_dir) 95 96 # TRAIN 97 num_steps = 10 98 est.train(train_input_fn, steps=num_steps) 99 100 # EVALUTE 101 scores = est.evaluate(eval_input_fn) 102 self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) 103 self.assertIn('loss', six.iterkeys(scores)) 104 105 # PREDICT 106 predictions = np.array([ 107 x[prediction_keys.PredictionKeys.PREDICTIONS] 108 for x in est.predict(predict_input_fn) 109 ]) 110 self.assertAllEqual((batch_size, label_dimension), predictions.shape) 111 112 # EXPORT 113 feature_spec = feature_column.make_parse_example_spec(feature_columns) 114 serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( 115 feature_spec) 116 export_dir = est.export_savedmodel(tempfile.mkdtemp(), 117 serving_input_receiver_fn) 118 self.assertTrue(gfile.Exists(export_dir)) 119 120 def test_numpy_input_fn(self): 121 """Tests complete flow with numpy_input_fn.""" 122 label_dimension = 2 123 batch_size = 10 124 data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) 125 data = data.reshape(batch_size, label_dimension) 126 # learn y = x 127 train_input_fn = numpy_io.numpy_input_fn( 128 x={'x': data}, 129 y=data, 130 batch_size=batch_size, 131 num_epochs=None, 132 shuffle=True) 133 eval_input_fn = numpy_io.numpy_input_fn( 134 x={'x': data}, 135 y=data, 136 batch_size=batch_size, 137 shuffle=False) 138 predict_input_fn = numpy_io.numpy_input_fn( 139 x={'x': data}, 140 batch_size=batch_size, 141 shuffle=False) 142 143 self._test_complete_flow( 144 train_input_fn=train_input_fn, 145 eval_input_fn=eval_input_fn, 146 predict_input_fn=predict_input_fn, 147 input_dimension=label_dimension, 148 label_dimension=label_dimension, 149 batch_size=batch_size) 150 151 152 if __name__ == '__main__': 153 test.main() 154