1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the structural state space ensembles.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy 22 23 from tensorflow.contrib import layers 24 from tensorflow.contrib.layers.python.layers import feature_column 25 26 from tensorflow.contrib.timeseries.python.timeseries import estimators 27 from tensorflow.contrib.timeseries.python.timeseries import input_pipeline 28 from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures 29 from tensorflow.contrib.timeseries.python.timeseries.state_space_models import state_space_model 30 from tensorflow.contrib.timeseries.python.timeseries.state_space_models import structural_ensemble 31 32 from tensorflow.python.estimator import estimator_lib 33 from tensorflow.python.framework import dtypes 34 from tensorflow.python.platform import test 35 36 37 class StructuralEnsembleEstimatorTests(test.TestCase): 38 39 def simple_data(self, sample_every, dtype, period, num_samples, num_features): 40 time = sample_every * numpy.arange(num_samples) 41 noise = numpy.random.normal( 42 scale=0.01, size=[num_samples, num_features]) 43 values = noise + numpy.sin( 44 numpy.arange(num_features)[None, ...] 45 + time[..., None] / float(period) * 2.0 * numpy.pi).astype( 46 dtype.as_numpy_dtype) 47 return {TrainEvalFeatures.TIMES: numpy.reshape(time, [1, -1]), 48 TrainEvalFeatures.VALUES: numpy.reshape( 49 values, [1, -1, num_features])} 50 51 def dry_run_train_helper( 52 self, sample_every, period, num_samples, model_type, model_args, 53 num_features=1): 54 numpy.random.seed(1) 55 dtype = dtypes.float32 56 features = self.simple_data( 57 sample_every, dtype=dtype, period=period, num_samples=num_samples, 58 num_features=num_features) 59 model = model_type( 60 configuration=( 61 state_space_model.StateSpaceModelConfiguration( 62 num_features=num_features, 63 dtype=dtype, 64 covariance_prior_fn=lambda _: 0.)), 65 **model_args) 66 67 class _RunConfig(estimator_lib.RunConfig): 68 69 @property 70 def tf_random_seed(self): 71 return 4 72 73 estimator = estimators.StateSpaceRegressor(model, config=_RunConfig()) 74 train_input_fn = input_pipeline.RandomWindowInputFn( 75 input_pipeline.NumpyReader(features), num_threads=1, shuffle_seed=1, 76 batch_size=16, window_size=16) 77 eval_input_fn = input_pipeline.WholeDatasetInputFn( 78 input_pipeline.NumpyReader(features)) 79 estimator.train(input_fn=train_input_fn, max_steps=1) 80 first_evaluation = estimator.evaluate(input_fn=eval_input_fn, steps=1) 81 estimator.train(input_fn=train_input_fn, max_steps=3) 82 second_evaluation = estimator.evaluate(input_fn=eval_input_fn, steps=1) 83 self.assertLess(second_evaluation["loss"], first_evaluation["loss"]) 84 85 def test_structural_multivariate(self): 86 self.dry_run_train_helper( 87 sample_every=3, 88 period=5, 89 num_samples=100, 90 num_features=3, 91 model_type=structural_ensemble.StructuralEnsemble, 92 model_args={ 93 "periodicities": 2, 94 "moving_average_order": 2, 95 "autoregressive_order": 1 96 }) 97 98 def test_exogenous_input(self): 99 """Test that no errors are raised when using exogenous features.""" 100 dtype = dtypes.float64 101 times = [1, 2, 3, 4, 5, 6] 102 values = [[0.01], [5.10], [5.21], [0.30], [5.41], [0.50]] 103 feature_a = [["off"], ["on"], ["on"], ["off"], ["on"], ["off"]] 104 sparse_column_a = feature_column.sparse_column_with_keys( 105 column_name="feature_a", keys=["on", "off"]) 106 one_hot_a = layers.one_hot_column(sparse_id_column=sparse_column_a) 107 regressor = estimators.StructuralEnsembleRegressor( 108 periodicities=[], 109 num_features=1, 110 moving_average_order=0, 111 exogenous_feature_columns=[one_hot_a], 112 dtype=dtype) 113 features = {TrainEvalFeatures.TIMES: times, 114 TrainEvalFeatures.VALUES: values, 115 "feature_a": feature_a} 116 train_input_fn = input_pipeline.RandomWindowInputFn( 117 input_pipeline.NumpyReader(features), 118 window_size=6, batch_size=1) 119 regressor.train(input_fn=train_input_fn, steps=1) 120 eval_input_fn = input_pipeline.WholeDatasetInputFn( 121 input_pipeline.NumpyReader(features)) 122 evaluation = regressor.evaluate(input_fn=eval_input_fn, steps=1) 123 predict_input_fn = input_pipeline.predict_continuation_input_fn( 124 evaluation, times=[[7, 8, 9]], 125 exogenous_features={"feature_a": [[["on"], ["off"], ["on"]]]}) 126 regressor.predict(input_fn=predict_input_fn) 127 128 def test_no_periodicity(self): 129 """Test that no errors are raised when periodicites is None.""" 130 dtype = dtypes.float64 131 times = [1, 2, 3, 4, 5, 6] 132 values = [[0.01], [5.10], [5.21], [0.30], [5.41], [0.50]] 133 regressor = estimators.StructuralEnsembleRegressor( 134 periodicities=None, 135 num_features=1, 136 moving_average_order=0, 137 dtype=dtype) 138 features = {TrainEvalFeatures.TIMES: times, 139 TrainEvalFeatures.VALUES: values} 140 train_input_fn = input_pipeline.RandomWindowInputFn( 141 input_pipeline.NumpyReader(features), 142 window_size=6, batch_size=1) 143 regressor.train(input_fn=train_input_fn, steps=1) 144 eval_input_fn = input_pipeline.WholeDatasetInputFn( 145 input_pipeline.NumpyReader(features)) 146 evaluation = regressor.evaluate(input_fn=eval_input_fn, steps=1) 147 predict_input_fn = input_pipeline.predict_continuation_input_fn( 148 evaluation, times=[[7, 8, 9]]) 149 regressor.predict(input_fn=predict_input_fn) 150 151 if __name__ == "__main__": 152 test.main() 153