Home | History | Annotate | Download | only in timeseries
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Convenience functions for working with time series saved_models.
     16 
     17 @@predict_continuation
     18 @@filter_continuation
     19 """
     20 
     21 from __future__ import absolute_import
     22 from __future__ import division
     23 from __future__ import print_function
     24 
     25 from tensorflow.contrib.timeseries.python.timeseries import feature_keys as _feature_keys
     26 from tensorflow.contrib.timeseries.python.timeseries import head as _head
     27 from tensorflow.contrib.timeseries.python.timeseries import input_pipeline as _input_pipeline
     28 from tensorflow.contrib.timeseries.python.timeseries import model_utils as _model_utils
     29 
     30 from tensorflow.python.util.all_util import remove_undocumented
     31 
     32 
     33 def _colate_features_to_feeds_and_fetches(continue_from, signature, features,
     34                                           graph):
     35   """Uses a saved model signature to construct feed and fetch dictionaries."""
     36   if _feature_keys.FilteringResults.STATE_TUPLE in continue_from:
     37     # We're continuing from an evaluation, so we need to unpack/flatten state.
     38     state_values = _head.state_to_dictionary(
     39         continue_from[_feature_keys.FilteringResults.STATE_TUPLE])
     40   else:
     41     state_values = continue_from
     42   input_feed_tensors_by_name = {
     43       input_key: graph.as_graph_element(input_value.name)
     44       for input_key, input_value in signature.inputs.items()
     45   }
     46   output_tensors_by_name = {
     47       output_key: graph.as_graph_element(output_value.name)
     48       for output_key, output_value in signature.outputs.items()
     49   }
     50   feed_dict = {}
     51   for state_key, state_value in state_values.items():
     52     feed_dict[input_feed_tensors_by_name[state_key]] = state_value
     53   for feature_key, feature_value in features.items():
     54     feed_dict[input_feed_tensors_by_name[feature_key]] = feature_value
     55   return output_tensors_by_name, feed_dict
     56 
     57 
     58 def predict_continuation(continue_from,
     59                          signatures,
     60                          session,
     61                          steps=None,
     62                          times=None,
     63                          exogenous_features=None):
     64   """Perform prediction using an exported saved model.
     65 
     66   Analogous to _input_pipeline.predict_continuation_input_fn, but operates on a
     67   saved model rather than feeding into Estimator's predict method.
     68 
     69   Args:
     70     continue_from: A dictionary containing the results of either an Estimator's
     71       evaluate method or filter_continuation. Used to determine the model
     72       state to make predictions starting from.
     73     signatures: The `MetaGraphDef` protocol buffer returned from
     74       `tf.saved_model.loader.load`. Used to determine the names of Tensors to
     75       feed and fetch. Must be from the same model as `continue_from`.
     76     session: The session to use. The session's graph must be the one into which
     77       `tf.saved_model.loader.load` loaded the model.
     78     steps: The number of steps to predict (scalar), starting after the
     79       evaluation or filtering. If `times` is specified, `steps` must not be; one
     80       is required.
     81     times: A [batch_size x window_size] array of integers (not a Tensor)
     82       indicating times to make predictions for. These times must be after the
     83       corresponding evaluation or filtering. If `steps` is specified, `times`
     84       must not be; one is required. If the batch dimension is omitted, it is
     85       assumed to be 1.
     86     exogenous_features: Optional dictionary. If specified, indicates exogenous
     87       features for the model to use while making the predictions. Values must
     88       have shape [batch_size x window_size x ...], where `batch_size` matches
     89       the batch dimension used when creating `continue_from`, and `window_size`
     90       is either the `steps` argument or the `window_size` of the `times`
     91       argument (depending on which was specified).
     92   Returns:
     93     A dictionary with model-specific predictions (typically having keys "mean"
     94     and "covariance") and a feature_keys.PredictionResults.TIMES key indicating
     95     the times for which the predictions were computed.
     96   Raises:
     97     ValueError: If `times` or `steps` are misspecified.
     98   """
     99   if exogenous_features is None:
    100     exogenous_features = {}
    101   predict_times = _model_utils.canonicalize_times_or_steps_from_output(
    102       times=times, steps=steps, previous_model_output=continue_from)
    103   features = {_feature_keys.PredictionFeatures.TIMES: predict_times}
    104   features.update(exogenous_features)
    105   predict_signature = signatures.signature_def[
    106       _feature_keys.SavedModelLabels.PREDICT]
    107   output_tensors_by_name, feed_dict = _colate_features_to_feeds_and_fetches(
    108       continue_from=continue_from,
    109       signature=predict_signature,
    110       features=features,
    111       graph=session.graph)
    112   output = session.run(output_tensors_by_name, feed_dict=feed_dict)
    113   output[_feature_keys.PredictionResults.TIMES] = features[
    114       _feature_keys.PredictionFeatures.TIMES]
    115   return output
    116 
    117 
    118 def filter_continuation(continue_from, signatures, session, features):
    119   """Perform filtering using an exported saved model.
    120 
    121   Filtering refers to updating model state based on new observations.
    122   Predictions based on the returned model state will be conditioned on these
    123   observations.
    124 
    125   Args:
    126     continue_from: A dictionary containing the results of either an Estimator's
    127       evaluate method or a previous filter_continuation. Used to determine the
    128       model state to start filtering from.
    129     signatures: The `MetaGraphDef` protocol buffer returned from
    130       `tf.saved_model.loader.load`. Used to determine the names of Tensors to
    131       feed and fetch. Must be from the same model as `continue_from`.
    132     session: The session to use. The session's graph must be the one into which
    133       `tf.saved_model.loader.load` loaded the model.
    134     features: A dictionary mapping keys to Numpy arrays, with several possible
    135       shapes (requires keys `FilteringFeatures.TIMES` and
    136       `FilteringFeatures.VALUES`):
    137         Single example; `TIMES` is a scalar and `VALUES` is either a scalar or a
    138           vector of length [number of features].
    139         Sequence; `TIMES` is a vector of shape [series length], `VALUES` either
    140           has shape [series length] (univariate) or [series length x number of
    141           features] (multivariate).
    142         Batch of sequences; `TIMES` is a vector of shape [batch size x series
    143           length], `VALUES` has shape [batch size x series length] or [batch
    144           size x series length x number of features].
    145       In any case, `VALUES` and any exogenous features must have their shapes
    146       prefixed by the shape of the value corresponding to the `TIMES` key.
    147   Returns:
    148     A dictionary containing model state updated to account for the observations
    149     in `features`.
    150   """
    151   filter_signature = signatures.signature_def[
    152       _feature_keys.SavedModelLabels.FILTER]
    153   features = _input_pipeline._canonicalize_numpy_data(  # pylint: disable=protected-access
    154       data=features,
    155       require_single_batch=False)
    156   output_tensors_by_name, feed_dict = _colate_features_to_feeds_and_fetches(
    157       continue_from=continue_from,
    158       signature=filter_signature,
    159       features=features,
    160       graph=session.graph)
    161   output = session.run(output_tensors_by_name, feed_dict=feed_dict)
    162   # Make it easier to chain filter -> predict by keeping track of the current
    163   # time.
    164   output[_feature_keys.FilteringResults.TIMES] = features[
    165       _feature_keys.FilteringFeatures.TIMES]
    166   return output
    167 
    168 remove_undocumented(module_name=__name__)
    169