1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for Estimator input.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import functools 22 import tempfile 23 24 import numpy as np 25 26 from tensorflow.python.training import training_util 27 from tensorflow.contrib.layers.python.layers import optimizers 28 from tensorflow.contrib.learn.python.learn import metric_spec 29 from tensorflow.contrib.learn.python.learn import models 30 from tensorflow.contrib.learn.python.learn.datasets import base 31 from tensorflow.contrib.learn.python.learn.estimators import _sklearn 32 from tensorflow.contrib.learn.python.learn.estimators import estimator 33 from tensorflow.contrib.learn.python.learn.estimators import model_fn 34 from tensorflow.contrib.metrics.python.ops import metric_ops 35 from tensorflow.python.framework import constant_op 36 from tensorflow.python.framework import dtypes 37 from tensorflow.python.ops import array_ops 38 from tensorflow.python.ops import data_flow_ops 39 from tensorflow.python.ops import math_ops 40 from tensorflow.python.platform import test 41 from tensorflow.python.training import input as input_lib 42 from tensorflow.python.training import queue_runner_impl 43 44 _BOSTON_INPUT_DIM = 13 45 _IRIS_INPUT_DIM = 4 46 47 48 def boston_input_fn(num_epochs=None): 49 boston = base.load_boston() 50 features = input_lib.limit_epochs( 51 array_ops.reshape( 52 constant_op.constant(boston.data), [-1, _BOSTON_INPUT_DIM]), 53 num_epochs=num_epochs) 54 labels = array_ops.reshape(constant_op.constant(boston.target), [-1, 1]) 55 return features, labels 56 57 58 def boston_input_fn_with_queue(num_epochs=None): 59 features, labels = boston_input_fn(num_epochs=num_epochs) 60 61 # Create a minimal queue runner. 62 fake_queue = data_flow_ops.FIFOQueue(30, dtypes.int32) 63 queue_runner = queue_runner_impl.QueueRunner(fake_queue, 64 [constant_op.constant(0)]) 65 queue_runner_impl.add_queue_runner(queue_runner) 66 67 return features, labels 68 69 70 def iris_input_fn(): 71 iris = base.load_iris() 72 features = array_ops.reshape( 73 constant_op.constant(iris.data), [-1, _IRIS_INPUT_DIM]) 74 labels = array_ops.reshape(constant_op.constant(iris.target), [-1]) 75 return features, labels 76 77 78 def iris_input_fn_labels_dict(): 79 iris = base.load_iris() 80 features = array_ops.reshape( 81 constant_op.constant(iris.data), [-1, _IRIS_INPUT_DIM]) 82 labels = { 83 'labels': array_ops.reshape(constant_op.constant(iris.target), [-1]) 84 } 85 return features, labels 86 87 88 def boston_eval_fn(): 89 boston = base.load_boston() 90 n_examples = len(boston.target) 91 features = array_ops.reshape( 92 constant_op.constant(boston.data), [n_examples, _BOSTON_INPUT_DIM]) 93 labels = array_ops.reshape( 94 constant_op.constant(boston.target), [n_examples, 1]) 95 return array_ops.concat([features, features], 96 0), array_ops.concat([labels, labels], 0) 97 98 99 def extract(data, key): 100 if isinstance(data, dict): 101 assert key in data 102 return data[key] 103 else: 104 return data 105 106 107 def linear_model_params_fn(features, labels, mode, params): 108 features = extract(features, 'input') 109 labels = extract(labels, 'labels') 110 111 assert mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, 112 model_fn.ModeKeys.INFER) 113 prediction, loss = (models.linear_regression_zero_init(features, labels)) 114 train_op = optimizers.optimize_loss( 115 loss, 116 training_util.get_global_step(), 117 optimizer='Adagrad', 118 learning_rate=params['learning_rate']) 119 return prediction, loss, train_op 120 121 122 def linear_model_fn(features, labels, mode): 123 features = extract(features, 'input') 124 labels = extract(labels, 'labels') 125 assert mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, 126 model_fn.ModeKeys.INFER) 127 if isinstance(features, dict): 128 (_, features), = features.items() 129 prediction, loss = (models.linear_regression_zero_init(features, labels)) 130 train_op = optimizers.optimize_loss( 131 loss, 132 training_util.get_global_step(), 133 optimizer='Adagrad', 134 learning_rate=0.1) 135 return prediction, loss, train_op 136 137 138 def linear_model_fn_with_model_fn_ops(features, labels, mode): 139 """Same as linear_model_fn, but returns `ModelFnOps`.""" 140 assert mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, 141 model_fn.ModeKeys.INFER) 142 prediction, loss = (models.linear_regression_zero_init(features, labels)) 143 train_op = optimizers.optimize_loss( 144 loss, 145 training_util.get_global_step(), 146 optimizer='Adagrad', 147 learning_rate=0.1) 148 return model_fn.ModelFnOps( 149 mode=mode, predictions=prediction, loss=loss, train_op=train_op) 150 151 152 def logistic_model_no_mode_fn(features, labels): 153 features = extract(features, 'input') 154 labels = extract(labels, 'labels') 155 labels = array_ops.one_hot(labels, 3, 1, 0) 156 prediction, loss = (models.logistic_regression_zero_init(features, labels)) 157 train_op = optimizers.optimize_loss( 158 loss, 159 training_util.get_global_step(), 160 optimizer='Adagrad', 161 learning_rate=0.1) 162 return { 163 'class': math_ops.argmax(prediction, 1), 164 'prob': prediction 165 }, loss, train_op 166 167 168 VOCAB_FILE_CONTENT = 'emerson\nlake\npalmer\n' 169 EXTRA_FILE_CONTENT = 'kermit\npiggy\nralph\n' 170 171 172 class EstimatorInputTest(test.TestCase): 173 174 def testContinueTrainingDictionaryInput(self): 175 boston = base.load_boston() 176 output_dir = tempfile.mkdtemp() 177 est = estimator.Estimator(model_fn=linear_model_fn, model_dir=output_dir) 178 boston_input = {'input': boston.data} 179 float64_target = {'labels': boston.target.astype(np.float64)} 180 est.fit(x=boston_input, y=float64_target, steps=50) 181 scores = est.evaluate( 182 x=boston_input, 183 y=float64_target, 184 metrics={ 185 'MSE': metric_ops.streaming_mean_squared_error 186 }) 187 del est 188 # Create another estimator object with the same output dir. 189 est2 = estimator.Estimator(model_fn=linear_model_fn, model_dir=output_dir) 190 191 # Check we can evaluate and predict. 192 scores2 = est2.evaluate( 193 x=boston_input, 194 y=float64_target, 195 metrics={ 196 'MSE': metric_ops.streaming_mean_squared_error 197 }) 198 self.assertAllClose(scores2['MSE'], scores['MSE']) 199 predictions = np.array(list(est2.predict(x=boston_input))) 200 other_score = _sklearn.mean_squared_error(predictions, 201 float64_target['labels']) 202 self.assertAllClose(other_score, scores['MSE']) 203 204 def testBostonAll(self): 205 boston = base.load_boston() 206 est = estimator.SKCompat(estimator.Estimator(model_fn=linear_model_fn)) 207 float64_labels = boston.target.astype(np.float64) 208 est.fit(x=boston.data, y=float64_labels, steps=100) 209 scores = est.score( 210 x=boston.data, 211 y=float64_labels, 212 metrics={ 213 'MSE': metric_ops.streaming_mean_squared_error 214 }) 215 predictions = np.array(list(est.predict(x=boston.data))) 216 other_score = _sklearn.mean_squared_error(predictions, boston.target) 217 self.assertAllClose(scores['MSE'], other_score) 218 self.assertTrue('global_step' in scores) 219 self.assertEqual(100, scores['global_step']) 220 221 def testBostonAllDictionaryInput(self): 222 boston = base.load_boston() 223 est = estimator.Estimator(model_fn=linear_model_fn) 224 boston_input = {'input': boston.data} 225 float64_target = {'labels': boston.target.astype(np.float64)} 226 est.fit(x=boston_input, y=float64_target, steps=100) 227 scores = est.evaluate( 228 x=boston_input, 229 y=float64_target, 230 metrics={ 231 'MSE': metric_ops.streaming_mean_squared_error 232 }) 233 predictions = np.array(list(est.predict(x=boston_input))) 234 other_score = _sklearn.mean_squared_error(predictions, boston.target) 235 self.assertAllClose(other_score, scores['MSE']) 236 self.assertTrue('global_step' in scores) 237 self.assertEqual(scores['global_step'], 100) 238 239 def testIrisAll(self): 240 iris = base.load_iris() 241 est = estimator.SKCompat( 242 estimator.Estimator(model_fn=logistic_model_no_mode_fn)) 243 est.fit(iris.data, iris.target, steps=100) 244 scores = est.score( 245 x=iris.data, 246 y=iris.target, 247 metrics={ 248 ('accuracy', 'class'): metric_ops.streaming_accuracy 249 }) 250 predictions = est.predict(x=iris.data) 251 predictions_class = est.predict(x=iris.data, outputs=['class'])['class'] 252 self.assertEqual(predictions['prob'].shape[0], iris.target.shape[0]) 253 self.assertAllClose(predictions['class'], predictions_class) 254 self.assertAllClose(predictions['class'], 255 np.argmax(predictions['prob'], axis=1)) 256 other_score = _sklearn.accuracy_score(iris.target, predictions['class']) 257 self.assertAllClose(scores['accuracy'], other_score) 258 self.assertTrue('global_step' in scores) 259 self.assertEqual(100, scores['global_step']) 260 261 def testIrisAllDictionaryInput(self): 262 iris = base.load_iris() 263 est = estimator.Estimator(model_fn=logistic_model_no_mode_fn) 264 iris_data = {'input': iris.data} 265 iris_target = {'labels': iris.target} 266 est.fit(iris_data, iris_target, steps=100) 267 scores = est.evaluate( 268 x=iris_data, 269 y=iris_target, 270 metrics={ 271 ('accuracy', 'class'): metric_ops.streaming_accuracy 272 }) 273 predictions = list(est.predict(x=iris_data)) 274 predictions_class = list(est.predict(x=iris_data, outputs=['class'])) 275 self.assertEqual(len(predictions), iris.target.shape[0]) 276 classes_batch = np.array([p['class'] for p in predictions]) 277 self.assertAllClose(classes_batch, 278 np.array([p['class'] for p in predictions_class])) 279 self.assertAllClose(classes_batch, 280 np.argmax( 281 np.array([p['prob'] for p in predictions]), axis=1)) 282 other_score = _sklearn.accuracy_score(iris.target, classes_batch) 283 self.assertAllClose(other_score, scores['accuracy']) 284 self.assertTrue('global_step' in scores) 285 self.assertEqual(scores['global_step'], 100) 286 287 def testIrisInputFn(self): 288 iris = base.load_iris() 289 est = estimator.Estimator(model_fn=logistic_model_no_mode_fn) 290 est.fit(input_fn=iris_input_fn, steps=100) 291 _ = est.evaluate(input_fn=iris_input_fn, steps=1) 292 predictions = list(est.predict(x=iris.data)) 293 self.assertEqual(len(predictions), iris.target.shape[0]) 294 295 def testIrisInputFnLabelsDict(self): 296 iris = base.load_iris() 297 est = estimator.Estimator(model_fn=logistic_model_no_mode_fn) 298 est.fit(input_fn=iris_input_fn_labels_dict, steps=100) 299 _ = est.evaluate( 300 input_fn=iris_input_fn_labels_dict, 301 steps=1, 302 metrics={ 303 'accuracy': 304 metric_spec.MetricSpec( 305 metric_fn=metric_ops.streaming_accuracy, 306 prediction_key='class', 307 label_key='labels') 308 }) 309 predictions = list(est.predict(x=iris.data)) 310 self.assertEqual(len(predictions), iris.target.shape[0]) 311 312 def testTrainInputFn(self): 313 est = estimator.Estimator(model_fn=linear_model_fn) 314 est.fit(input_fn=boston_input_fn, steps=1) 315 _ = est.evaluate(input_fn=boston_eval_fn, steps=1) 316 317 def testPredictInputFn(self): 318 est = estimator.Estimator(model_fn=linear_model_fn) 319 boston = base.load_boston() 320 est.fit(input_fn=boston_input_fn, steps=1) 321 input_fn = functools.partial(boston_input_fn, num_epochs=1) 322 output = list(est.predict(input_fn=input_fn)) 323 self.assertEqual(len(output), boston.target.shape[0]) 324 325 def testPredictInputFnWithQueue(self): 326 est = estimator.Estimator(model_fn=linear_model_fn) 327 boston = base.load_boston() 328 est.fit(input_fn=boston_input_fn, steps=1) 329 input_fn = functools.partial(boston_input_fn_with_queue, num_epochs=2) 330 output = list(est.predict(input_fn=input_fn)) 331 self.assertEqual(len(output), boston.target.shape[0] * 2) 332 333 def testPredictConstInputFn(self): 334 est = estimator.Estimator(model_fn=linear_model_fn) 335 boston = base.load_boston() 336 est.fit(input_fn=boston_input_fn, steps=1) 337 338 def input_fn(): 339 features = array_ops.reshape( 340 constant_op.constant(boston.data), [-1, _BOSTON_INPUT_DIM]) 341 labels = array_ops.reshape(constant_op.constant(boston.target), [-1, 1]) 342 return features, labels 343 344 output = list(est.predict(input_fn=input_fn)) 345 self.assertEqual(len(output), boston.target.shape[0]) 346 347 348 if __name__ == '__main__': 349 test.main() 350