Home | History | Annotate | Download | only in canned
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for dnn_linear_combined.py."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import shutil
     22 import tempfile
     23 
     24 import numpy as np
     25 import six
     26 
     27 from tensorflow.core.example import example_pb2
     28 from tensorflow.core.example import feature_pb2
     29 from tensorflow.python.estimator import warm_starting_util
     30 from tensorflow.python.estimator.canned import dnn_linear_combined
     31 from tensorflow.python.estimator.canned import dnn_testing_utils
     32 from tensorflow.python.estimator.canned import linear_testing_utils
     33 from tensorflow.python.estimator.canned import prediction_keys
     34 from tensorflow.python.estimator.export import export
     35 from tensorflow.python.estimator.inputs import numpy_io
     36 from tensorflow.python.estimator.inputs import pandas_io
     37 from tensorflow.python.feature_column import feature_column
     38 from tensorflow.python.framework import dtypes
     39 from tensorflow.python.framework import ops
     40 from tensorflow.python.ops import nn
     41 from tensorflow.python.ops import parsing_ops
     42 from tensorflow.python.ops import variables as variables_lib
     43 from tensorflow.python.platform import gfile
     44 from tensorflow.python.platform import test
     45 from tensorflow.python.summary.writer import writer_cache
     46 from tensorflow.python.training import checkpoint_utils
     47 from tensorflow.python.training import gradient_descent
     48 from tensorflow.python.training import input as input_lib
     49 from tensorflow.python.training import optimizer as optimizer_lib
     50 
     51 
     52 try:
     53   # pylint: disable=g-import-not-at-top
     54   import pandas as pd
     55   HAS_PANDAS = True
     56 except IOError:
     57   # Pandas writes a temporary file during import. If it fails, don't use pandas.
     58   HAS_PANDAS = False
     59 except ImportError:
     60   HAS_PANDAS = False
     61 
     62 
     63 class DNNOnlyModelFnTest(dnn_testing_utils.BaseDNNModelFnTest, test.TestCase):
     64 
     65   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     66     test.TestCase.__init__(self, methodName)
     67     dnn_testing_utils.BaseDNNModelFnTest.__init__(self, self._dnn_only_model_fn)
     68 
     69   def _dnn_only_model_fn(self,
     70                          features,
     71                          labels,
     72                          mode,
     73                          head,
     74                          hidden_units,
     75                          feature_columns,
     76                          optimizer='Adagrad',
     77                          activation_fn=nn.relu,
     78                          dropout=None,
     79                          input_layer_partitioner=None,
     80                          config=None):
     81     return dnn_linear_combined._dnn_linear_combined_model_fn(
     82         features=features,
     83         labels=labels,
     84         mode=mode,
     85         head=head,
     86         linear_feature_columns=[],
     87         dnn_hidden_units=hidden_units,
     88         dnn_feature_columns=feature_columns,
     89         dnn_optimizer=optimizer,
     90         dnn_activation_fn=activation_fn,
     91         dnn_dropout=dropout,
     92         input_layer_partitioner=input_layer_partitioner,
     93         config=config)
     94 
     95 
     96 # A function to mimic linear-regressor init reuse same tests.
     97 def _linear_regressor_fn(feature_columns,
     98                          model_dir=None,
     99                          label_dimension=1,
    100                          weight_column=None,
    101                          optimizer='Ftrl',
    102                          config=None,
    103                          partitioner=None):
    104   return dnn_linear_combined.DNNLinearCombinedRegressor(
    105       model_dir=model_dir,
    106       linear_feature_columns=feature_columns,
    107       linear_optimizer=optimizer,
    108       label_dimension=label_dimension,
    109       weight_column=weight_column,
    110       input_layer_partitioner=partitioner,
    111       config=config)
    112 
    113 
    114 class LinearOnlyRegressorPartitionerTest(
    115     linear_testing_utils.BaseLinearRegressorPartitionerTest, test.TestCase):
    116 
    117   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    118     test.TestCase.__init__(self, methodName)
    119     linear_testing_utils.BaseLinearRegressorPartitionerTest.__init__(
    120         self, _linear_regressor_fn)
    121 
    122 
    123 class LinearOnlyRegressorEvaluationTest(
    124     linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase):
    125 
    126   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    127     test.TestCase.__init__(self, methodName)
    128     linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(
    129         self, _linear_regressor_fn)
    130 
    131 
    132 class LinearOnlyRegressorPredictTest(
    133     linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase):
    134 
    135   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    136     test.TestCase.__init__(self, methodName)
    137     linear_testing_utils.BaseLinearRegressorPredictTest.__init__(
    138         self, _linear_regressor_fn)
    139 
    140 
    141 class LinearOnlyRegressorIntegrationTest(
    142     linear_testing_utils.BaseLinearRegressorIntegrationTest, test.TestCase):
    143 
    144   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    145     test.TestCase.__init__(self, methodName)
    146     linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__(
    147         self, _linear_regressor_fn)
    148 
    149 
    150 class LinearOnlyRegressorTrainingTest(
    151     linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase):
    152 
    153   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    154     test.TestCase.__init__(self, methodName)
    155     linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(
    156         self, _linear_regressor_fn)
    157 
    158 
    159 def _linear_classifier_fn(feature_columns,
    160                           model_dir=None,
    161                           n_classes=2,
    162                           weight_column=None,
    163                           label_vocabulary=None,
    164                           optimizer='Ftrl',
    165                           config=None,
    166                           partitioner=None):
    167   return dnn_linear_combined.DNNLinearCombinedClassifier(
    168       model_dir=model_dir,
    169       linear_feature_columns=feature_columns,
    170       linear_optimizer=optimizer,
    171       n_classes=n_classes,
    172       weight_column=weight_column,
    173       label_vocabulary=label_vocabulary,
    174       input_layer_partitioner=partitioner,
    175       config=config)
    176 
    177 
    178 class LinearOnlyClassifierTrainingTest(
    179     linear_testing_utils.BaseLinearClassifierTrainingTest, test.TestCase):
    180 
    181   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    182     test.TestCase.__init__(self, methodName)
    183     linear_testing_utils.BaseLinearClassifierTrainingTest.__init__(
    184         self, linear_classifier_fn=_linear_classifier_fn)
    185 
    186 
    187 class LinearOnlyClassifierClassesEvaluationTest(
    188     linear_testing_utils.BaseLinearClassifierEvaluationTest, test.TestCase):
    189 
    190   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    191     test.TestCase.__init__(self, methodName)
    192     linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__(
    193         self, linear_classifier_fn=_linear_classifier_fn)
    194 
    195 
    196 class LinearOnlyClassifierPredictTest(
    197     linear_testing_utils.BaseLinearClassifierPredictTest, test.TestCase):
    198 
    199   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    200     test.TestCase.__init__(self, methodName)
    201     linear_testing_utils.BaseLinearClassifierPredictTest.__init__(
    202         self, linear_classifier_fn=_linear_classifier_fn)
    203 
    204 
    205 class LinearOnlyClassifierIntegrationTest(
    206     linear_testing_utils.BaseLinearClassifierIntegrationTest, test.TestCase):
    207 
    208   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    209     test.TestCase.__init__(self, methodName)
    210     linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__(
    211         self, linear_classifier_fn=_linear_classifier_fn)
    212 
    213 
    214 class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
    215 
    216   def setUp(self):
    217     self._model_dir = tempfile.mkdtemp()
    218 
    219   def tearDown(self):
    220     if self._model_dir:
    221       writer_cache.FileWriterCache.clear()
    222       shutil.rmtree(self._model_dir)
    223 
    224   def _test_complete_flow(
    225       self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
    226       label_dimension, batch_size):
    227     linear_feature_columns = [
    228         feature_column.numeric_column('x', shape=(input_dimension,))]
    229     dnn_feature_columns = [
    230         feature_column.numeric_column('x', shape=(input_dimension,))]
    231     feature_columns = linear_feature_columns + dnn_feature_columns
    232     est = dnn_linear_combined.DNNLinearCombinedRegressor(
    233         linear_feature_columns=linear_feature_columns,
    234         dnn_hidden_units=(2, 2),
    235         dnn_feature_columns=dnn_feature_columns,
    236         label_dimension=label_dimension,
    237         model_dir=self._model_dir)
    238 
    239     # TRAIN
    240     num_steps = 10
    241     est.train(train_input_fn, steps=num_steps)
    242 
    243     # EVALUTE
    244     scores = est.evaluate(eval_input_fn)
    245     self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
    246     self.assertIn('loss', six.iterkeys(scores))
    247 
    248     # PREDICT
    249     predictions = np.array([
    250         x[prediction_keys.PredictionKeys.PREDICTIONS]
    251         for x in est.predict(predict_input_fn)
    252     ])
    253     self.assertAllEqual((batch_size, label_dimension), predictions.shape)
    254 
    255     # EXPORT
    256     feature_spec = feature_column.make_parse_example_spec(feature_columns)
    257     serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
    258         feature_spec)
    259     export_dir = est.export_savedmodel(tempfile.mkdtemp(),
    260                                        serving_input_receiver_fn)
    261     self.assertTrue(gfile.Exists(export_dir))
    262 
    263   def test_numpy_input_fn(self):
    264     """Tests complete flow with numpy_input_fn."""
    265     label_dimension = 2
    266     batch_size = 10
    267     data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
    268     data = data.reshape(batch_size, label_dimension)
    269     # learn y = x
    270     train_input_fn = numpy_io.numpy_input_fn(
    271         x={'x': data},
    272         y=data,
    273         batch_size=batch_size,
    274         num_epochs=None,
    275         shuffle=True)
    276     eval_input_fn = numpy_io.numpy_input_fn(
    277         x={'x': data},
    278         y=data,
    279         batch_size=batch_size,
    280         shuffle=False)
    281     predict_input_fn = numpy_io.numpy_input_fn(
    282         x={'x': data},
    283         batch_size=batch_size,
    284         shuffle=False)
    285 
    286     self._test_complete_flow(
    287         train_input_fn=train_input_fn,
    288         eval_input_fn=eval_input_fn,
    289         predict_input_fn=predict_input_fn,
    290         input_dimension=label_dimension,
    291         label_dimension=label_dimension,
    292         batch_size=batch_size)
    293 
    294   def test_pandas_input_fn(self):
    295     """Tests complete flow with pandas_input_fn."""
    296     if not HAS_PANDAS:
    297       return
    298     label_dimension = 1
    299     batch_size = 10
    300     data = np.linspace(0., 2., batch_size, dtype=np.float32)
    301     x = pd.DataFrame({'x': data})
    302     y = pd.Series(data)
    303     train_input_fn = pandas_io.pandas_input_fn(
    304         x=x,
    305         y=y,
    306         batch_size=batch_size,
    307         num_epochs=None,
    308         shuffle=True)
    309     eval_input_fn = pandas_io.pandas_input_fn(
    310         x=x,
    311         y=y,
    312         batch_size=batch_size,
    313         shuffle=False)
    314     predict_input_fn = pandas_io.pandas_input_fn(
    315         x=x,
    316         batch_size=batch_size,
    317         shuffle=False)
    318 
    319     self._test_complete_flow(
    320         train_input_fn=train_input_fn,
    321         eval_input_fn=eval_input_fn,
    322         predict_input_fn=predict_input_fn,
    323         input_dimension=label_dimension,
    324         label_dimension=label_dimension,
    325         batch_size=batch_size)
    326 
    327   def test_input_fn_from_parse_example(self):
    328     """Tests complete flow with input_fn constructed from parse_example."""
    329     label_dimension = 2
    330     batch_size = 10
    331     data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
    332     data = data.reshape(batch_size, label_dimension)
    333 
    334     serialized_examples = []
    335     for datum in data:
    336       example = example_pb2.Example(features=feature_pb2.Features(
    337           feature={
    338               'x': feature_pb2.Feature(
    339                   float_list=feature_pb2.FloatList(value=datum)),
    340               'y': feature_pb2.Feature(
    341                   float_list=feature_pb2.FloatList(value=datum)),
    342           }))
    343       serialized_examples.append(example.SerializeToString())
    344 
    345     feature_spec = {
    346         'x': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32),
    347         'y': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32),
    348     }
    349     def _train_input_fn():
    350       feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
    351       features = linear_testing_utils.queue_parsed_features(feature_map)
    352       labels = features.pop('y')
    353       return features, labels
    354     def _eval_input_fn():
    355       feature_map = parsing_ops.parse_example(
    356           input_lib.limit_epochs(serialized_examples, num_epochs=1),
    357           feature_spec)
    358       features = linear_testing_utils.queue_parsed_features(feature_map)
    359       labels = features.pop('y')
    360       return features, labels
    361     def _predict_input_fn():
    362       feature_map = parsing_ops.parse_example(
    363           input_lib.limit_epochs(serialized_examples, num_epochs=1),
    364           feature_spec)
    365       features = linear_testing_utils.queue_parsed_features(feature_map)
    366       features.pop('y')
    367       return features, None
    368 
    369     self._test_complete_flow(
    370         train_input_fn=_train_input_fn,
    371         eval_input_fn=_eval_input_fn,
    372         predict_input_fn=_predict_input_fn,
    373         input_dimension=label_dimension,
    374         label_dimension=label_dimension,
    375         batch_size=batch_size)
    376 
    377 
    378 # A function to mimic dnn-classifier init reuse same tests.
    379 def _dnn_classifier_fn(hidden_units,
    380                        feature_columns,
    381                        model_dir=None,
    382                        n_classes=2,
    383                        weight_column=None,
    384                        label_vocabulary=None,
    385                        optimizer='Adagrad',
    386                        config=None,
    387                        input_layer_partitioner=None):
    388   return dnn_linear_combined.DNNLinearCombinedClassifier(
    389       model_dir=model_dir,
    390       dnn_hidden_units=hidden_units,
    391       dnn_feature_columns=feature_columns,
    392       dnn_optimizer=optimizer,
    393       n_classes=n_classes,
    394       weight_column=weight_column,
    395       label_vocabulary=label_vocabulary,
    396       input_layer_partitioner=input_layer_partitioner,
    397       config=config)
    398 
    399 
    400 class DNNOnlyClassifierEvaluateTest(
    401     dnn_testing_utils.BaseDNNClassifierEvaluateTest, test.TestCase):
    402 
    403   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    404     test.TestCase.__init__(self, methodName)
    405     dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
    406         self, _dnn_classifier_fn)
    407 
    408 
    409 class DNNOnlyClassifierPredictTest(
    410     dnn_testing_utils.BaseDNNClassifierPredictTest, test.TestCase):
    411 
    412   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    413     test.TestCase.__init__(self, methodName)
    414     dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
    415         self, _dnn_classifier_fn)
    416 
    417 
    418 class DNNOnlyClassifierTrainTest(
    419     dnn_testing_utils.BaseDNNClassifierTrainTest, test.TestCase):
    420 
    421   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    422     test.TestCase.__init__(self, methodName)
    423     dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(
    424         self, _dnn_classifier_fn)
    425 
    426 
    427 # A function to mimic dnn-regressor init reuse same tests.
    428 def _dnn_regressor_fn(hidden_units,
    429                       feature_columns,
    430                       model_dir=None,
    431                       label_dimension=1,
    432                       weight_column=None,
    433                       optimizer='Adagrad',
    434                       config=None,
    435                       input_layer_partitioner=None):
    436   return dnn_linear_combined.DNNLinearCombinedRegressor(
    437       model_dir=model_dir,
    438       dnn_hidden_units=hidden_units,
    439       dnn_feature_columns=feature_columns,
    440       dnn_optimizer=optimizer,
    441       label_dimension=label_dimension,
    442       weight_column=weight_column,
    443       input_layer_partitioner=input_layer_partitioner,
    444       config=config)
    445 
    446 
    447 class DNNOnlyRegressorEvaluateTest(
    448     dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase):
    449 
    450   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    451     test.TestCase.__init__(self, methodName)
    452     dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
    453         self, _dnn_regressor_fn)
    454 
    455 
    456 class DNNOnlyRegressorPredictTest(
    457     dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase):
    458 
    459   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    460     test.TestCase.__init__(self, methodName)
    461     dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
    462         self, _dnn_regressor_fn)
    463 
    464 
    465 class DNNOnlyRegressorTrainTest(
    466     dnn_testing_utils.BaseDNNRegressorTrainTest, test.TestCase):
    467 
    468   def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
    469     test.TestCase.__init__(self, methodName)
    470     dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
    471         self, _dnn_regressor_fn)
    472 
    473 
    474 class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
    475 
    476   def setUp(self):
    477     self._model_dir = tempfile.mkdtemp()
    478 
    479   def tearDown(self):
    480     if self._model_dir:
    481       writer_cache.FileWriterCache.clear()
    482       shutil.rmtree(self._model_dir)
    483 
    484   def _as_label(self, data_in_float):
    485     return np.rint(data_in_float).astype(np.int64)
    486 
    487   def _test_complete_flow(
    488       self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
    489       n_classes, batch_size):
    490     linear_feature_columns = [
    491         feature_column.numeric_column('x', shape=(input_dimension,))]
    492     dnn_feature_columns = [
    493         feature_column.numeric_column('x', shape=(input_dimension,))]
    494     feature_columns = linear_feature_columns + dnn_feature_columns
    495     est = dnn_linear_combined.DNNLinearCombinedClassifier(
    496         linear_feature_columns=linear_feature_columns,
    497         dnn_hidden_units=(2, 2),
    498         dnn_feature_columns=dnn_feature_columns,
    499         n_classes=n_classes,
    500         model_dir=self._model_dir)
    501 
    502     # TRAIN
    503     num_steps = 10
    504     est.train(train_input_fn, steps=num_steps)
    505 
    506     # EVALUTE
    507     scores = est.evaluate(eval_input_fn)
    508     self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
    509     self.assertIn('loss', six.iterkeys(scores))
    510 
    511     # PREDICT
    512     predicted_proba = np.array([
    513         x[prediction_keys.PredictionKeys.PROBABILITIES]
    514         for x in est.predict(predict_input_fn)
    515     ])
    516     self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
    517 
    518     # EXPORT
    519     feature_spec = feature_column.make_parse_example_spec(feature_columns)
    520     serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
    521         feature_spec)
    522     export_dir = est.export_savedmodel(tempfile.mkdtemp(),
    523                                        serving_input_receiver_fn)
    524     self.assertTrue(gfile.Exists(export_dir))
    525 
    526   def test_numpy_input_fn(self):
    527     """Tests complete flow with numpy_input_fn."""
    528     n_classes = 3
    529     input_dimension = 2
    530     batch_size = 10
    531     data = np.linspace(
    532         0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
    533     x_data = data.reshape(batch_size, input_dimension)
    534     y_data = self._as_label(np.reshape(data[:batch_size], (batch_size, 1)))
    535     # learn y = x
    536     train_input_fn = numpy_io.numpy_input_fn(
    537         x={'x': x_data},
    538         y=y_data,
    539         batch_size=batch_size,
    540         num_epochs=None,
    541         shuffle=True)
    542     eval_input_fn = numpy_io.numpy_input_fn(
    543         x={'x': x_data},
    544         y=y_data,
    545         batch_size=batch_size,
    546         shuffle=False)
    547     predict_input_fn = numpy_io.numpy_input_fn(
    548         x={'x': x_data},
    549         batch_size=batch_size,
    550         shuffle=False)
    551 
    552     self._test_complete_flow(
    553         train_input_fn=train_input_fn,
    554         eval_input_fn=eval_input_fn,
    555         predict_input_fn=predict_input_fn,
    556         input_dimension=input_dimension,
    557         n_classes=n_classes,
    558         batch_size=batch_size)
    559 
    560   def test_pandas_input_fn(self):
    561     """Tests complete flow with pandas_input_fn."""
    562     if not HAS_PANDAS:
    563       return
    564     input_dimension = 1
    565     n_classes = 2
    566     batch_size = 10
    567     data = np.linspace(0., n_classes - 1., batch_size, dtype=np.float32)
    568     x = pd.DataFrame({'x': data})
    569     y = pd.Series(self._as_label(data))
    570     train_input_fn = pandas_io.pandas_input_fn(
    571         x=x,
    572         y=y,
    573         batch_size=batch_size,
    574         num_epochs=None,
    575         shuffle=True)
    576     eval_input_fn = pandas_io.pandas_input_fn(
    577         x=x,
    578         y=y,
    579         batch_size=batch_size,
    580         shuffle=False)
    581     predict_input_fn = pandas_io.pandas_input_fn(
    582         x=x,
    583         batch_size=batch_size,
    584         shuffle=False)
    585 
    586     self._test_complete_flow(
    587         train_input_fn=train_input_fn,
    588         eval_input_fn=eval_input_fn,
    589         predict_input_fn=predict_input_fn,
    590         input_dimension=input_dimension,
    591         n_classes=n_classes,
    592         batch_size=batch_size)
    593 
    594   def test_input_fn_from_parse_example(self):
    595     """Tests complete flow with input_fn constructed from parse_example."""
    596     input_dimension = 2
    597     n_classes = 3
    598     batch_size = 10
    599     data = np.linspace(0., n_classes-1., batch_size * input_dimension,
    600                        dtype=np.float32)
    601     data = data.reshape(batch_size, input_dimension)
    602 
    603     serialized_examples = []
    604     for datum in data:
    605       example = example_pb2.Example(features=feature_pb2.Features(
    606           feature={
    607               'x':
    608                   feature_pb2.Feature(float_list=feature_pb2.FloatList(
    609                       value=datum)),
    610               'y':
    611                   feature_pb2.Feature(int64_list=feature_pb2.Int64List(
    612                       value=self._as_label(datum[:1]))),
    613           }))
    614       serialized_examples.append(example.SerializeToString())
    615 
    616     feature_spec = {
    617         'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32),
    618         'y': parsing_ops.FixedLenFeature([1], dtypes.int64),
    619     }
    620     def _train_input_fn():
    621       feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
    622       features = linear_testing_utils.queue_parsed_features(feature_map)
    623       labels = features.pop('y')
    624       return features, labels
    625     def _eval_input_fn():
    626       feature_map = parsing_ops.parse_example(
    627           input_lib.limit_epochs(serialized_examples, num_epochs=1),
    628           feature_spec)
    629       features = linear_testing_utils.queue_parsed_features(feature_map)
    630       labels = features.pop('y')
    631       return features, labels
    632     def _predict_input_fn():
    633       feature_map = parsing_ops.parse_example(
    634           input_lib.limit_epochs(serialized_examples, num_epochs=1),
    635           feature_spec)
    636       features = linear_testing_utils.queue_parsed_features(feature_map)
    637       features.pop('y')
    638       return features, None
    639 
    640     self._test_complete_flow(
    641         train_input_fn=_train_input_fn,
    642         eval_input_fn=_eval_input_fn,
    643         predict_input_fn=_predict_input_fn,
    644         input_dimension=input_dimension,
    645         n_classes=n_classes,
    646         batch_size=batch_size)
    647 
    648 
    649 class DNNLinearCombinedTests(test.TestCase):
    650 
    651   def setUp(self):
    652     self._model_dir = tempfile.mkdtemp()
    653 
    654   def tearDown(self):
    655     if self._model_dir:
    656       shutil.rmtree(self._model_dir)
    657 
    658   def _mock_optimizer(self, real_optimizer, var_name_prefix):
    659     """Verifies global_step is None and var_names start with given prefix."""
    660 
    661     def _minimize(loss, global_step=None, var_list=None):
    662       self.assertIsNone(global_step)
    663       trainable_vars = var_list or ops.get_collection(
    664           ops.GraphKeys.TRAINABLE_VARIABLES)
    665       var_names = [var.name for var in trainable_vars]
    666       self.assertTrue(
    667           all([name.startswith(var_name_prefix) for name in var_names]))
    668       # var is used to check this op called by training.
    669       with ops.name_scope(''):
    670         var = variables_lib.Variable(0., name=(var_name_prefix + '_called'))
    671       with ops.control_dependencies([var.assign(100.)]):
    672         return real_optimizer.minimize(loss, global_step, var_list)
    673 
    674     optimizer_mock = test.mock.NonCallableMagicMock(
    675         spec=optimizer_lib.Optimizer, wraps=real_optimizer)
    676     optimizer_mock.minimize = test.mock.MagicMock(wraps=_minimize)
    677 
    678     return optimizer_mock
    679 
    680   def test_train_op_calls_both_dnn_and_linear(self):
    681     opt = gradient_descent.GradientDescentOptimizer(1.)
    682     x_column = feature_column.numeric_column('x')
    683     input_fn = numpy_io.numpy_input_fn(
    684         x={'x': np.array([[0.], [1.]])},
    685         y=np.array([[0.], [1.]]),
    686         batch_size=1,
    687         shuffle=False)
    688     est = dnn_linear_combined.DNNLinearCombinedClassifier(
    689         linear_feature_columns=[x_column],
    690         # verifies linear_optimizer is used only for linear part.
    691         linear_optimizer=self._mock_optimizer(opt, 'linear'),
    692         dnn_hidden_units=(2, 2),
    693         dnn_feature_columns=[x_column],
    694         # verifies dnn_optimizer is used only for linear part.
    695         dnn_optimizer=self._mock_optimizer(opt, 'dnn'),
    696         model_dir=self._model_dir)
    697     est.train(input_fn, steps=1)
    698     # verifies train_op fires linear minimize op
    699     self.assertEqual(100.,
    700                      checkpoint_utils.load_variable(
    701                          self._model_dir, 'linear_called'))
    702     # verifies train_op fires dnn minimize op
    703     self.assertEqual(100.,
    704                      checkpoint_utils.load_variable(
    705                          self._model_dir, 'dnn_called'))
    706 
    707   def test_dnn_and_linear_logits_are_added(self):
    708     with ops.Graph().as_default():
    709       variables_lib.Variable([[1.0]], name='linear/linear_model/x/weights')
    710       variables_lib.Variable([2.0], name='linear/linear_model/bias_weights')
    711       variables_lib.Variable([[3.0]], name='dnn/hiddenlayer_0/kernel')
    712       variables_lib.Variable([4.0], name='dnn/hiddenlayer_0/bias')
    713       variables_lib.Variable([[5.0]], name='dnn/logits/kernel')
    714       variables_lib.Variable([6.0], name='dnn/logits/bias')
    715       variables_lib.Variable(1, name='global_step', dtype=dtypes.int64)
    716       linear_testing_utils.save_variables_to_ckpt(self._model_dir)
    717 
    718     x_column = feature_column.numeric_column('x')
    719     est = dnn_linear_combined.DNNLinearCombinedRegressor(
    720         linear_feature_columns=[x_column],
    721         dnn_hidden_units=[1],
    722         dnn_feature_columns=[x_column],
    723         model_dir=self._model_dir)
    724     input_fn = numpy_io.numpy_input_fn(
    725         x={'x': np.array([[10.]])}, batch_size=1, shuffle=False)
    726     # linear logits = 10*1 + 2 = 12
    727     # dnn logits = (10*3 + 4)*5 + 6 = 176
    728     # logits = dnn + linear = 176 + 12 = 188
    729     self.assertAllClose(
    730         {
    731             prediction_keys.PredictionKeys.PREDICTIONS: [188.],
    732         },
    733         next(est.predict(input_fn=input_fn)))
    734 
    735 
    736 class DNNLinearCombinedWarmStartingTest(test.TestCase):
    737 
    738   def setUp(self):
    739     # Create a directory to save our old checkpoint and vocabularies to.
    740     self._ckpt_and_vocab_dir = tempfile.mkdtemp()
    741 
    742     # Make a dummy input_fn.
    743     def _input_fn():
    744       features = {
    745           'age': [[23.], [31.]],
    746           'city': [['Palo Alto'], ['Mountain View']],
    747       }
    748       return features, [0, 1]
    749 
    750     self._input_fn = _input_fn
    751 
    752   def tearDown(self):
    753     # Clean up checkpoint / vocab dir.
    754     writer_cache.FileWriterCache.clear()
    755     shutil.rmtree(self._ckpt_and_vocab_dir)
    756 
    757   def test_classifier_basic_warm_starting(self):
    758     """Tests correctness of DNNLinearCombinedClassifier default warm-start."""
    759     age = feature_column.numeric_column('age')
    760     city = feature_column.embedding_column(
    761         feature_column.categorical_column_with_vocabulary_list(
    762             'city', vocabulary_list=['Mountain View', 'Palo Alto']),
    763         dimension=5)
    764 
    765     # Create a DNNLinearCombinedClassifier and train to save a checkpoint.
    766     dnn_lc_classifier = dnn_linear_combined.DNNLinearCombinedClassifier(
    767         linear_feature_columns=[age],
    768         dnn_feature_columns=[city],
    769         dnn_hidden_units=[256, 128],
    770         model_dir=self._ckpt_and_vocab_dir,
    771         n_classes=4,
    772         linear_optimizer='SGD',
    773         dnn_optimizer='SGD')
    774     dnn_lc_classifier.train(input_fn=self._input_fn, max_steps=1)
    775 
    776     # Create a second DNNLinearCombinedClassifier, warm-started from the first.
    777     # Use a learning_rate = 0.0 optimizer to check values (use SGD so we don't
    778     # have accumulator values that change).
    779     warm_started_dnn_lc_classifier = (
    780         dnn_linear_combined.DNNLinearCombinedClassifier(
    781             linear_feature_columns=[age],
    782             dnn_feature_columns=[city],
    783             dnn_hidden_units=[256, 128],
    784             n_classes=4,
    785             linear_optimizer=gradient_descent.GradientDescentOptimizer(
    786                 learning_rate=0.0),
    787             dnn_optimizer=gradient_descent.GradientDescentOptimizer(
    788                 learning_rate=0.0),
    789             warm_start_from=dnn_lc_classifier.model_dir))
    790 
    791     warm_started_dnn_lc_classifier.train(input_fn=self._input_fn, max_steps=1)
    792     for variable_name in warm_started_dnn_lc_classifier.get_variable_names():
    793       self.assertAllClose(
    794           dnn_lc_classifier.get_variable_value(variable_name),
    795           warm_started_dnn_lc_classifier.get_variable_value(variable_name))
    796 
    797   def test_regressor_basic_warm_starting(self):
    798     """Tests correctness of DNNLinearCombinedRegressor default warm-start."""
    799     age = feature_column.numeric_column('age')
    800     city = feature_column.embedding_column(
    801         feature_column.categorical_column_with_vocabulary_list(
    802             'city', vocabulary_list=['Mountain View', 'Palo Alto']),
    803         dimension=5)
    804 
    805     # Create a DNNLinearCombinedRegressor and train to save a checkpoint.
    806     dnn_lc_regressor = dnn_linear_combined.DNNLinearCombinedRegressor(
    807         linear_feature_columns=[age],
    808         dnn_feature_columns=[city],
    809         dnn_hidden_units=[256, 128],
    810         model_dir=self._ckpt_and_vocab_dir,
    811         linear_optimizer='SGD',
    812         dnn_optimizer='SGD')
    813     dnn_lc_regressor.train(input_fn=self._input_fn, max_steps=1)
    814 
    815     # Create a second DNNLinearCombinedRegressor, warm-started from the first.
    816     # Use a learning_rate = 0.0 optimizer to check values (use SGD so we don't
    817     # have accumulator values that change).
    818     warm_started_dnn_lc_regressor = (
    819         dnn_linear_combined.DNNLinearCombinedRegressor(
    820             linear_feature_columns=[age],
    821             dnn_feature_columns=[city],
    822             dnn_hidden_units=[256, 128],
    823             linear_optimizer=gradient_descent.GradientDescentOptimizer(
    824                 learning_rate=0.0),
    825             dnn_optimizer=gradient_descent.GradientDescentOptimizer(
    826                 learning_rate=0.0),
    827             warm_start_from=dnn_lc_regressor.model_dir))
    828 
    829     warm_started_dnn_lc_regressor.train(input_fn=self._input_fn, max_steps=1)
    830     for variable_name in warm_started_dnn_lc_regressor.get_variable_names():
    831       self.assertAllClose(
    832           dnn_lc_regressor.get_variable_value(variable_name),
    833           warm_started_dnn_lc_regressor.get_variable_value(variable_name))
    834 
    835   def test_warm_starting_selective_variables(self):
    836     """Tests selecting variables to warm-start."""
    837     age = feature_column.numeric_column('age')
    838     city = feature_column.embedding_column(
    839         feature_column.categorical_column_with_vocabulary_list(
    840             'city', vocabulary_list=['Mountain View', 'Palo Alto']),
    841         dimension=5)
    842 
    843     # Create a DNNLinearCombinedClassifier and train to save a checkpoint.
    844     dnn_lc_classifier = dnn_linear_combined.DNNLinearCombinedClassifier(
    845         linear_feature_columns=[age],
    846         dnn_feature_columns=[city],
    847         dnn_hidden_units=[256, 128],
    848         model_dir=self._ckpt_and_vocab_dir,
    849         n_classes=4,
    850         linear_optimizer='SGD',
    851         dnn_optimizer='SGD')
    852     dnn_lc_classifier.train(input_fn=self._input_fn, max_steps=1)
    853 
    854     # Create a second DNNLinearCombinedClassifier, warm-started from the first.
    855     # Use a learning_rate = 0.0 optimizer to check values (use SGD so we don't
    856     # have accumulator values that change).
    857     warm_started_dnn_lc_classifier = (
    858         dnn_linear_combined.DNNLinearCombinedClassifier(
    859             linear_feature_columns=[age],
    860             dnn_feature_columns=[city],
    861             dnn_hidden_units=[256, 128],
    862             n_classes=4,
    863             linear_optimizer=gradient_descent.GradientDescentOptimizer(
    864                 learning_rate=0.0),
    865             dnn_optimizer=gradient_descent.GradientDescentOptimizer(
    866                 learning_rate=0.0),
    867             # The provided regular expression will only warm-start the deep
    868             # portion of the model.
    869             warm_start_from=warm_starting_util.WarmStartSettings(
    870                 ckpt_to_initialize_from=dnn_lc_classifier.model_dir,
    871                 vars_to_warm_start='.*(dnn).*')))
    872 
    873     warm_started_dnn_lc_classifier.train(input_fn=self._input_fn, max_steps=1)
    874     for variable_name in warm_started_dnn_lc_classifier.get_variable_names():
    875       if 'dnn' in variable_name:
    876         self.assertAllClose(
    877             dnn_lc_classifier.get_variable_value(variable_name),
    878             warm_started_dnn_lc_classifier.get_variable_value(variable_name))
    879       elif 'linear' in variable_name:
    880         linear_values = warm_started_dnn_lc_classifier.get_variable_value(
    881             variable_name)
    882         # Since they're not warm-started, the linear weights will be
    883         # zero-initialized.
    884         self.assertAllClose(np.zeros_like(linear_values), linear_values)
    885 
    886 
    887 if __name__ == '__main__':
    888   test.main()
    889