Home | History | Annotate | Download | only in timeseries
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Classes for wrapping a model to operate on different data shapes."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import abc
     22 
     23 from tensorflow.contrib.timeseries.python.timeseries import feature_keys
     24 from tensorflow.contrib.timeseries.python.timeseries import math_utils
     25 from tensorflow.contrib.timeseries.python.timeseries.model import ModelOutputs
     26 
     27 from tensorflow.python.estimator import estimator_lib
     28 from tensorflow.python.framework import dtypes
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.ops import array_ops
     31 from tensorflow.python.ops import math_ops
     32 from tensorflow.python.util import nest
     33 
     34 
     35 class PassthroughStateManager(object):
     36   """A minimal wrapper for models which do not need state management."""
     37 
     38   def __init__(self):
     39     self._input_statistics = None
     40     self._graph_initialized = False
     41 
     42   def initialize_graph(self, model, input_statistics=None):
     43     """Adds required operations to the graph."""
     44     del model  # unused
     45     self._graph_initialized = True
     46     self._input_statistics = input_statistics
     47 
     48   def define_loss(self, model, features, mode):
     49     """Wrap "model" with StateManager-specific operations.
     50 
     51     Args:
     52       model: The model (inheriting from TimeSeriesModel) to manage state for.
     53       features: A dictionary with the following key/value pairs:
     54         feature_keys.TrainEvalFeatures.TIMES: A [batch size x window size]
     55             Tensor with times for each observation.
     56         feature_keys.TrainEvalFeatures.VALUES: A [batch size x window size x num
     57             features] Tensor with values for each observation.
     58       mode: The tf.estimator.ModeKeys mode to use (TRAIN or EVAL).
     59     Returns:
     60       A ModelOutputs object.
     61     Raises:
     62       ValueError: If start state was specified.
     63     """
     64     if feature_keys.State.STATE_TUPLE in features:
     65       raise ValueError(
     66           "Overriding start state is not supported for this model.")
     67     return model.define_loss(features, mode)
     68 
     69 
     70 class _OverridableStateManager(PassthroughStateManager):
     71   """Base class for state managers which support overriding model state."""
     72 
     73   @abc.abstractmethod
     74   def _define_loss_with_saved_state(self, model, features, mode):
     75     pass
     76 
     77   def define_loss(self, model, features, mode):
     78     """Switches between explicit start state and managed state."""
     79     if feature_keys.FilteringFeatures.STATE_TUPLE in features:
     80       # Explicit start state has been provided, so we should use that.
     81       if mode == estimator_lib.ModeKeys.TRAIN:
     82         raise ValueError(
     83             "Overriding saved state for training is not supported (but a value "
     84             "for feature {} was specified).".format(
     85                 feature_keys.FilteringFeatures.STATE_TUPLE))
     86       start_state = features[feature_keys.FilteringFeatures.STATE_TUPLE]
     87       del features[feature_keys.FilteringFeatures.STATE_TUPLE]
     88       return model.get_batch_loss(
     89           features=features, mode=mode, state=start_state)
     90     else:
     91       # No explicit start state; use managed state.
     92       return self._define_loss_with_saved_state(
     93           model=model, features=features, mode=mode)
     94 
     95 
     96 class FilteringOnlyStateManager(_OverridableStateManager):
     97   """State manager for models which use state only for filtering.
     98 
     99   Window-based models (ARModel) do not require state to be fed during training
    100   (instead requiring a specific window size). Rather than requiring a minimum
    101   window size for filtering, these models maintain this window in their state,
    102   and so need state to be fed.
    103   """
    104 
    105   def _define_loss_with_saved_state(self, model, features, mode):
    106     return model.define_loss(features, mode)
    107 
    108 
    109 class ChainingStateManager(_OverridableStateManager):
    110   """Maintains state across a batch for SequentialTimeSeriesModel subclasses.
    111 
    112   The batch dimension is treated as indexing sequential chunks of the same
    113   timeseries. End state from each chunk is fed as start state to the next chunk
    114   during the next timestep. This is an approximation to full-batch training for
    115   sequential models, but is typically much faster while still accurately
    116   recovering parameters. The speedup comes from reduced scheduling overhead of
    117   TensorFlow ops, since each operation can do much more work.
    118   """
    119 
    120   def __init__(self, state_saving_interval=20, checkpoint_state=False):
    121     """Initialize the state manager.
    122 
    123     Args:
    124       state_saving_interval: This state manager saves intermediate model state
    125           every `state_saving_interval` times. Larger values save memory, and
    126           checkpoint size if `checkpoint_state` is enabled, but models
    127           will need to impute across artificial gaps of up to this size
    128           (i.e. gaps not appearing in the original data). This imputation may
    129           affect training. Set state_saving_interval to 1 to avoid any
    130           artificial imputation.
    131       checkpoint_state: If True, saved intermediate model state will be
    132           written to checkpoints. Checkpoints will then scale with dataset
    133           size. If False, state will be freshly imputed from the beginning of a
    134           series each time the model is restored, which means it may take a few
    135           iterations for state to warm up.
    136     """
    137     super(ChainingStateManager, self).__init__()
    138     self._checkpoint_state = checkpoint_state
    139     self._state_saving_interval = state_saving_interval
    140     self._start_state = None
    141     self._cached_states = None
    142 
    143   def initialize_graph(self, model, input_statistics=None):
    144     """Adds required operations to the graph."""
    145     super(ChainingStateManager, self).initialize_graph(
    146         model=model, input_statistics=input_statistics)
    147     self._start_state = model.get_start_state()
    148     self._cached_states = math_utils.TupleOfTensorsLookup(
    149         key_dtype=dtypes.int64,
    150         default_values=self._start_state,
    151         empty_key=-1,
    152         name="cached_states",
    153         checkpoint=self._checkpoint_state)
    154 
    155   def _define_loss_with_saved_state(self, model, features, mode):
    156     """Feeds end state from one training iteration into the next.
    157 
    158     Args:
    159       model: The model to wrap. Compatible with children of TimeSeriesModel.
    160       features: Dictionary with Tensor values defining the data to be
    161         processed. The expected key/value pairs are at minimum:
    162           feature_keys.TrainEvalFeatures.TIMES: A [number of chunks x window
    163             size] Tensor with times for each observation, the result of chunking
    164             a single longer time series.
    165           feature_keys.TrainEvalFeatures.VALUES: A [number of chunks x window
    166             size x num features] Tensor with values for each observation,
    167             corresponding to times.
    168       mode: The tf.estimator.ModeKeys mode to use. For EVAL and INFER, no
    169           batching is performed, which may be slow. This is to avoid giving
    170           cached and almost certainly stale values.
    171     Returns:
    172       A ModelOutputs object.
    173     Raises:
    174       ValueError: If initialize_graph has not been called.
    175     """
    176     if not self._graph_initialized:
    177       raise ValueError("ChainingStateManager requires initialize_graph() to be "
    178                        "called before use.")
    179     (loss_op, end_state, batch_predictions) = self._update_cached_states(
    180         model=model,
    181         features=features,
    182         mode=mode)
    183     # Add a batch dimension so state can be used directly (e.g. for predictions)
    184     # without the user manually reshaping it.
    185     last_end_state_flat = [end_state_value[-1][None]
    186                            for end_state_value in nest.flatten(end_state)]
    187     batch_predictions["observed"] = features[
    188         feature_keys.TrainEvalFeatures.VALUES]
    189     return ModelOutputs(
    190         loss=loss_op,
    191         end_state=nest.pack_sequence_as(end_state, last_end_state_flat),
    192         predictions=batch_predictions,
    193         prediction_times=features[feature_keys.TrainEvalFeatures.TIMES])
    194 
    195   def _get_chunk_number(self, time):
    196     return time // self._state_saving_interval
    197 
    198   def _get_cached_states(self, times):
    199     """Retrieve cached states for a batch of times."""
    200     read_chunk_numbers = self._get_chunk_number(times)
    201     looked_up_state = list(self._cached_states.lookup(
    202         math_ops.cast(read_chunk_numbers, dtypes.int64)))
    203     looked_up_state = tuple(looked_up_state)
    204     # We need to special-case the first chunk in a series to explicitly rely on
    205     # the model's starting state so that gradients flow back to it. Otherwise it
    206     # would affect only initialization, and would not be read from or updated
    207     # during training. Not doing this also isolates that part of the graph,
    208     # leading to errors on model reload if there are trainable variables
    209     # affecting a model's start state.
    210     if self._input_statistics is not None:
    211       start_time = self._input_statistics.start_time
    212     else:
    213       start_time = 0
    214     set_to_start_state = math_ops.equal(read_chunk_numbers,
    215                                         self._get_chunk_number(start_time))
    216     new_states = []
    217     for start_state_value, cache_variable in zip(
    218         nest.flatten(
    219             math_utils.replicate_state(self._start_state,
    220                                        array_ops.shape(times)[0])),
    221         nest.flatten(looked_up_state)):
    222 
    223       new_states.append(
    224           array_ops.where(set_to_start_state, start_state_value,
    225                           cache_variable))
    226     looked_up_state = nest.pack_sequence_as(looked_up_state, new_states)
    227     return looked_up_state
    228 
    229   def _update_cached_states(self, model, features, mode):
    230     """Read, process, and write chunks to the cache."""
    231     times = features[feature_keys.TrainEvalFeatures.TIMES]
    232     looked_up_state = self._get_cached_states(times[:, 0])
    233     (model_loss, intermediate_states,
    234      batch_predictions) = model.per_step_batch_loss(
    235          features=features,
    236          mode=mode,
    237          state=looked_up_state)
    238     # We need to at least write to the bucket after the one we read from.
    239     min_chunk_numbers = self._get_chunk_number(times) + 1
    240     # We write to the bucket that would have been read had the window started at
    241     # the next sample (except for the last sample in the window, which gets
    242     # written to the next bucket). This assumes fixed missing times (i.e. if we
    243     # were presented with times [10, 50] we will never see times [30, 50]).
    244     #
    245     # TODO(allenl): Retrieve the highest time less than the current time rather
    246     # than relying on fixed bucketing.
    247     write_chunk_numbers = math_ops.maximum(
    248         self._get_chunk_number(array_ops.concat(
    249             [times[:, 1:], times[:, -1:] + 1], axis=1)),
    250         min_chunk_numbers)
    251     # Write once for every computed state; this may mean that we write multiple
    252     # times to the same cell, but later writes will take precedence.
    253     save_ops = [
    254         self._cached_states.insert(
    255             keys=write_chunk_numbers,
    256             values=intermediate_states)]
    257     end_state = nest.pack_sequence_as(
    258         intermediate_states,
    259         [state_element[:, -1]
    260          for state_element in nest.flatten(intermediate_states)])
    261     with ops.control_dependencies(save_ops):
    262       # Make sure end states get saved at each iteration
    263       loss_op = array_ops.identity(model_loss)
    264     return loss_op, end_state, batch_predictions
    265