Home | History | Annotate | Download | only in examples
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 r"""Demonstrates multiclass MNIST TF Boosted trees example.
     16 
     17   This example demonstrates how to run experiments with TF Boosted Trees on
     18   a binary dataset. We use digits 4 and 9 from the original MNIST dataset.
     19 
     20   Example Usage:
     21   python tensorflow/contrib/boosted_trees/examples/binary_mnist.py \
     22   --output_dir="/tmp/binary_mnist" --depth=4 --learning_rate=0.3 \
     23   --batch_size=10761 --examples_per_layer=10761 --eval_batch_size=1030 \
     24   --num_eval_steps=1 --num_trees=10 --l2=1 --vmodule=training_ops=1
     25 
     26   When training is done, accuracy on eval data is reported. Point tensorboard
     27   to the directory for the run to see how the training progresses:
     28 
     29   tensorboard --logdir=/tmp/binary_mnist
     30 
     31 """
     32 from __future__ import absolute_import
     33 from __future__ import division
     34 from __future__ import print_function
     35 
     36 import argparse
     37 import sys
     38 
     39 import numpy as np
     40 import tensorflow as tf
     41 from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeClassifier
     42 from tensorflow.contrib.boosted_trees.proto import learner_pb2
     43 from tensorflow.contrib.learn import learn_runner
     44 
     45 
     46 def get_input_fn(data,
     47                  batch_size,
     48                  capacity=10000,
     49                  min_after_dequeue=3000):
     50   """Input function over MNIST data."""
     51   # Keep only 4 and 9 digits.
     52   ids = np.where((data.labels == 4) | (data.labels == 9))
     53   images = data.images[ids]
     54   labels = data.labels[ids]
     55   # Make digit 4 label 1, 9 is 0.
     56   labels = labels == 4
     57 
     58   def _input_fn():
     59     """Prepare features and labels."""
     60     images_batch, labels_batch = tf.train.shuffle_batch(
     61         tensors=[images,
     62                  labels.astype(np.int32)],
     63         batch_size=batch_size,
     64         capacity=capacity,
     65         min_after_dequeue=min_after_dequeue,
     66         enqueue_many=True,
     67         num_threads=4)
     68     features_map = {"images": images_batch}
     69     return features_map, labels_batch
     70 
     71   return _input_fn
     72 
     73 
     74 # Main config - creates a TF Boosted Trees Estimator based on flags.
     75 def _get_tfbt(output_dir):
     76   """Configures TF Boosted Trees estimator based on flags."""
     77   learner_config = learner_pb2.LearnerConfig()
     78 
     79   learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate
     80   learner_config.regularization.l1 = 0.0
     81   learner_config.regularization.l2 = FLAGS.l2 / FLAGS.examples_per_layer
     82   learner_config.constraints.max_tree_depth = FLAGS.depth
     83 
     84   growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER
     85   learner_config.growing_mode = growing_mode
     86   run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300)
     87 
     88   # Create a TF Boosted trees estimator that can take in custom loss.
     89   estimator = GradientBoostedDecisionTreeClassifier(
     90       learner_config=learner_config,
     91       examples_per_layer=FLAGS.examples_per_layer,
     92       model_dir=output_dir,
     93       num_trees=FLAGS.num_trees,
     94       center_bias=False,
     95       config=run_config)
     96   return estimator
     97 
     98 
     99 def _make_experiment_fn(output_dir):
    100   """Creates experiment for gradient boosted decision trees."""
    101   data = tf.contrib.learn.datasets.mnist.load_mnist()
    102   train_input_fn = get_input_fn(data.train, FLAGS.batch_size)
    103   eval_input_fn = get_input_fn(data.validation, FLAGS.eval_batch_size)
    104 
    105   return tf.contrib.learn.Experiment(
    106       estimator=_get_tfbt(output_dir),
    107       train_input_fn=train_input_fn,
    108       eval_input_fn=eval_input_fn,
    109       train_steps=None,
    110       eval_steps=FLAGS.num_eval_steps,
    111       eval_metrics=None)
    112 
    113 
    114 def main(unused_argv):
    115   learn_runner.run(
    116       experiment_fn=_make_experiment_fn,
    117       output_dir=FLAGS.output_dir,
    118       schedule="train_and_evaluate")
    119 
    120 
    121 if __name__ == "__main__":
    122   tf.logging.set_verbosity(tf.logging.INFO)
    123   parser = argparse.ArgumentParser()
    124   # Define the list of flags that users can change.
    125   parser.add_argument(
    126       "--output_dir",
    127       type=str,
    128       required=True,
    129       help="Choose the dir for the output.")
    130   parser.add_argument(
    131       "--batch_size",
    132       type=int,
    133       default=1000,
    134       help="The batch size for reading data.")
    135   parser.add_argument(
    136       "--eval_batch_size",
    137       type=int,
    138       default=1000,
    139       help="Size of the batch for eval.")
    140   parser.add_argument(
    141       "--num_eval_steps",
    142       type=int,
    143       default=1,
    144       help="The number of steps to run evaluation for.")
    145   # Flags for gradient boosted trees config.
    146   parser.add_argument(
    147       "--depth", type=int, default=4, help="Maximum depth of weak learners.")
    148   parser.add_argument(
    149       "--l2", type=float, default=1.0, help="l2 regularization per batch.")
    150   parser.add_argument(
    151       "--learning_rate",
    152       type=float,
    153       default=0.1,
    154       help="Learning rate (shrinkage weight) with which each new tree is added."
    155   )
    156   parser.add_argument(
    157       "--examples_per_layer",
    158       type=int,
    159       default=1000,
    160       help="Number of examples to accumulate stats for per layer.")
    161   parser.add_argument(
    162       "--num_trees",
    163       type=int,
    164       default=None,
    165       required=True,
    166       help="Number of trees to grow before stopping.")
    167 
    168   FLAGS, unparsed = parser.parse_known_args()
    169   tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
    170