Home | History | Annotate | Download | only in timeseries
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Utilities for testing time series models."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.timeseries.python.timeseries import estimators
     22 from tensorflow.contrib.timeseries.python.timeseries import input_pipeline
     23 from tensorflow.contrib.timeseries.python.timeseries import state_management
     24 from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures
     25 
     26 from tensorflow.python.client import session
     27 from tensorflow.python.estimator import estimator_lib
     28 from tensorflow.python.framework import ops
     29 from tensorflow.python.framework import random_seed
     30 from tensorflow.python.ops import array_ops
     31 from tensorflow.python.ops import math_ops
     32 from tensorflow.python.ops import variables
     33 from tensorflow.python.platform import tf_logging as logging
     34 from tensorflow.python.training import adam
     35 from tensorflow.python.training import basic_session_run_hooks
     36 from tensorflow.python.training import coordinator as coordinator_lib
     37 from tensorflow.python.training import queue_runner_impl
     38 from tensorflow.python.util import nest
     39 
     40 
     41 class AllWindowInputFn(input_pipeline.TimeSeriesInputFn):
     42   """Returns all contiguous windows of data from a full dataset.
     43 
     44   In contrast to WholeDatasetInputFn, which does basic shape checking but
     45   maintains the flat sequencing of data, this `TimeSeriesInputFn` creates
     46   batches of windows. However, unlike `RandomWindowInputFn` these windows are
     47   deterministic, starting at every possible offset (i.e. batches of size
     48   series_length - window_size + 1 are produced).
     49   """
     50 
     51   def __init__(self, time_series_reader, window_size):
     52     """Initialize the input_pipeline.
     53 
     54     Args:
     55       time_series_reader: A `input_pipeline.TimeSeriesReader` object.
     56       window_size: The size of contiguous windows of data to produce.
     57     """
     58     self._window_size = window_size
     59     self._reader = time_series_reader
     60     super(AllWindowInputFn, self).__init__()
     61 
     62   def create_batch(self):
     63     features = self._reader.read_full()
     64     times = features[TrainEvalFeatures.TIMES]
     65     num_windows = array_ops.shape(times)[0] - self._window_size + 1
     66     indices = array_ops.reshape(math_ops.range(num_windows), [num_windows, 1])
     67     # indices contains the starting point for each window. We now extend these
     68     # indices to include the elements inside the windows as well by doing a
     69     # broadcast addition.
     70     increments = array_ops.reshape(math_ops.range(self._window_size), [1, -1])
     71     all_indices = array_ops.reshape(indices + increments, [-1])
     72     # Select the appropriate elements in the batch and reshape the output to 3D.
     73     features = {
     74         key: array_ops.reshape(
     75             array_ops.gather(value, all_indices),
     76             array_ops.concat(
     77                 [[num_windows, self._window_size], array_ops.shape(value)[1:]],
     78                 axis=0))
     79         for key, value in features.items()
     80     }
     81     return (features, None)
     82 
     83 
     84 class _SavingTensorHook(basic_session_run_hooks.LoggingTensorHook):
     85   """A hook to save Tensors during training."""
     86 
     87   def __init__(self, tensors, every_n_iter=None, every_n_secs=None):
     88     self.tensor_values = {}
     89     super(_SavingTensorHook, self).__init__(
     90         tensors=tensors, every_n_iter=every_n_iter,
     91         every_n_secs=every_n_secs)
     92 
     93   def after_run(self, run_context, run_values):
     94     del run_context
     95     if self._should_trigger:
     96       for tag in self._current_tensors.keys():
     97         self.tensor_values[tag] = run_values.results[tag]
     98       self._timer.update_last_triggered_step(self._iter_count)
     99     self._iter_count += 1
    100 
    101 
    102 def _train_on_generated_data(
    103     generate_fn, generative_model, train_iterations, seed,
    104     learning_rate=0.1, ignore_params_fn=lambda _: (),
    105     derived_param_test_fn=lambda _: (),
    106     train_input_fn_type=input_pipeline.WholeDatasetInputFn,
    107     train_state_manager=state_management.PassthroughStateManager()):
    108   """The training portion of parameter recovery tests."""
    109   random_seed.set_random_seed(seed)
    110   generate_graph = ops.Graph()
    111   with generate_graph.as_default():
    112     with session.Session(graph=generate_graph):
    113       generative_model.initialize_graph()
    114       time_series_reader, true_parameters = generate_fn(generative_model)
    115       true_parameters = {
    116           tensor.name: value for tensor, value in true_parameters.items()}
    117   eval_input_fn = input_pipeline.WholeDatasetInputFn(time_series_reader)
    118   eval_state_manager = state_management.PassthroughStateManager()
    119   true_parameter_eval_graph = ops.Graph()
    120   with true_parameter_eval_graph.as_default():
    121     generative_model.initialize_graph()
    122     ignore_params = ignore_params_fn(generative_model)
    123     feature_dict, _ = eval_input_fn()
    124     eval_state_manager.initialize_graph(generative_model)
    125     feature_dict[TrainEvalFeatures.VALUES] = math_ops.cast(
    126         feature_dict[TrainEvalFeatures.VALUES], generative_model.dtype)
    127     model_outputs = eval_state_manager.define_loss(
    128         model=generative_model,
    129         features=feature_dict,
    130         mode=estimator_lib.ModeKeys.EVAL)
    131     with session.Session(graph=true_parameter_eval_graph) as sess:
    132       variables.global_variables_initializer().run()
    133       coordinator = coordinator_lib.Coordinator()
    134       queue_runner_impl.start_queue_runners(sess, coord=coordinator)
    135       true_param_loss = model_outputs.loss.eval(feed_dict=true_parameters)
    136       true_transformed_params = {
    137           param: param.eval(feed_dict=true_parameters)
    138           for param in derived_param_test_fn(generative_model)}
    139       coordinator.request_stop()
    140       coordinator.join()
    141 
    142   saving_hook = _SavingTensorHook(
    143       tensors=true_parameters.keys(),
    144       every_n_iter=train_iterations - 1)
    145 
    146   class _RunConfig(estimator_lib.RunConfig):
    147 
    148     @property
    149     def tf_random_seed(self):
    150       return seed
    151 
    152   estimator = estimators.TimeSeriesRegressor(
    153       model=generative_model,
    154       config=_RunConfig(),
    155       state_manager=train_state_manager,
    156       optimizer=adam.AdamOptimizer(learning_rate))
    157   train_input_fn = train_input_fn_type(time_series_reader=time_series_reader)
    158   trained_loss = (estimator.train(
    159       input_fn=train_input_fn,
    160       max_steps=train_iterations,
    161       hooks=[saving_hook]).evaluate(
    162           input_fn=eval_input_fn, steps=1))["loss"]
    163   logging.info("Final trained loss: %f", trained_loss)
    164   logging.info("True parameter loss: %f", true_param_loss)
    165   return (ignore_params, true_parameters, true_transformed_params,
    166           trained_loss, true_param_loss, saving_hook,
    167           true_parameter_eval_graph)
    168 
    169 
    170 def test_parameter_recovery(
    171     generate_fn, generative_model, train_iterations, test_case, seed,
    172     learning_rate=0.1, rtol=0.2, atol=0.1, train_loss_tolerance_coeff=0.99,
    173     ignore_params_fn=lambda _: (),
    174     derived_param_test_fn=lambda _: (),
    175     train_input_fn_type=input_pipeline.WholeDatasetInputFn,
    176     train_state_manager=state_management.PassthroughStateManager()):
    177   """Test that a generative model fits generated data.
    178 
    179   Args:
    180     generate_fn: A function taking a model and returning a `TimeSeriesReader`
    181         object and dictionary mapping parameters to their
    182         values. model.initialize_graph() will have been called on the model
    183         before it is passed to this function.
    184     generative_model: A timeseries.model.TimeSeriesModel instance to test.
    185     train_iterations: Number of training steps.
    186     test_case: A tf.test.TestCase to run assertions on.
    187     seed: Same as for TimeSeriesModel.unconditional_generate().
    188     learning_rate: Step size for optimization.
    189     rtol: Relative tolerance for tests.
    190     atol: Absolute tolerance for tests.
    191     train_loss_tolerance_coeff: Trained loss times this value must be less
    192         than the loss evaluated using the generated parameters.
    193     ignore_params_fn: Function mapping from a Model to a list of parameters
    194         which are not tested for accurate recovery.
    195     derived_param_test_fn: Function returning a list of derived parameters
    196         (Tensors) which are checked for accurate recovery (comparing the value
    197         evaluated with trained parameters to the value under the true
    198         parameters).
    199 
    200         As an example, for VARMA, in addition to checking AR and MA parameters,
    201         this function can be used to also check lagged covariance. See
    202         varma_ssm.py for details.
    203     train_input_fn_type: The `TimeSeriesInputFn` type to use when training
    204         (likely `WholeDatasetInputFn` or `RandomWindowInputFn`). If None, use
    205         `WholeDatasetInputFn`.
    206     train_state_manager: The state manager to use when training (likely
    207         `PassthroughStateManager` or `ChainingStateManager`). If None, use
    208         `PassthroughStateManager`.
    209   """
    210   (ignore_params, true_parameters, true_transformed_params,
    211    trained_loss, true_param_loss, saving_hook, true_parameter_eval_graph
    212   ) = _train_on_generated_data(
    213       generate_fn=generate_fn, generative_model=generative_model,
    214       train_iterations=train_iterations, seed=seed, learning_rate=learning_rate,
    215       ignore_params_fn=ignore_params_fn,
    216       derived_param_test_fn=derived_param_test_fn,
    217       train_input_fn_type=train_input_fn_type,
    218       train_state_manager=train_state_manager)
    219   trained_parameter_substitutions = {}
    220   for param in true_parameters.keys():
    221     evaled_value = saving_hook.tensor_values[param]
    222     trained_parameter_substitutions[param] = evaled_value
    223     true_value = true_parameters[param]
    224     logging.info("True %s: %s, learned: %s",
    225                  param, true_value, evaled_value)
    226   with session.Session(graph=true_parameter_eval_graph):
    227     for transformed_param, true_value in true_transformed_params.items():
    228       trained_value = transformed_param.eval(
    229           feed_dict=trained_parameter_substitutions)
    230       logging.info("True %s [transformed parameter]: %s, learned: %s",
    231                    transformed_param, true_value, trained_value)
    232       test_case.assertAllClose(true_value, trained_value,
    233                                rtol=rtol, atol=atol)
    234 
    235   if ignore_params is None:
    236     ignore_params = []
    237   else:
    238     ignore_params = nest.flatten(ignore_params)
    239   ignore_params = [tensor.name for tensor in ignore_params]
    240   if trained_loss > 0:
    241     test_case.assertLess(trained_loss * train_loss_tolerance_coeff,
    242                          true_param_loss)
    243   else:
    244     test_case.assertLess(trained_loss / train_loss_tolerance_coeff,
    245                          true_param_loss)
    246   for param in true_parameters.keys():
    247     if param in ignore_params:
    248       continue
    249     evaled_value = saving_hook.tensor_values[param]
    250     true_value = true_parameters[param]
    251     test_case.assertAllClose(true_value, evaled_value,
    252                              rtol=rtol, atol=atol)
    253 
    254 
    255 def parameter_recovery_dry_run(
    256     generate_fn, generative_model, seed,
    257     learning_rate=0.1,
    258     train_input_fn_type=input_pipeline.WholeDatasetInputFn,
    259     train_state_manager=state_management.PassthroughStateManager()):
    260   """Test that a generative model can train on generated data.
    261 
    262   Args:
    263     generate_fn: A function taking a model and returning a
    264         `input_pipeline.TimeSeriesReader` object and a dictionary mapping
    265         parameters to their values. model.initialize_graph() will have been
    266         called on the model before it is passed to this function.
    267     generative_model: A timeseries.model.TimeSeriesModel instance to test.
    268     seed: Same as for TimeSeriesModel.unconditional_generate().
    269     learning_rate: Step size for optimization.
    270     train_input_fn_type: The type of `TimeSeriesInputFn` to use when training
    271         (likely `WholeDatasetInputFn` or `RandomWindowInputFn`). If None, use
    272         `WholeDatasetInputFn`.
    273     train_state_manager: The state manager to use when training (likely
    274         `PassthroughStateManager` or `ChainingStateManager`). If None, use
    275         `PassthroughStateManager`.
    276   """
    277   _train_on_generated_data(
    278       generate_fn=generate_fn, generative_model=generative_model,
    279       seed=seed, learning_rate=learning_rate,
    280       train_input_fn_type=train_input_fn_type,
    281       train_state_manager=train_state_manager,
    282       train_iterations=2)
    283