1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for dnn.py.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import shutil 22 import tempfile 23 24 import numpy as np 25 import six 26 27 from tensorflow.core.example import example_pb2 28 from tensorflow.core.example import feature_pb2 29 from tensorflow.python.estimator.canned import dnn 30 from tensorflow.python.estimator.canned import dnn_testing_utils 31 from tensorflow.python.estimator.canned import prediction_keys 32 from tensorflow.python.estimator.export import export 33 from tensorflow.python.estimator.inputs import numpy_io 34 from tensorflow.python.estimator.inputs import pandas_io 35 from tensorflow.python.feature_column import feature_column 36 from tensorflow.python.framework import dtypes 37 from tensorflow.python.framework import ops 38 from tensorflow.python.ops import data_flow_ops 39 from tensorflow.python.ops import parsing_ops 40 from tensorflow.python.platform import gfile 41 from tensorflow.python.platform import test 42 from tensorflow.python.summary.writer import writer_cache 43 from tensorflow.python.training import input as input_lib 44 from tensorflow.python.training import queue_runner 45 46 try: 47 # pylint: disable=g-import-not-at-top 48 import pandas as pd 49 HAS_PANDAS = True 50 except IOError: 51 # Pandas writes a temporary file during import. If it fails, don't use pandas. 52 HAS_PANDAS = False 53 except ImportError: 54 HAS_PANDAS = False 55 56 57 def _dnn_classifier_fn(*args, **kwargs): 58 return dnn.DNNClassifier(*args, **kwargs) 59 60 61 class DNNModelFnTest(dnn_testing_utils.BaseDNNModelFnTest, test.TestCase): 62 63 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 64 test.TestCase.__init__(self, methodName) 65 dnn_testing_utils.BaseDNNModelFnTest.__init__(self, dnn._dnn_model_fn) 66 67 68 class DNNLogitFnTest(dnn_testing_utils.BaseDNNLogitFnTest, test.TestCase): 69 70 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 71 test.TestCase.__init__(self, methodName) 72 dnn_testing_utils.BaseDNNLogitFnTest.__init__(self, 73 dnn._dnn_logit_fn_builder) 74 75 76 class DNNWarmStartingTest(dnn_testing_utils.BaseDNNWarmStartingTest, 77 test.TestCase): 78 79 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 80 test.TestCase.__init__(self, methodName) 81 dnn_testing_utils.BaseDNNWarmStartingTest.__init__(self, _dnn_classifier_fn, 82 _dnn_regressor_fn) 83 84 85 class DNNClassifierEvaluateTest( 86 dnn_testing_utils.BaseDNNClassifierEvaluateTest, test.TestCase): 87 88 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 89 test.TestCase.__init__(self, methodName) 90 dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__( 91 self, _dnn_classifier_fn) 92 93 94 class DNNClassifierPredictTest( 95 dnn_testing_utils.BaseDNNClassifierPredictTest, test.TestCase): 96 97 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 98 test.TestCase.__init__(self, methodName) 99 dnn_testing_utils.BaseDNNClassifierPredictTest.__init__( 100 self, _dnn_classifier_fn) 101 102 103 class DNNClassifierTrainTest( 104 dnn_testing_utils.BaseDNNClassifierTrainTest, test.TestCase): 105 106 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 107 test.TestCase.__init__(self, methodName) 108 dnn_testing_utils.BaseDNNClassifierTrainTest.__init__( 109 self, _dnn_classifier_fn) 110 111 112 def _dnn_regressor_fn(*args, **kwargs): 113 return dnn.DNNRegressor(*args, **kwargs) 114 115 116 class DNNRegressorEvaluateTest( 117 dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase): 118 119 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 120 test.TestCase.__init__(self, methodName) 121 dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__( 122 self, _dnn_regressor_fn) 123 124 125 class DNNRegressorPredictTest( 126 dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase): 127 128 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 129 test.TestCase.__init__(self, methodName) 130 dnn_testing_utils.BaseDNNRegressorPredictTest.__init__( 131 self, _dnn_regressor_fn) 132 133 134 class DNNRegressorTrainTest( 135 dnn_testing_utils.BaseDNNRegressorTrainTest, test.TestCase): 136 137 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 138 test.TestCase.__init__(self, methodName) 139 dnn_testing_utils.BaseDNNRegressorTrainTest.__init__( 140 self, _dnn_regressor_fn) 141 142 143 def _queue_parsed_features(feature_map): 144 tensors_to_enqueue = [] 145 keys = [] 146 for key, tensor in six.iteritems(feature_map): 147 keys.append(key) 148 tensors_to_enqueue.append(tensor) 149 queue_dtypes = [x.dtype for x in tensors_to_enqueue] 150 input_queue = data_flow_ops.FIFOQueue(capacity=100, dtypes=queue_dtypes) 151 queue_runner.add_queue_runner( 152 queue_runner.QueueRunner( 153 input_queue, 154 [input_queue.enqueue(tensors_to_enqueue)])) 155 dequeued_tensors = input_queue.dequeue() 156 return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))} 157 158 159 class DNNRegressorIntegrationTest(test.TestCase): 160 161 def setUp(self): 162 self._model_dir = tempfile.mkdtemp() 163 164 def tearDown(self): 165 if self._model_dir: 166 writer_cache.FileWriterCache.clear() 167 shutil.rmtree(self._model_dir) 168 169 def _test_complete_flow( 170 self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension, 171 label_dimension, batch_size): 172 feature_columns = [ 173 feature_column.numeric_column('x', shape=(input_dimension,))] 174 est = dnn.DNNRegressor( 175 hidden_units=(2, 2), 176 feature_columns=feature_columns, 177 label_dimension=label_dimension, 178 model_dir=self._model_dir) 179 180 # TRAIN 181 num_steps = 10 182 est.train(train_input_fn, steps=num_steps) 183 184 # EVALUTE 185 scores = est.evaluate(eval_input_fn) 186 self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) 187 self.assertIn('loss', six.iterkeys(scores)) 188 189 # PREDICT 190 predictions = np.array([ 191 x[prediction_keys.PredictionKeys.PREDICTIONS] 192 for x in est.predict(predict_input_fn) 193 ]) 194 self.assertAllEqual((batch_size, label_dimension), predictions.shape) 195 196 # EXPORT 197 feature_spec = feature_column.make_parse_example_spec(feature_columns) 198 serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( 199 feature_spec) 200 export_dir = est.export_savedmodel(tempfile.mkdtemp(), 201 serving_input_receiver_fn) 202 self.assertTrue(gfile.Exists(export_dir)) 203 204 def test_numpy_input_fn(self): 205 """Tests complete flow with numpy_input_fn.""" 206 label_dimension = 2 207 batch_size = 10 208 data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) 209 data = data.reshape(batch_size, label_dimension) 210 # learn y = x 211 train_input_fn = numpy_io.numpy_input_fn( 212 x={'x': data}, 213 y=data, 214 batch_size=batch_size, 215 num_epochs=None, 216 shuffle=True) 217 eval_input_fn = numpy_io.numpy_input_fn( 218 x={'x': data}, 219 y=data, 220 batch_size=batch_size, 221 shuffle=False) 222 predict_input_fn = numpy_io.numpy_input_fn( 223 x={'x': data}, 224 batch_size=batch_size, 225 shuffle=False) 226 227 self._test_complete_flow( 228 train_input_fn=train_input_fn, 229 eval_input_fn=eval_input_fn, 230 predict_input_fn=predict_input_fn, 231 input_dimension=label_dimension, 232 label_dimension=label_dimension, 233 batch_size=batch_size) 234 235 def test_pandas_input_fn(self): 236 """Tests complete flow with pandas_input_fn.""" 237 if not HAS_PANDAS: 238 return 239 label_dimension = 1 240 batch_size = 10 241 data = np.linspace(0., 2., batch_size, dtype=np.float32) 242 x = pd.DataFrame({'x': data}) 243 y = pd.Series(data) 244 train_input_fn = pandas_io.pandas_input_fn( 245 x=x, 246 y=y, 247 batch_size=batch_size, 248 num_epochs=None, 249 shuffle=True) 250 eval_input_fn = pandas_io.pandas_input_fn( 251 x=x, 252 y=y, 253 batch_size=batch_size, 254 shuffle=False) 255 predict_input_fn = pandas_io.pandas_input_fn( 256 x=x, 257 batch_size=batch_size, 258 shuffle=False) 259 260 self._test_complete_flow( 261 train_input_fn=train_input_fn, 262 eval_input_fn=eval_input_fn, 263 predict_input_fn=predict_input_fn, 264 input_dimension=label_dimension, 265 label_dimension=label_dimension, 266 batch_size=batch_size) 267 268 def test_input_fn_from_parse_example(self): 269 """Tests complete flow with input_fn constructed from parse_example.""" 270 label_dimension = 2 271 batch_size = 10 272 data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) 273 data = data.reshape(batch_size, label_dimension) 274 275 serialized_examples = [] 276 for datum in data: 277 example = example_pb2.Example(features=feature_pb2.Features( 278 feature={ 279 'x': feature_pb2.Feature( 280 float_list=feature_pb2.FloatList(value=datum)), 281 'y': feature_pb2.Feature( 282 float_list=feature_pb2.FloatList(value=datum)), 283 })) 284 serialized_examples.append(example.SerializeToString()) 285 286 feature_spec = { 287 'x': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32), 288 'y': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32), 289 } 290 def _train_input_fn(): 291 feature_map = parsing_ops.parse_example(serialized_examples, feature_spec) 292 features = _queue_parsed_features(feature_map) 293 labels = features.pop('y') 294 return features, labels 295 def _eval_input_fn(): 296 feature_map = parsing_ops.parse_example( 297 input_lib.limit_epochs(serialized_examples, num_epochs=1), 298 feature_spec) 299 features = _queue_parsed_features(feature_map) 300 labels = features.pop('y') 301 return features, labels 302 def _predict_input_fn(): 303 feature_map = parsing_ops.parse_example( 304 input_lib.limit_epochs(serialized_examples, num_epochs=1), 305 feature_spec) 306 features = _queue_parsed_features(feature_map) 307 features.pop('y') 308 return features, None 309 310 self._test_complete_flow( 311 train_input_fn=_train_input_fn, 312 eval_input_fn=_eval_input_fn, 313 predict_input_fn=_predict_input_fn, 314 input_dimension=label_dimension, 315 label_dimension=label_dimension, 316 batch_size=batch_size) 317 318 319 class DNNClassifierIntegrationTest(test.TestCase): 320 321 def setUp(self): 322 self._model_dir = tempfile.mkdtemp() 323 324 def tearDown(self): 325 if self._model_dir: 326 writer_cache.FileWriterCache.clear() 327 shutil.rmtree(self._model_dir) 328 329 def _as_label(self, data_in_float): 330 return np.rint(data_in_float).astype(np.int64) 331 332 def _test_complete_flow( 333 self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension, 334 n_classes, batch_size): 335 feature_columns = [ 336 feature_column.numeric_column('x', shape=(input_dimension,))] 337 est = dnn.DNNClassifier( 338 hidden_units=(2, 2), 339 feature_columns=feature_columns, 340 n_classes=n_classes, 341 model_dir=self._model_dir) 342 343 # TRAIN 344 num_steps = 10 345 est.train(train_input_fn, steps=num_steps) 346 347 # EVALUTE 348 scores = est.evaluate(eval_input_fn) 349 self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) 350 self.assertIn('loss', six.iterkeys(scores)) 351 352 # PREDICT 353 predicted_proba = np.array([ 354 x[prediction_keys.PredictionKeys.PROBABILITIES] 355 for x in est.predict(predict_input_fn) 356 ]) 357 self.assertAllEqual((batch_size, n_classes), predicted_proba.shape) 358 359 # EXPORT 360 feature_spec = feature_column.make_parse_example_spec(feature_columns) 361 serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( 362 feature_spec) 363 export_dir = est.export_savedmodel(tempfile.mkdtemp(), 364 serving_input_receiver_fn) 365 self.assertTrue(gfile.Exists(export_dir)) 366 367 def test_numpy_input_fn(self): 368 """Tests complete flow with numpy_input_fn.""" 369 n_classes = 3 370 input_dimension = 2 371 batch_size = 10 372 data = np.linspace( 373 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32) 374 x_data = data.reshape(batch_size, input_dimension) 375 y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1)) 376 # learn y = x 377 train_input_fn = numpy_io.numpy_input_fn( 378 x={'x': x_data}, 379 y=y_data, 380 batch_size=batch_size, 381 num_epochs=None, 382 shuffle=True) 383 eval_input_fn = numpy_io.numpy_input_fn( 384 x={'x': x_data}, 385 y=y_data, 386 batch_size=batch_size, 387 shuffle=False) 388 predict_input_fn = numpy_io.numpy_input_fn( 389 x={'x': x_data}, 390 batch_size=batch_size, 391 shuffle=False) 392 393 self._test_complete_flow( 394 train_input_fn=train_input_fn, 395 eval_input_fn=eval_input_fn, 396 predict_input_fn=predict_input_fn, 397 input_dimension=input_dimension, 398 n_classes=n_classes, 399 batch_size=batch_size) 400 401 def test_pandas_input_fn(self): 402 """Tests complete flow with pandas_input_fn.""" 403 if not HAS_PANDAS: 404 return 405 input_dimension = 1 406 n_classes = 3 407 batch_size = 10 408 data = np.linspace(0., n_classes - 1., batch_size, dtype=np.float32) 409 x = pd.DataFrame({'x': data}) 410 y = pd.Series(self._as_label(data)) 411 train_input_fn = pandas_io.pandas_input_fn( 412 x=x, 413 y=y, 414 batch_size=batch_size, 415 num_epochs=None, 416 shuffle=True) 417 eval_input_fn = pandas_io.pandas_input_fn( 418 x=x, 419 y=y, 420 batch_size=batch_size, 421 shuffle=False) 422 predict_input_fn = pandas_io.pandas_input_fn( 423 x=x, 424 batch_size=batch_size, 425 shuffle=False) 426 427 self._test_complete_flow( 428 train_input_fn=train_input_fn, 429 eval_input_fn=eval_input_fn, 430 predict_input_fn=predict_input_fn, 431 input_dimension=input_dimension, 432 n_classes=n_classes, 433 batch_size=batch_size) 434 435 def test_input_fn_from_parse_example(self): 436 """Tests complete flow with input_fn constructed from parse_example.""" 437 input_dimension = 2 438 n_classes = 3 439 batch_size = 10 440 data = np.linspace( 441 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32) 442 data = data.reshape(batch_size, input_dimension) 443 444 serialized_examples = [] 445 for datum in data: 446 example = example_pb2.Example(features=feature_pb2.Features( 447 feature={ 448 'x': 449 feature_pb2.Feature(float_list=feature_pb2.FloatList( 450 value=datum)), 451 'y': 452 feature_pb2.Feature(int64_list=feature_pb2.Int64List( 453 value=self._as_label(datum[:1]))), 454 })) 455 serialized_examples.append(example.SerializeToString()) 456 457 feature_spec = { 458 'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32), 459 'y': parsing_ops.FixedLenFeature([1], dtypes.int64), 460 } 461 def _train_input_fn(): 462 feature_map = parsing_ops.parse_example(serialized_examples, feature_spec) 463 features = _queue_parsed_features(feature_map) 464 labels = features.pop('y') 465 return features, labels 466 def _eval_input_fn(): 467 feature_map = parsing_ops.parse_example( 468 input_lib.limit_epochs(serialized_examples, num_epochs=1), 469 feature_spec) 470 features = _queue_parsed_features(feature_map) 471 labels = features.pop('y') 472 return features, labels 473 def _predict_input_fn(): 474 feature_map = parsing_ops.parse_example( 475 input_lib.limit_epochs(serialized_examples, num_epochs=1), 476 feature_spec) 477 features = _queue_parsed_features(feature_map) 478 features.pop('y') 479 return features, None 480 481 self._test_complete_flow( 482 train_input_fn=_train_input_fn, 483 eval_input_fn=_eval_input_fn, 484 predict_input_fn=_predict_input_fn, 485 input_dimension=input_dimension, 486 n_classes=n_classes, 487 batch_size=batch_size) 488 489 490 if __name__ == '__main__': 491 test.main() 492