1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Convenience functions for working with time series saved_models. 16 17 @@predict_continuation 18 @@filter_continuation 19 """ 20 21 from __future__ import absolute_import 22 from __future__ import division 23 from __future__ import print_function 24 25 from tensorflow.contrib.timeseries.python.timeseries import feature_keys as _feature_keys 26 from tensorflow.contrib.timeseries.python.timeseries import head as _head 27 from tensorflow.contrib.timeseries.python.timeseries import input_pipeline as _input_pipeline 28 from tensorflow.contrib.timeseries.python.timeseries import model_utils as _model_utils 29 30 from tensorflow.python.util.all_util import remove_undocumented 31 32 33 def _colate_features_to_feeds_and_fetches(continue_from, signature, features, 34 graph): 35 """Uses a saved model signature to construct feed and fetch dictionaries.""" 36 if _feature_keys.FilteringResults.STATE_TUPLE in continue_from: 37 # We're continuing from an evaluation, so we need to unpack/flatten state. 38 state_values = _head.state_to_dictionary( 39 continue_from[_feature_keys.FilteringResults.STATE_TUPLE]) 40 else: 41 state_values = continue_from 42 input_feed_tensors_by_name = { 43 input_key: graph.as_graph_element(input_value.name) 44 for input_key, input_value in signature.inputs.items() 45 } 46 output_tensors_by_name = { 47 output_key: graph.as_graph_element(output_value.name) 48 for output_key, output_value in signature.outputs.items() 49 } 50 feed_dict = {} 51 for state_key, state_value in state_values.items(): 52 feed_dict[input_feed_tensors_by_name[state_key]] = state_value 53 for feature_key, feature_value in features.items(): 54 feed_dict[input_feed_tensors_by_name[feature_key]] = feature_value 55 return output_tensors_by_name, feed_dict 56 57 58 def predict_continuation(continue_from, 59 signatures, 60 session, 61 steps=None, 62 times=None, 63 exogenous_features=None): 64 """Perform prediction using an exported saved model. 65 66 Analogous to _input_pipeline.predict_continuation_input_fn, but operates on a 67 saved model rather than feeding into Estimator's predict method. 68 69 Args: 70 continue_from: A dictionary containing the results of either an Estimator's 71 evaluate method or filter_continuation. Used to determine the model 72 state to make predictions starting from. 73 signatures: The `MetaGraphDef` protocol buffer returned from 74 `tf.saved_model.loader.load`. Used to determine the names of Tensors to 75 feed and fetch. Must be from the same model as `continue_from`. 76 session: The session to use. The session's graph must be the one into which 77 `tf.saved_model.loader.load` loaded the model. 78 steps: The number of steps to predict (scalar), starting after the 79 evaluation or filtering. If `times` is specified, `steps` must not be; one 80 is required. 81 times: A [batch_size x window_size] array of integers (not a Tensor) 82 indicating times to make predictions for. These times must be after the 83 corresponding evaluation or filtering. If `steps` is specified, `times` 84 must not be; one is required. If the batch dimension is omitted, it is 85 assumed to be 1. 86 exogenous_features: Optional dictionary. If specified, indicates exogenous 87 features for the model to use while making the predictions. Values must 88 have shape [batch_size x window_size x ...], where `batch_size` matches 89 the batch dimension used when creating `continue_from`, and `window_size` 90 is either the `steps` argument or the `window_size` of the `times` 91 argument (depending on which was specified). 92 Returns: 93 A dictionary with model-specific predictions (typically having keys "mean" 94 and "covariance") and a feature_keys.PredictionResults.TIMES key indicating 95 the times for which the predictions were computed. 96 Raises: 97 ValueError: If `times` or `steps` are misspecified. 98 """ 99 if exogenous_features is None: 100 exogenous_features = {} 101 predict_times = _model_utils.canonicalize_times_or_steps_from_output( 102 times=times, steps=steps, previous_model_output=continue_from) 103 features = {_feature_keys.PredictionFeatures.TIMES: predict_times} 104 features.update(exogenous_features) 105 predict_signature = signatures.signature_def[ 106 _feature_keys.SavedModelLabels.PREDICT] 107 output_tensors_by_name, feed_dict = _colate_features_to_feeds_and_fetches( 108 continue_from=continue_from, 109 signature=predict_signature, 110 features=features, 111 graph=session.graph) 112 output = session.run(output_tensors_by_name, feed_dict=feed_dict) 113 output[_feature_keys.PredictionResults.TIMES] = features[ 114 _feature_keys.PredictionFeatures.TIMES] 115 return output 116 117 118 def filter_continuation(continue_from, signatures, session, features): 119 """Perform filtering using an exported saved model. 120 121 Filtering refers to updating model state based on new observations. 122 Predictions based on the returned model state will be conditioned on these 123 observations. 124 125 Args: 126 continue_from: A dictionary containing the results of either an Estimator's 127 evaluate method or a previous filter_continuation. Used to determine the 128 model state to start filtering from. 129 signatures: The `MetaGraphDef` protocol buffer returned from 130 `tf.saved_model.loader.load`. Used to determine the names of Tensors to 131 feed and fetch. Must be from the same model as `continue_from`. 132 session: The session to use. The session's graph must be the one into which 133 `tf.saved_model.loader.load` loaded the model. 134 features: A dictionary mapping keys to Numpy arrays, with several possible 135 shapes (requires keys `FilteringFeatures.TIMES` and 136 `FilteringFeatures.VALUES`): 137 Single example; `TIMES` is a scalar and `VALUES` is either a scalar or a 138 vector of length [number of features]. 139 Sequence; `TIMES` is a vector of shape [series length], `VALUES` either 140 has shape [series length] (univariate) or [series length x number of 141 features] (multivariate). 142 Batch of sequences; `TIMES` is a vector of shape [batch size x series 143 length], `VALUES` has shape [batch size x series length] or [batch 144 size x series length x number of features]. 145 In any case, `VALUES` and any exogenous features must have their shapes 146 prefixed by the shape of the value corresponding to the `TIMES` key. 147 Returns: 148 A dictionary containing model state updated to account for the observations 149 in `features`. 150 """ 151 filter_signature = signatures.signature_def[ 152 _feature_keys.SavedModelLabels.FILTER] 153 features = _input_pipeline._canonicalize_numpy_data( # pylint: disable=protected-access 154 data=features, 155 require_single_batch=False) 156 output_tensors_by_name, feed_dict = _colate_features_to_feeds_and_fetches( 157 continue_from=continue_from, 158 signature=filter_signature, 159 features=features, 160 graph=session.graph) 161 output = session.run(output_tensors_by_name, feed_dict=feed_dict) 162 # Make it easier to chain filter -> predict by keeping track of the current 163 # time. 164 output[_feature_keys.FilteringResults.TIMES] = features[ 165 _feature_keys.FilteringFeatures.TIMES] 166 return output 167 168 remove_undocumented(module_name=__name__) 169