1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for TensorForestTrainer.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.contrib.learn.python.learn.datasets import base 24 from tensorflow.contrib.tensor_forest.client import random_forest 25 from tensorflow.contrib.tensor_forest.python import tensor_forest 26 from tensorflow.python.platform import test 27 28 29 class TensorForestTrainerTests(test.TestCase): 30 31 def testClassification(self): 32 """Tests multi-class classification using matrix data as input.""" 33 hparams = tensor_forest.ForestHParams( 34 num_trees=3, 35 max_nodes=1000, 36 num_classes=3, 37 num_features=4, 38 split_after_samples=20, 39 inference_tree_paths=True) 40 classifier = random_forest.TensorForestEstimator(hparams.fill()) 41 42 iris = base.load_iris() 43 data = iris.data.astype(np.float32) 44 labels = iris.target.astype(np.int32) 45 46 classifier.fit(x=data, y=labels, steps=100, batch_size=50) 47 classifier.evaluate(x=data, y=labels, steps=10) 48 49 def testRegression(self): 50 """Tests multi-class classification using matrix data as input.""" 51 52 hparams = tensor_forest.ForestHParams( 53 num_trees=3, 54 max_nodes=1000, 55 num_classes=1, 56 num_features=13, 57 regression=True, 58 split_after_samples=20) 59 60 regressor = random_forest.TensorForestEstimator(hparams.fill()) 61 62 boston = base.load_boston() 63 data = boston.data.astype(np.float32) 64 labels = boston.target.astype(np.int32) 65 66 regressor.fit(x=data, y=labels, steps=100, batch_size=50) 67 regressor.evaluate(x=data, y=labels, steps=10) 68 69 70 if __name__ == "__main__": 71 test.main() 72