1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 r"""Demonstrates multiclass MNIST TF Boosted trees example. 16 17 This example demonstrates how to run experiments with TF Boosted Trees on 18 a binary dataset. We use digits 4 and 9 from the original MNIST dataset. 19 20 Example Usage: 21 python tensorflow/contrib/boosted_trees/examples/binary_mnist.py \ 22 --output_dir="/tmp/binary_mnist" --depth=4 --learning_rate=0.3 \ 23 --batch_size=10761 --examples_per_layer=10761 --eval_batch_size=1030 \ 24 --num_eval_steps=1 --num_trees=10 --l2=1 --vmodule=training_ops=1 25 26 When training is done, accuracy on eval data is reported. Point tensorboard 27 to the directory for the run to see how the training progresses: 28 29 tensorboard --logdir=/tmp/binary_mnist 30 31 """ 32 from __future__ import absolute_import 33 from __future__ import division 34 from __future__ import print_function 35 36 import argparse 37 import sys 38 39 import numpy as np 40 import tensorflow as tf 41 from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeClassifier 42 from tensorflow.contrib.boosted_trees.proto import learner_pb2 43 from tensorflow.contrib.learn import learn_runner 44 45 46 def get_input_fn(data, 47 batch_size, 48 capacity=10000, 49 min_after_dequeue=3000): 50 """Input function over MNIST data.""" 51 # Keep only 4 and 9 digits. 52 ids = np.where((data.labels == 4) | (data.labels == 9)) 53 images = data.images[ids] 54 labels = data.labels[ids] 55 # Make digit 4 label 1, 9 is 0. 56 labels = labels == 4 57 58 def _input_fn(): 59 """Prepare features and labels.""" 60 images_batch, labels_batch = tf.train.shuffle_batch( 61 tensors=[images, 62 labels.astype(np.int32)], 63 batch_size=batch_size, 64 capacity=capacity, 65 min_after_dequeue=min_after_dequeue, 66 enqueue_many=True, 67 num_threads=4) 68 features_map = {"images": images_batch} 69 return features_map, labels_batch 70 71 return _input_fn 72 73 74 # Main config - creates a TF Boosted Trees Estimator based on flags. 75 def _get_tfbt(output_dir): 76 """Configures TF Boosted Trees estimator based on flags.""" 77 learner_config = learner_pb2.LearnerConfig() 78 79 learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate 80 learner_config.regularization.l1 = 0.0 81 learner_config.regularization.l2 = FLAGS.l2 / FLAGS.examples_per_layer 82 learner_config.constraints.max_tree_depth = FLAGS.depth 83 84 growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER 85 learner_config.growing_mode = growing_mode 86 run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300) 87 88 # Create a TF Boosted trees estimator that can take in custom loss. 89 estimator = GradientBoostedDecisionTreeClassifier( 90 learner_config=learner_config, 91 examples_per_layer=FLAGS.examples_per_layer, 92 model_dir=output_dir, 93 num_trees=FLAGS.num_trees, 94 center_bias=False, 95 config=run_config) 96 return estimator 97 98 99 def _make_experiment_fn(output_dir): 100 """Creates experiment for gradient boosted decision trees.""" 101 data = tf.contrib.learn.datasets.mnist.load_mnist() 102 train_input_fn = get_input_fn(data.train, FLAGS.batch_size) 103 eval_input_fn = get_input_fn(data.validation, FLAGS.eval_batch_size) 104 105 return tf.contrib.learn.Experiment( 106 estimator=_get_tfbt(output_dir), 107 train_input_fn=train_input_fn, 108 eval_input_fn=eval_input_fn, 109 train_steps=None, 110 eval_steps=FLAGS.num_eval_steps, 111 eval_metrics=None) 112 113 114 def main(unused_argv): 115 learn_runner.run( 116 experiment_fn=_make_experiment_fn, 117 output_dir=FLAGS.output_dir, 118 schedule="train_and_evaluate") 119 120 121 if __name__ == "__main__": 122 tf.logging.set_verbosity(tf.logging.INFO) 123 parser = argparse.ArgumentParser() 124 # Define the list of flags that users can change. 125 parser.add_argument( 126 "--output_dir", 127 type=str, 128 required=True, 129 help="Choose the dir for the output.") 130 parser.add_argument( 131 "--batch_size", 132 type=int, 133 default=1000, 134 help="The batch size for reading data.") 135 parser.add_argument( 136 "--eval_batch_size", 137 type=int, 138 default=1000, 139 help="Size of the batch for eval.") 140 parser.add_argument( 141 "--num_eval_steps", 142 type=int, 143 default=1, 144 help="The number of steps to run evaluation for.") 145 # Flags for gradient boosted trees config. 146 parser.add_argument( 147 "--depth", type=int, default=4, help="Maximum depth of weak learners.") 148 parser.add_argument( 149 "--l2", type=float, default=1.0, help="l2 regularization per batch.") 150 parser.add_argument( 151 "--learning_rate", 152 type=float, 153 default=0.1, 154 help="Learning rate (shrinkage weight) with which each new tree is added." 155 ) 156 parser.add_argument( 157 "--examples_per_layer", 158 type=int, 159 default=1000, 160 help="Number of examples to accumulate stats for per layer.") 161 parser.add_argument( 162 "--num_trees", 163 type=int, 164 default=None, 165 required=True, 166 help="Number of trees to grow before stopping.") 167 168 FLAGS, unparsed = parser.parse_known_args() 169 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 170