Home | History | Annotate | Download | only in timeseries
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Base class for time series models."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import abc
     22 import collections
     23 
     24 from tensorflow.contrib import layers
     25 from tensorflow.contrib.layers import feature_column
     26 
     27 from tensorflow.contrib.timeseries.python.timeseries import math_utils
     28 from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures
     29 from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures
     30 
     31 from tensorflow.python.framework import dtypes
     32 from tensorflow.python.framework import ops
     33 from tensorflow.python.ops import array_ops
     34 from tensorflow.python.ops import control_flow_ops
     35 from tensorflow.python.ops import math_ops
     36 from tensorflow.python.ops import tensor_array_ops
     37 from tensorflow.python.ops import variable_scope
     38 
     39 from tensorflow.python.util import nest
     40 
     41 
     42 ModelOutputs = collections.namedtuple(  # pylint: disable=invalid-name
     43     typename="ModelOutputs",
     44     field_names=[
     45         "loss",  # The scalar value to be minimized during training.
     46         "end_state",  # A nested tuple specifying the model's state after
     47                       # running on the specified data
     48         "predictions",  # A dictionary of predictions, each with shape prefixed
     49                         # by the shape of `prediction_times`.
     50         "prediction_times"  # A [batch size x window size] integer Tensor
     51                             # indicating times for which values in `predictions`
     52                             # were computed.
     53     ])
     54 
     55 
     56 class TimeSeriesModel(object):
     57   """Base class for creating generative time series models."""
     58 
     59   __metaclass__ = abc.ABCMeta
     60 
     61   def __init__(self,
     62                num_features,
     63                exogenous_feature_columns=None,
     64                dtype=dtypes.float32):
     65     """Constructor for generative models.
     66 
     67     Args:
     68       num_features: Number of features for the time series
     69       exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn
     70           objects (for example tf.contrib.layers.embedding_column) corresponding
     71           to exogenous features which provide extra information to the model but
     72           are not part of the series to be predicted. Passed to
     73           tf.contrib.layers.input_from_feature_columns.
     74       dtype: The floating point datatype to use.
     75     """
     76     if exogenous_feature_columns:
     77       self._exogenous_feature_columns = exogenous_feature_columns
     78     else:
     79       self._exogenous_feature_columns = []
     80     self.num_features = num_features
     81     self.dtype = dtype
     82     self._input_statistics = None
     83     self._graph_initialized = False
     84     self._stats_means = None
     85     self._stats_sigmas = None
     86 
     87   @property
     88   def exogenous_feature_columns(self):
     89     """`FeatureColumn` objects for features which are not predicted."""
     90     return self._exogenous_feature_columns
     91 
     92   # TODO(allenl): Move more of the generic machinery for generating and
     93   # predicting into TimeSeriesModel, and possibly share it between generate()
     94   # and predict()
     95   def generate(self, number_of_series, series_length,
     96                model_parameters=None, seed=None):
     97     """Sample synthetic data from model parameters, with optional substitutions.
     98 
     99     Returns `number_of_series` possible sequences of future values, sampled from
    100     the generative model with each conditioned on the previous. Samples are
    101     based on trained parameters, except for those parameters explicitly
    102     overridden in `model_parameters`.
    103 
    104     For distributions over future observations, see predict().
    105 
    106     Args:
    107       number_of_series: Number of time series to create.
    108       series_length: Length of each time series.
    109       model_parameters: A dictionary mapping model parameters to values, which
    110           replace trained parameters when generating data.
    111       seed: If specified, return deterministic time series according to this
    112           value.
    113     Returns:
    114       A dictionary with keys TrainEvalFeatures.TIMES (mapping to an array with
    115       shape [number_of_series, series_length]) and TrainEvalFeatures.VALUES
    116       (mapping to an array with shape [number_of_series, series_length,
    117       num_features]).
    118     """
    119     raise NotImplementedError("This model does not support generation.")
    120 
    121   def initialize_graph(self, input_statistics=None):
    122     """Define ops for the model, not depending on any previously defined ops.
    123 
    124     Args:
    125       input_statistics: A math_utils.InputStatistics object containing input
    126           statistics. If None, data-independent defaults are used, which may
    127           result in longer or unstable training.
    128     """
    129     self._graph_initialized = True
    130     self._input_statistics = input_statistics
    131     if self._input_statistics:
    132       self._stats_means, variances = (
    133           self._input_statistics.overall_feature_moments)
    134       self._stats_sigmas = math_ops.sqrt(variances)
    135 
    136   def _scale_data(self, data):
    137     """Scale data according to stats (input scale -> model scale)."""
    138     if self._input_statistics is not None:
    139       return (data - self._stats_means) / self._stats_sigmas
    140     else:
    141       return data
    142 
    143   def _scale_variance(self, variance):
    144     """Scale variances according to stats (input scale -> model scale)."""
    145     if self._input_statistics is not None:
    146       return variance / self._input_statistics.overall_feature_moments.variance
    147     else:
    148       return variance
    149 
    150   def _scale_back_data(self, data):
    151     """Scale back data according to stats (model scale -> input scale)."""
    152     if self._input_statistics is not None:
    153       return (data * self._stats_sigmas) + self._stats_means
    154     else:
    155       return data
    156 
    157   def _scale_back_variance(self, variance):
    158     """Scale back variances according to stats (model scale -> input scale)."""
    159     if self._input_statistics is not None:
    160       return variance * self._input_statistics.overall_feature_moments.variance
    161     else:
    162       return variance
    163 
    164   def _check_graph_initialized(self):
    165     if not self._graph_initialized:
    166       raise ValueError(
    167           "TimeSeriesModels require initialize_graph() to be called before "
    168           "use. This defines variables and ops in the default graph, and "
    169           "allows Tensor-valued input statistics to be specified.")
    170 
    171   def define_loss(self, features, mode):
    172     """Default loss definition with state replicated across a batch.
    173 
    174     Time series passed to this model have a batch dimension, and each series in
    175     a batch can be operated on in parallel. This loss definition assumes that
    176     each element of the batch represents an independent sample conditioned on
    177     the same initial state (i.e. it is simply replicated across the batch). A
    178     batch size of one provides sequential operations on a single time series.
    179 
    180     More complex processing may operate instead on get_start_state() and
    181     get_batch_loss() directly.
    182 
    183     Args:
    184       features: A dictionary (such as is produced by a chunker) with at minimum
    185         the following key/value pairs (others corresponding to the
    186         `exogenous_feature_columns` argument to `__init__` may be included
    187         representing exogenous regressors):
    188         TrainEvalFeatures.TIMES: A [batch size x window size] integer Tensor
    189             with times for each observation. If there is no artificial chunking,
    190             the window size is simply the length of the time series.
    191         TrainEvalFeatures.VALUES: A [batch size x window size x num features]
    192             Tensor with values for each observation.
    193       mode: The tf.estimator.ModeKeys mode to use (TRAIN, EVAL). For INFER,
    194         see predict().
    195     Returns:
    196       A ModelOutputs object.
    197     """
    198     self._check_graph_initialized()
    199     start_state = math_utils.replicate_state(
    200         start_state=self.get_start_state(),
    201         batch_size=array_ops.shape(features[TrainEvalFeatures.TIMES])[0])
    202     return self.get_batch_loss(features=features, mode=mode, state=start_state)
    203 
    204   # TODO(vitalyk,allenl): Better documentation surrounding options for chunking,
    205   # references to papers, etc.
    206   @abc.abstractmethod
    207   def get_start_state(self):
    208     """Returns a tuple of state for the start of the time series.
    209 
    210     For example, a mean and covariance. State should not have a batch
    211     dimension, and will often be TensorFlow Variables to be learned along with
    212     the rest of the model parameters.
    213     """
    214     pass
    215 
    216   @abc.abstractmethod
    217   def get_batch_loss(self, features, mode, state):
    218     """Return predictions, losses, and end state for a time series.
    219 
    220     Args:
    221       features: A dictionary with times, values, and (optionally) exogenous
    222           regressors. See `define_loss`.
    223       mode: The tf.estimator.ModeKeys mode to use (TRAIN, EVAL, INFER).
    224       state: Model-dependent state, each with size [batch size x ...]. The
    225           number and type will typically be fixed by the model (for example a
    226           mean and variance).
    227     Returns:
    228       A ModelOutputs object.
    229     """
    230     pass
    231 
    232   @abc.abstractmethod
    233   def predict(self, features):
    234     """Returns predictions of future observations given an initial state.
    235 
    236     Computes distributions for future observations. For sampled draws from the
    237     model where each is conditioned on the previous, see generate().
    238 
    239     Args:
    240       features: A dictionary with at minimum the following key/value pairs
    241         (others corresponding to the `exogenous_feature_columns` argument to
    242         `__init__` may be included representing exogenous regressors):
    243         PredictionFeatures.TIMES: A [batch size x window size] Tensor with
    244           times to make predictions for. Times must be increasing within each
    245           part of the batch, and must be greater than the last time `state` was
    246           updated.
    247         PredictionFeatures.STATE_TUPLE: Model-dependent state, each with size
    248           [batch size x ...]. The number and type will typically be fixed by the
    249           model (for example a mean and variance). Typically these will be the
    250           end state returned by get_batch_loss, predicting beyond that data.
    251     Returns:
    252       A dictionary with model-dependent predictions corresponding to the
    253       requested times. Keys indicate the type of prediction, and values have
    254       shape [batch size x window size x ...]. For example state space models
    255       return a "predicted_mean" and "predicted_covariance".
    256     """
    257     pass
    258 
    259   def _get_exogenous_embedding_shape(self):
    260     """Computes the shape of the vector returned by _process_exogenous_features.
    261 
    262     Returns:
    263       The shape as a list. Does not include a batch dimension.
    264     """
    265     if not self._exogenous_feature_columns:
    266       return (0,)
    267     with ops.Graph().as_default():
    268       placeholder_features = (
    269           feature_column.make_place_holder_tensors_for_base_features(
    270               self._exogenous_feature_columns))
    271       embedded = layers.input_from_feature_columns(
    272           columns_to_tensors=placeholder_features,
    273           feature_columns=self._exogenous_feature_columns)
    274       return embedded.get_shape().as_list()[1:]
    275 
    276   def _process_exogenous_features(self, times, features):
    277     """Create a single vector from exogenous features.
    278 
    279     Args:
    280       times: A [batch size, window size] vector of times for this batch,
    281           primarily used to check the shape information of exogenous features.
    282       features: A dictionary of exogenous features corresponding to the columns
    283           in self._exogenous_feature_columns. Each value should have a shape
    284           prefixed by [batch size, window size].
    285     Returns:
    286       A Tensor with shape [batch size, window size, exogenous dimension], where
    287       the size of the exogenous dimension depends on the exogenous feature
    288       columns passed to the model's constructor.
    289     Raises:
    290       ValueError: If an exogenous feature has an unknown rank.
    291     """
    292     if self._exogenous_feature_columns:
    293       exogenous_features_single_batch_dimension = {}
    294       for name, tensor in features.items():
    295         if tensor.get_shape().ndims is None:
    296           # input_from_feature_columns does not support completely unknown
    297           # feature shapes, so we save on a bit of logic and provide a better
    298           # error message by checking that here.
    299           raise ValueError(
    300               ("Features with unknown rank are not supported. Got shape {} for "
    301                "feature {}.").format(tensor.get_shape(), name))
    302         tensor_shape_dynamic = array_ops.shape(tensor)
    303         tensor = array_ops.reshape(
    304             tensor,
    305             array_ops.concat([[tensor_shape_dynamic[0]
    306                                * tensor_shape_dynamic[1]],
    307                               tensor_shape_dynamic[2:]], axis=0))
    308         # Avoid shape warnings when embedding "scalar" exogenous features (those
    309         # with only batch and window dimensions); input_from_feature_columns
    310         # expects input ranks to match the embedded rank.
    311         if tensor.get_shape().ndims == 1:
    312           exogenous_features_single_batch_dimension[name] = tensor[:, None]
    313         else:
    314           exogenous_features_single_batch_dimension[name] = tensor
    315       embedded_exogenous_features_single_batch_dimension = (
    316           layers.input_from_feature_columns(
    317               columns_to_tensors=exogenous_features_single_batch_dimension,
    318               feature_columns=self._exogenous_feature_columns,
    319               trainable=True))
    320       exogenous_regressors = array_ops.reshape(
    321           embedded_exogenous_features_single_batch_dimension,
    322           array_ops.concat(
    323               [
    324                   array_ops.shape(times), array_ops.shape(
    325                       embedded_exogenous_features_single_batch_dimension)[1:]
    326               ],
    327               axis=0))
    328       exogenous_regressors.set_shape(times.get_shape().concatenate(
    329           embedded_exogenous_features_single_batch_dimension.get_shape()[1:]))
    330       exogenous_regressors = math_ops.cast(
    331           exogenous_regressors, dtype=self.dtype)
    332     else:
    333       # Not having any exogenous features is a special case so that models can
    334       # avoid superfluous updates, which may not be free of side effects due to
    335       # bias terms in transformations.
    336       exogenous_regressors = None
    337     return exogenous_regressors
    338 
    339 
    340 # TODO(allenl): Add a superclass of SequentialTimeSeriesModel which fuses
    341 # filtering/prediction/exogenous into one step, and move looping constructs to
    342 # that class.
    343 class SequentialTimeSeriesModel(TimeSeriesModel):
    344   """Base class for recurrent generative models.
    345 
    346   Models implementing this interface have three main functions, corresponding to
    347   abstract methods:
    348     _filtering_step: Updates state based on observations and computes a loss.
    349     _prediction_step: Predicts a batch of observations and new model state.
    350     _imputation_step: Updates model state across a gap.
    351     _exogenous_input_step: Updates state to account for exogenous regressors.
    352 
    353   Models may also specify a _window_initializer to prepare for a window of data.
    354 
    355   See StateSpaceModel for a concrete example of a model implementing this
    356   interface.
    357 
    358   """
    359 
    360   def __init__(self,
    361                train_output_names,
    362                predict_output_names,
    363                num_features,
    364                normalize_features=False,
    365                dtype=dtypes.float32,
    366                exogenous_feature_columns=None,
    367                exogenous_update_condition=None,
    368                static_unrolling_window_size_threshold=None):
    369     """Initialize a SequentialTimeSeriesModel.
    370 
    371     Args:
    372       train_output_names: A list of products/predictions returned from
    373           _filtering_step.
    374       predict_output_names: A list of products/predictions returned from
    375           _prediction_step.
    376       num_features: Number of features for the time series
    377       normalize_features: Boolean. If True, `values` are passed normalized to
    378           the model (via self._scale_data). Scaling is done for the whole window
    379           as a batch, which is slightly more efficient than scaling inside the
    380           window loop. The model must then define _scale_back_predictions, which
    381           may use _scale_back_data or _scale_back_variance to return predictions
    382           to the input scale.
    383       dtype: The floating point datatype to use.
    384       exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn
    385           objects. See `TimeSeriesModel`.
    386       exogenous_update_condition: A function taking two Tensor arguments `times`
    387           (shape [batch size]) and `features` (a dictionary mapping exogenous
    388           feature keys to Tensors with shapes [batch size, ...]) and returning a
    389           boolean Tensor with shape [batch size] indicating whether state should
    390           be updated using exogenous features for each part of the batch. Where
    391           it is False, no exogenous update is performed. If None (default),
    392           exogenous updates are always performed. Useful for avoiding "leaky"
    393           frequent exogenous updates when sparse updates are desired. Called
    394           only during graph construction.
    395       static_unrolling_window_size_threshold: Controls whether a `tf.while_loop`
    396           is used when looping over a window of data. If
    397           `static_unrolling_window_size_threshold` is None, a `tf.while_loop` is
    398           always used. Otherwise it must be an integer, and the graph is
    399           replicated for each step taken whenever the window size is less than
    400           or equal to this value (if the window size is available in the static
    401           shape information of the TrainEvalFeatures.TIMES feature). Static
    402           unrolling generally decreases the per-step time for small window/batch
    403           sizes, but increases graph construction time.
    404     """
    405     super(SequentialTimeSeriesModel, self).__init__(
    406         num_features=num_features, dtype=dtype,
    407         exogenous_feature_columns=exogenous_feature_columns)
    408     self._exogenous_update_condition = exogenous_update_condition
    409     self._train_output_names = train_output_names
    410     self._predict_output_names = predict_output_names
    411     self._normalize_features = normalize_features
    412     self._static_unrolling_window_size_threshold = (
    413         static_unrolling_window_size_threshold)
    414 
    415   def _scale_back_predictions(self, predictions):
    416     """Return a window of predictions to input scale.
    417 
    418     Args:
    419       predictions: A dictionary mapping from prediction names to Tensors.
    420     Returns:
    421       A dictionary with values corrected for input normalization (e.g. with
    422       self._scale_back_mean and possibly self._scale_back_variance). May be a
    423       mutated version of the argument.
    424     """
    425     raise NotImplementedError(
    426         "SequentialTimeSeriesModel normalized input data"
    427         " (normalize_features=True), but no method was provided to transform "
    428         "the predictions back to the input scale.")
    429 
    430   @abc.abstractmethod
    431   def _filtering_step(self, current_times, current_values, state, predictions):
    432     """Compute a single-step loss for a batch of data.
    433 
    434     Args:
    435       current_times: A [batch size] Tensor of times for each observation.
    436       current_values: A [batch size] Tensor of values for each observation.
    437       state: Model state, updated to current_times.
    438       predictions: The outputs of _prediction_step
    439     Returns:
    440       A tuple of (updated state, outputs):
    441         updated state: Model state taking current_values into account.
    442         outputs: A dictionary of Tensors with keys corresponding to
    443             self._train_output_names, plus a special "loss" key. The value
    444             corresponding to "loss" is minimized during training. Other outputs
    445             may include one-step-ahead predictions, for example a predicted
    446             location and scale.
    447     """
    448     pass
    449 
    450   @abc.abstractmethod
    451   def _prediction_step(self, current_times, state):
    452     """Compute a batch of single-step predictions.
    453 
    454     Args:
    455       current_times: A [batch size] Tensor of times for each observation.
    456       state: Model state, imputed to one step before current_times.
    457     Returns:
    458       A tuple of (updated state, outputs):
    459         updated state: Model state updated to current_times.
    460         outputs: A dictionary of Tensors with keys corresponding to
    461             self._predict_output_names.
    462     """
    463     pass
    464 
    465   @abc.abstractmethod
    466   def _imputation_step(self, current_times, state):
    467     """Update model state across missing values.
    468 
    469     Called to prepare model state for _filtering_step and _prediction_step.
    470 
    471     Args:
    472       current_times: A [batch size] Tensor; state will be imputed up to, but not
    473           including, these timesteps.
    474       state: The pre-imputation model state, Tensors with shape [batch size x
    475           ...].
    476     Returns:
    477       Updated/imputed model state, corresponding to `state`.
    478     """
    479     pass
    480 
    481   @abc.abstractmethod
    482   def _exogenous_input_step(
    483       self, current_times, current_exogenous_regressors, state):
    484     """Update state to account for exogenous regressors.
    485 
    486     Args:
    487       current_times: A [batch size] Tensor of times for the exogenous values
    488           being input.
    489       current_exogenous_regressors: A [batch size x exogenous input dimension]
    490           Tensor of exogenous values for each part of the batch.
    491       state: Model state, a possibly nested list of Tensors, each with shape
    492           [batch size x ...].
    493     Returns:
    494       Updated model state, structure and shapes matching the `state` argument.
    495     """
    496     pass
    497 
    498   # TODO(allenl): Move regularization to a separate object (optional and
    499   # configurable)
    500   def _loss_additions(self, times, values, mode):
    501     """Additions to per-observation normalized loss, e.g. regularization.
    502 
    503     Args:
    504       times: A [batch size x window size] Tensor with times for each
    505           observation.
    506       values: A [batch size x window size x num features] Tensor with values for
    507           each observation.
    508       mode: The tf.estimator.ModeKeys mode to use (TRAIN, EVAL, INFER).
    509     Returns:
    510       A scalar value to add to the per-observation normalized loss.
    511     """
    512     del times, values, mode
    513     return 0.
    514 
    515   def _window_initializer(self, times, state):
    516     """Prepare for training or prediction on a window of data.
    517 
    518     Args:
    519       times: A [batch size x window size] Tensor with times for each
    520           observation.
    521       state: Model-dependent state, each with size [batch size x ...]. The
    522           number and type will typically be fixed by the model (for example a
    523           mean and variance).
    524     Returns:
    525       Nothing
    526     """
    527     pass
    528 
    529   def get_batch_loss(self, features, mode, state):
    530     """Calls self._filtering_step. See TimeSeriesModel.get_batch_loss."""
    531     per_observation_loss, state, outputs = self.per_step_batch_loss(
    532         features, mode, state)
    533     # per_step_batch_loss returns [batch size, window size, ...] state, whereas
    534     # get_batch_loss is expected to return [batch size, ...] state for the last
    535     # element of a window
    536     state = nest.pack_sequence_as(
    537         state,
    538         [state_element[:, -1] for state_element in nest.flatten(state)])
    539     outputs["observed"] = features[TrainEvalFeatures.VALUES]
    540     return ModelOutputs(
    541         loss=per_observation_loss,
    542         end_state=state,
    543         predictions=outputs,
    544         prediction_times=features[TrainEvalFeatures.TIMES])
    545 
    546   def _apply_exogenous_update(
    547       self, current_times, step_number, state, raw_features,
    548       embedded_exogenous_regressors):
    549     """Performs a conditional state update based on exogenous features."""
    550     if embedded_exogenous_regressors is None:
    551       return state
    552     else:
    553       current_exogenous_regressors = embedded_exogenous_regressors[
    554           :, step_number, :]
    555       exogenous_updated_state = self._exogenous_input_step(
    556           current_times=current_times,
    557           current_exogenous_regressors=current_exogenous_regressors,
    558           state=state)
    559       if self._exogenous_update_condition is not None:
    560         current_raw_exogenous_features = {
    561             key: value[:, step_number] for key, value in raw_features.items()
    562             if key not in [PredictionFeatures.STATE_TUPLE,
    563                            TrainEvalFeatures.TIMES,
    564                            TrainEvalFeatures.VALUES]}
    565         conditionally_updated_state_flat = []
    566         for updated_state_element, original_state_element in zip(
    567             nest.flatten(exogenous_updated_state),
    568             nest.flatten(state)):
    569           conditionally_updated_state_flat.append(
    570               array_ops.where(
    571                   self._exogenous_update_condition(
    572                       times=current_times,
    573                       features=current_raw_exogenous_features),
    574                   updated_state_element,
    575                   original_state_element))
    576         return nest.pack_sequence_as(state, conditionally_updated_state_flat)
    577       else:
    578         return exogenous_updated_state
    579 
    580   def per_step_batch_loss(self, features, mode, state):
    581     """Computes predictions, losses, and intermediate model states.
    582 
    583     Args:
    584       features: A dictionary with times, values, and (optionally) exogenous
    585           regressors. See `define_loss`.
    586       mode: The tf.estimator.ModeKeys mode to use (TRAIN, EVAL, INFER).
    587       state: Model-dependent state, each with size [batch size x ...]. The
    588           number and type will typically be fixed by the model (for example a
    589           mean and variance).
    590     Returns:
    591       A tuple of (loss, filtered_states, predictions)
    592         loss: Average loss values across the batch.
    593         filtered_states: For each Tensor in `state` with shape [batch size x
    594             ...], `filtered_states` has a Tensor with shape [batch size x window
    595             size x ...] with filtered state for each part of the batch and
    596             window.
    597         predictions: A dictionary with model-dependent one-step-ahead (or
    598             at-least-one-step-ahead with missing values) predictions, with keys
    599             indicating the type of prediction and values having shape [batch
    600             size x window size x ...]. For example state space models provide
    601             "mean", "covariance", and "log_likelihood".
    602 
    603     """
    604     self._check_graph_initialized()
    605     times = math_ops.cast(features[TrainEvalFeatures.TIMES], dtype=dtypes.int64)
    606     values = math_ops.cast(features[TrainEvalFeatures.VALUES], dtype=self.dtype)
    607     if self._normalize_features:
    608       values = self._scale_data(values)
    609     exogenous_regressors = self._process_exogenous_features(
    610         times=times,
    611         features={key: value for key, value in features.items()
    612                   if key not in [TrainEvalFeatures.TIMES,
    613                                  TrainEvalFeatures.VALUES]})
    614     def _batch_loss_filtering_step(step_number, current_times, state):
    615       """Make a prediction and update it based on data."""
    616       current_values = values[:, step_number, :]
    617       state = self._apply_exogenous_update(
    618           step_number=step_number, current_times=current_times, state=state,
    619           raw_features=features,
    620           embedded_exogenous_regressors=exogenous_regressors)
    621       predicted_state, predictions = self._prediction_step(
    622           current_times=current_times,
    623           state=state)
    624       filtered_state, outputs = self._filtering_step(
    625           current_times=current_times,
    626           current_values=current_values,
    627           state=predicted_state,
    628           predictions=predictions)
    629       return filtered_state, outputs
    630     state, outputs = self._state_update_loop(
    631         times=times, state=state, state_update_fn=_batch_loss_filtering_step,
    632         outputs=["loss"] + self._train_output_names)
    633     outputs["loss"].set_shape(times.get_shape())
    634     loss_sum = math_ops.reduce_sum(outputs["loss"])
    635     per_observation_loss = (loss_sum / math_ops.cast(
    636         math_ops.reduce_prod(array_ops.shape(times)), dtype=self.dtype))
    637     per_observation_loss += self._loss_additions(times, values, mode)
    638     # Since we have window-level additions to the loss, its per-step value is
    639     # misleading, so we avoid returning it.
    640     del outputs["loss"]
    641     if self._normalize_features:
    642       outputs = self._scale_back_predictions(outputs)
    643     return per_observation_loss, state, outputs
    644 
    645   def predict(self, features):
    646     """Calls self._prediction_step in a loop. See TimeSeriesModel.predict."""
    647     predict_times = ops.convert_to_tensor(features[PredictionFeatures.TIMES],
    648                                           dtypes.int64)
    649     start_state = features[PredictionFeatures.STATE_TUPLE]
    650     exogenous_regressors = self._process_exogenous_features(
    651         times=predict_times,
    652         features={
    653             key: value
    654             for key, value in features.items()
    655             if key not in
    656             [PredictionFeatures.TIMES, PredictionFeatures.STATE_TUPLE]
    657         })
    658     def _call_prediction_step(step_number, current_times, state):
    659       state = self._apply_exogenous_update(
    660           step_number=step_number, current_times=current_times, state=state,
    661           raw_features=features,
    662           embedded_exogenous_regressors=exogenous_regressors)
    663       state, outputs = self._prediction_step(
    664           current_times=current_times, state=state)
    665       return state, outputs
    666     _, predictions = self._state_update_loop(
    667         times=predict_times, state=start_state,
    668         state_update_fn=_call_prediction_step,
    669         outputs=self._predict_output_names)
    670     if self._normalize_features:
    671       predictions = self._scale_back_predictions(predictions)
    672     return predictions
    673 
    674   class _FakeTensorArray(object):
    675     """An interface for Python lists that is similar to TensorArray.
    676 
    677     Used for easy switching between static and dynamic looping.
    678     """
    679 
    680     def __init__(self):
    681       self.values = []
    682 
    683     def write(self, unused_position, value):
    684       del unused_position
    685       self.values.append(value)
    686       return self
    687 
    688   def _state_update_loop(self, times, state, state_update_fn, outputs):
    689     """Iterates over `times`, calling `state_update_fn` to collect outputs.
    690 
    691     Args:
    692       times: A [batch size x window size] Tensor of integers to iterate over.
    693       state: A list of model-specific state Tensors, each with shape [batch size
    694           x ...].
    695       state_update_fn: A callback taking the following arguments
    696             step_number; A scalar integer Tensor indicating the current position
    697               in the window.
    698             current_times; A [batch size] vector of Integers indicating times
    699               for each part of the batch.
    700             state; Current model state.
    701           It returns a tuple of (updated state, output_values), output_values
    702           being a dictionary of Tensors with keys corresponding to `outputs`.
    703       outputs: A list of strings indicating values which will be saved while
    704           iterating. Must match the keys of the dictionary returned by
    705           state_update_fn.
    706     Returns:
    707       A tuple of (state, output_dict)
    708       state: The final model state.
    709       output_dict: A dictionary of outputs corresponding to those specified in
    710         `outputs` and computed in state_update_fn.
    711     """
    712     times = ops.convert_to_tensor(times, dtype=dtypes.int64)
    713     window_static_shape = times.get_shape()[1].value
    714     if self._static_unrolling_window_size_threshold is None:
    715       static_unroll = False
    716     else:
    717       # The user has specified a threshold for static loop unrolling.
    718       if window_static_shape is None:
    719         # We don't have static shape information for the window size, so dynamic
    720         # looping is our only option.
    721         static_unroll = False
    722       elif window_static_shape <= self._static_unrolling_window_size_threshold:
    723         # The threshold is satisfied; unroll statically
    724         static_unroll = True
    725       else:
    726         # A threshold was set but not satisfied
    727         static_unroll = False
    728 
    729     self._window_initializer(times, state)
    730 
    731     def _run_condition(step_number, *unused):
    732       del unused  # not part of while loop run condition
    733       return math_ops.less(step_number, window_size)
    734 
    735     def _state_update_step(
    736         step_number, state, state_accumulators, output_accumulators,
    737         reuse=False):
    738       """Impute, then take one state_update_fn step, accumulating outputs."""
    739       with variable_scope.variable_scope("state_update_step", reuse=reuse):
    740         current_times = times[:, step_number]
    741         state = self._imputation_step(current_times=current_times, state=state)
    742         output_accumulators_dict = {
    743             accumulator_key: accumulator
    744             for accumulator_key, accumulator
    745             in zip(outputs, output_accumulators)}
    746         step_state, output_values = state_update_fn(
    747             step_number=step_number,
    748             current_times=current_times,
    749             state=state)
    750         assert set(output_values.keys()) == set(outputs)
    751         new_output_accumulators = []
    752         for output_key in outputs:
    753           accumulator = output_accumulators_dict[output_key]
    754           output_value = output_values[output_key]
    755           new_output_accumulators.append(
    756               accumulator.write(step_number, output_value))
    757         flat_step_state = nest.flatten(step_state)
    758         assert len(state_accumulators) == len(flat_step_state)
    759         new_state_accumulators = []
    760         new_state_flat = []
    761         for step_state_value, state_accumulator, original_state in zip(
    762             flat_step_state, state_accumulators, nest.flatten(state)):
    763           # Make sure the static shape information is complete so while_loop
    764           # does not complain about shape information changing.
    765           step_state_value.set_shape(original_state.get_shape())
    766           new_state_flat.append(step_state_value)
    767           new_state_accumulators.append(state_accumulator.write(
    768               step_number, step_state_value))
    769         step_state = nest.pack_sequence_as(state, new_state_flat)
    770         return (step_number + 1, step_state,
    771                 new_state_accumulators, new_output_accumulators)
    772 
    773     window_size = array_ops.shape(times)[1]
    774 
    775     def _window_size_tensor_array(dtype):
    776       if static_unroll:
    777         return self._FakeTensorArray()
    778       else:
    779         return tensor_array_ops.TensorArray(
    780             dtype=dtype, size=window_size, dynamic_size=False)
    781 
    782     initial_loop_arguments = [
    783         array_ops.zeros([], dtypes.int32),
    784         state,
    785         [_window_size_tensor_array(element.dtype)
    786          for element in nest.flatten(state)],
    787         [_window_size_tensor_array(self.dtype) for _ in outputs]]
    788     if static_unroll:
    789       arguments = initial_loop_arguments
    790       for step_number in range(times.get_shape()[1].value):
    791         arguments = _state_update_step(
    792             array_ops.constant(step_number, dtypes.int32), *arguments[1:],
    793             reuse=(step_number > 0))  # Variable sharing between steps
    794     else:
    795       arguments = control_flow_ops.while_loop(
    796           cond=_run_condition,
    797           body=_state_update_step,
    798           loop_vars=initial_loop_arguments)
    799     (_, _, state_loop_result, outputs_loop_result) = arguments
    800 
    801     def _stack_and_transpose(tensor_array):
    802       """Stack and re-order the dimensions of a TensorArray."""
    803       if static_unroll:
    804         return array_ops.stack(tensor_array.values, axis=1)
    805       else:
    806         # TensorArrays from while_loop stack with window size as the first
    807         # dimension, so this function swaps it and the batch dimension to
    808         # maintain the [batch x window size x ...] convention used elsewhere.
    809         stacked = tensor_array.stack()
    810         return array_ops.transpose(
    811             stacked,
    812             perm=array_ops.concat([[1, 0], math_ops.range(
    813                 2, array_ops.rank(stacked))], 0))
    814 
    815     outputs_dict = {output_key: _stack_and_transpose(output)
    816                     for output_key, output
    817                     in zip(outputs, outputs_loop_result)}
    818     full_state = nest.pack_sequence_as(
    819         state,
    820         [_stack_and_transpose(state_element)
    821          for state_element in state_loop_result])
    822     return full_state, outputs_dict
    823