Home | History | Annotate | Download | only in canned
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for dnn.py."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import shutil
     22 import tempfile
     23 
     24 import numpy as np
     25 import six
     26 
     27 from tensorflow.core.example import example_pb2
     28 from tensorflow.core.example import feature_pb2
     29 from tensorflow.python.estimator.canned import dnn
     30 from tensorflow.python.estimator.canned import dnn_testing_utils
     31 from tensorflow.python.estimator.canned import prediction_keys
     32 from tensorflow.python.estimator.export import export
     33 from tensorflow.python.estimator.inputs import numpy_io
     34 from tensorflow.python.estimator.inputs import pandas_io
     35 from tensorflow.python.feature_column import feature_column
     36 from tensorflow.python.framework import dtypes
     37 from tensorflow.python.framework import ops
     38 from tensorflow.python.ops import data_flow_ops
     39 from tensorflow.python.ops import parsing_ops
     40 from tensorflow.python.platform import gfile
     41 from tensorflow.python.platform import test
     42 from tensorflow.python.summary.writer import writer_cache
     43 from tensorflow.python.training import input as input_lib
     44 from tensorflow.python.training import queue_runner
     45 
     46 try:
     47   # pylint: disable=g-import-not-at-top
     48   import pandas as pd
     49   HAS_PANDAS = True
     50 except IOError:
     51   # Pandas writes a temporary file during import. If it fails, don't use pandas.
     52   HAS_PANDAS = False
     53 except ImportError:
     54   HAS_PANDAS = False
     55 
     56 
     57 def _dnn_classifier_fn(*args, **kwargs):
     58   return dnn.DNNClassifier(*args, **kwargs)
     59 
     60 
     61 class DNNModelFnTest(dnn_testing_utils.BaseDNNModelFnTest, test.TestCase):
     62 
     63   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     64     test.TestCase.__init__(self, methodName)
     65     dnn_testing_utils.BaseDNNModelFnTest.__init__(self, dnn._dnn_model_fn)
     66 
     67 
     68 class DNNLogitFnTest(dnn_testing_utils.BaseDNNLogitFnTest, test.TestCase):
     69 
     70   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     71     test.TestCase.__init__(self, methodName)
     72     dnn_testing_utils.BaseDNNLogitFnTest.__init__(self,
     73                                                   dnn._dnn_logit_fn_builder)
     74 
     75 
     76 class DNNWarmStartingTest(dnn_testing_utils.BaseDNNWarmStartingTest,
     77                           test.TestCase):
     78 
     79   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     80     test.TestCase.__init__(self, methodName)
     81     dnn_testing_utils.BaseDNNWarmStartingTest.__init__(self, _dnn_classifier_fn,
     82                                                        _dnn_regressor_fn)
     83 
     84 
     85 class DNNClassifierEvaluateTest(
     86     dnn_testing_utils.BaseDNNClassifierEvaluateTest, test.TestCase):
     87 
     88   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     89     test.TestCase.__init__(self, methodName)
     90     dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
     91         self, _dnn_classifier_fn)
     92 
     93 
     94 class DNNClassifierPredictTest(
     95     dnn_testing_utils.BaseDNNClassifierPredictTest, test.TestCase):
     96 
     97   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     98     test.TestCase.__init__(self, methodName)
     99     dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
    100         self, _dnn_classifier_fn)
    101 
    102 
    103 class DNNClassifierTrainTest(
    104     dnn_testing_utils.BaseDNNClassifierTrainTest, test.TestCase):
    105 
    106   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    107     test.TestCase.__init__(self, methodName)
    108     dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(
    109         self, _dnn_classifier_fn)
    110 
    111 
    112 def _dnn_regressor_fn(*args, **kwargs):
    113   return dnn.DNNRegressor(*args, **kwargs)
    114 
    115 
    116 class DNNRegressorEvaluateTest(
    117     dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase):
    118 
    119   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    120     test.TestCase.__init__(self, methodName)
    121     dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
    122         self, _dnn_regressor_fn)
    123 
    124 
    125 class DNNRegressorPredictTest(
    126     dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase):
    127 
    128   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    129     test.TestCase.__init__(self, methodName)
    130     dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
    131         self, _dnn_regressor_fn)
    132 
    133 
    134 class DNNRegressorTrainTest(
    135     dnn_testing_utils.BaseDNNRegressorTrainTest, test.TestCase):
    136 
    137   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    138     test.TestCase.__init__(self, methodName)
    139     dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
    140         self, _dnn_regressor_fn)
    141 
    142 
    143 def _queue_parsed_features(feature_map):
    144   tensors_to_enqueue = []
    145   keys = []
    146   for key, tensor in six.iteritems(feature_map):
    147     keys.append(key)
    148     tensors_to_enqueue.append(tensor)
    149   queue_dtypes = [x.dtype for x in tensors_to_enqueue]
    150   input_queue = data_flow_ops.FIFOQueue(capacity=100, dtypes=queue_dtypes)
    151   queue_runner.add_queue_runner(
    152       queue_runner.QueueRunner(
    153           input_queue,
    154           [input_queue.enqueue(tensors_to_enqueue)]))
    155   dequeued_tensors = input_queue.dequeue()
    156   return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))}
    157 
    158 
    159 class DNNRegressorIntegrationTest(test.TestCase):
    160 
    161   def setUp(self):
    162     self._model_dir = tempfile.mkdtemp()
    163 
    164   def tearDown(self):
    165     if self._model_dir:
    166       writer_cache.FileWriterCache.clear()
    167       shutil.rmtree(self._model_dir)
    168 
    169   def _test_complete_flow(
    170       self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
    171       label_dimension, batch_size):
    172     feature_columns = [
    173         feature_column.numeric_column('x', shape=(input_dimension,))]
    174     est = dnn.DNNRegressor(
    175         hidden_units=(2, 2),
    176         feature_columns=feature_columns,
    177         label_dimension=label_dimension,
    178         model_dir=self._model_dir)
    179 
    180     # TRAIN
    181     num_steps = 10
    182     est.train(train_input_fn, steps=num_steps)
    183 
    184     # EVALUTE
    185     scores = est.evaluate(eval_input_fn)
    186     self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
    187     self.assertIn('loss', six.iterkeys(scores))
    188 
    189     # PREDICT
    190     predictions = np.array([
    191         x[prediction_keys.PredictionKeys.PREDICTIONS]
    192         for x in est.predict(predict_input_fn)
    193     ])
    194     self.assertAllEqual((batch_size, label_dimension), predictions.shape)
    195 
    196     # EXPORT
    197     feature_spec = feature_column.make_parse_example_spec(feature_columns)
    198     serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
    199         feature_spec)
    200     export_dir = est.export_savedmodel(tempfile.mkdtemp(),
    201                                        serving_input_receiver_fn)
    202     self.assertTrue(gfile.Exists(export_dir))
    203 
    204   def test_numpy_input_fn(self):
    205     """Tests complete flow with numpy_input_fn."""
    206     label_dimension = 2
    207     batch_size = 10
    208     data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
    209     data = data.reshape(batch_size, label_dimension)
    210     # learn y = x
    211     train_input_fn = numpy_io.numpy_input_fn(
    212         x={'x': data},
    213         y=data,
    214         batch_size=batch_size,
    215         num_epochs=None,
    216         shuffle=True)
    217     eval_input_fn = numpy_io.numpy_input_fn(
    218         x={'x': data},
    219         y=data,
    220         batch_size=batch_size,
    221         shuffle=False)
    222     predict_input_fn = numpy_io.numpy_input_fn(
    223         x={'x': data},
    224         batch_size=batch_size,
    225         shuffle=False)
    226 
    227     self._test_complete_flow(
    228         train_input_fn=train_input_fn,
    229         eval_input_fn=eval_input_fn,
    230         predict_input_fn=predict_input_fn,
    231         input_dimension=label_dimension,
    232         label_dimension=label_dimension,
    233         batch_size=batch_size)
    234 
    235   def test_pandas_input_fn(self):
    236     """Tests complete flow with pandas_input_fn."""
    237     if not HAS_PANDAS:
    238       return
    239     label_dimension = 1
    240     batch_size = 10
    241     data = np.linspace(0., 2., batch_size, dtype=np.float32)
    242     x = pd.DataFrame({'x': data})
    243     y = pd.Series(data)
    244     train_input_fn = pandas_io.pandas_input_fn(
    245         x=x,
    246         y=y,
    247         batch_size=batch_size,
    248         num_epochs=None,
    249         shuffle=True)
    250     eval_input_fn = pandas_io.pandas_input_fn(
    251         x=x,
    252         y=y,
    253         batch_size=batch_size,
    254         shuffle=False)
    255     predict_input_fn = pandas_io.pandas_input_fn(
    256         x=x,
    257         batch_size=batch_size,
    258         shuffle=False)
    259 
    260     self._test_complete_flow(
    261         train_input_fn=train_input_fn,
    262         eval_input_fn=eval_input_fn,
    263         predict_input_fn=predict_input_fn,
    264         input_dimension=label_dimension,
    265         label_dimension=label_dimension,
    266         batch_size=batch_size)
    267 
    268   def test_input_fn_from_parse_example(self):
    269     """Tests complete flow with input_fn constructed from parse_example."""
    270     label_dimension = 2
    271     batch_size = 10
    272     data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
    273     data = data.reshape(batch_size, label_dimension)
    274 
    275     serialized_examples = []
    276     for datum in data:
    277       example = example_pb2.Example(features=feature_pb2.Features(
    278           feature={
    279               'x': feature_pb2.Feature(
    280                   float_list=feature_pb2.FloatList(value=datum)),
    281               'y': feature_pb2.Feature(
    282                   float_list=feature_pb2.FloatList(value=datum)),
    283           }))
    284       serialized_examples.append(example.SerializeToString())
    285 
    286     feature_spec = {
    287         'x': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32),
    288         'y': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32),
    289     }
    290     def _train_input_fn():
    291       feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
    292       features = _queue_parsed_features(feature_map)
    293       labels = features.pop('y')
    294       return features, labels
    295     def _eval_input_fn():
    296       feature_map = parsing_ops.parse_example(
    297           input_lib.limit_epochs(serialized_examples, num_epochs=1),
    298           feature_spec)
    299       features = _queue_parsed_features(feature_map)
    300       labels = features.pop('y')
    301       return features, labels
    302     def _predict_input_fn():
    303       feature_map = parsing_ops.parse_example(
    304           input_lib.limit_epochs(serialized_examples, num_epochs=1),
    305           feature_spec)
    306       features = _queue_parsed_features(feature_map)
    307       features.pop('y')
    308       return features, None
    309 
    310     self._test_complete_flow(
    311         train_input_fn=_train_input_fn,
    312         eval_input_fn=_eval_input_fn,
    313         predict_input_fn=_predict_input_fn,
    314         input_dimension=label_dimension,
    315         label_dimension=label_dimension,
    316         batch_size=batch_size)
    317 
    318 
    319 class DNNClassifierIntegrationTest(test.TestCase):
    320 
    321   def setUp(self):
    322     self._model_dir = tempfile.mkdtemp()
    323 
    324   def tearDown(self):
    325     if self._model_dir:
    326       writer_cache.FileWriterCache.clear()
    327       shutil.rmtree(self._model_dir)
    328 
    329   def _as_label(self, data_in_float):
    330     return np.rint(data_in_float).astype(np.int64)
    331 
    332   def _test_complete_flow(
    333       self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
    334       n_classes, batch_size):
    335     feature_columns = [
    336         feature_column.numeric_column('x', shape=(input_dimension,))]
    337     est = dnn.DNNClassifier(
    338         hidden_units=(2, 2),
    339         feature_columns=feature_columns,
    340         n_classes=n_classes,
    341         model_dir=self._model_dir)
    342 
    343     # TRAIN
    344     num_steps = 10
    345     est.train(train_input_fn, steps=num_steps)
    346 
    347     # EVALUTE
    348     scores = est.evaluate(eval_input_fn)
    349     self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
    350     self.assertIn('loss', six.iterkeys(scores))
    351 
    352     # PREDICT
    353     predicted_proba = np.array([
    354         x[prediction_keys.PredictionKeys.PROBABILITIES]
    355         for x in est.predict(predict_input_fn)
    356     ])
    357     self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
    358 
    359     # EXPORT
    360     feature_spec = feature_column.make_parse_example_spec(feature_columns)
    361     serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
    362         feature_spec)
    363     export_dir = est.export_savedmodel(tempfile.mkdtemp(),
    364                                        serving_input_receiver_fn)
    365     self.assertTrue(gfile.Exists(export_dir))
    366 
    367   def test_numpy_input_fn(self):
    368     """Tests complete flow with numpy_input_fn."""
    369     n_classes = 3
    370     input_dimension = 2
    371     batch_size = 10
    372     data = np.linspace(
    373         0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
    374     x_data = data.reshape(batch_size, input_dimension)
    375     y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))
    376     # learn y = x
    377     train_input_fn = numpy_io.numpy_input_fn(
    378         x={'x': x_data},
    379         y=y_data,
    380         batch_size=batch_size,
    381         num_epochs=None,
    382         shuffle=True)
    383     eval_input_fn = numpy_io.numpy_input_fn(
    384         x={'x': x_data},
    385         y=y_data,
    386         batch_size=batch_size,
    387         shuffle=False)
    388     predict_input_fn = numpy_io.numpy_input_fn(
    389         x={'x': x_data},
    390         batch_size=batch_size,
    391         shuffle=False)
    392 
    393     self._test_complete_flow(
    394         train_input_fn=train_input_fn,
    395         eval_input_fn=eval_input_fn,
    396         predict_input_fn=predict_input_fn,
    397         input_dimension=input_dimension,
    398         n_classes=n_classes,
    399         batch_size=batch_size)
    400 
    401   def test_pandas_input_fn(self):
    402     """Tests complete flow with pandas_input_fn."""
    403     if not HAS_PANDAS:
    404       return
    405     input_dimension = 1
    406     n_classes = 3
    407     batch_size = 10
    408     data = np.linspace(0., n_classes - 1., batch_size, dtype=np.float32)
    409     x = pd.DataFrame({'x': data})
    410     y = pd.Series(self._as_label(data))
    411     train_input_fn = pandas_io.pandas_input_fn(
    412         x=x,
    413         y=y,
    414         batch_size=batch_size,
    415         num_epochs=None,
    416         shuffle=True)
    417     eval_input_fn = pandas_io.pandas_input_fn(
    418         x=x,
    419         y=y,
    420         batch_size=batch_size,
    421         shuffle=False)
    422     predict_input_fn = pandas_io.pandas_input_fn(
    423         x=x,
    424         batch_size=batch_size,
    425         shuffle=False)
    426 
    427     self._test_complete_flow(
    428         train_input_fn=train_input_fn,
    429         eval_input_fn=eval_input_fn,
    430         predict_input_fn=predict_input_fn,
    431         input_dimension=input_dimension,
    432         n_classes=n_classes,
    433         batch_size=batch_size)
    434 
    435   def test_input_fn_from_parse_example(self):
    436     """Tests complete flow with input_fn constructed from parse_example."""
    437     input_dimension = 2
    438     n_classes = 3
    439     batch_size = 10
    440     data = np.linspace(
    441         0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
    442     data = data.reshape(batch_size, input_dimension)
    443 
    444     serialized_examples = []
    445     for datum in data:
    446       example = example_pb2.Example(features=feature_pb2.Features(
    447           feature={
    448               'x':
    449                   feature_pb2.Feature(float_list=feature_pb2.FloatList(
    450                       value=datum)),
    451               'y':
    452                   feature_pb2.Feature(int64_list=feature_pb2.Int64List(
    453                       value=self._as_label(datum[:1]))),
    454           }))
    455       serialized_examples.append(example.SerializeToString())
    456 
    457     feature_spec = {
    458         'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32),
    459         'y': parsing_ops.FixedLenFeature([1], dtypes.int64),
    460     }
    461     def _train_input_fn():
    462       feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
    463       features = _queue_parsed_features(feature_map)
    464       labels = features.pop('y')
    465       return features, labels
    466     def _eval_input_fn():
    467       feature_map = parsing_ops.parse_example(
    468           input_lib.limit_epochs(serialized_examples, num_epochs=1),
    469           feature_spec)
    470       features = _queue_parsed_features(feature_map)
    471       labels = features.pop('y')
    472       return features, labels
    473     def _predict_input_fn():
    474       feature_map = parsing_ops.parse_example(
    475           input_lib.limit_epochs(serialized_examples, num_epochs=1),
    476           feature_spec)
    477       features = _queue_parsed_features(feature_map)
    478       features.pop('y')
    479       return features, None
    480 
    481     self._test_complete_flow(
    482         train_input_fn=_train_input_fn,
    483         eval_input_fn=_eval_input_fn,
    484         predict_input_fn=_predict_input_fn,
    485         input_dimension=input_dimension,
    486         n_classes=n_classes,
    487         batch_size=batch_size)
    488 
    489 
    490 if __name__ == '__main__':
    491   test.main()
    492