Home | History | Annotate | Download | only in v4
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_
     16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_
     17 
     18 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
     19 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h"
     20 #include "tensorflow/contrib/tensor_forest/kernels/v4/params.h"
     21 #include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h"
     22 #include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h"
     23 
     24 namespace tensorflow {
     25 namespace tensorforest {
     26 
     27 // Abstract base class for classes that can initialize, get, and update leaf
     28 // models.
     29 class LeafModelOperator {
     30  public:
     31   // Number of outputs is interpreted differently for classification and
     32   // regression.  For classification, it's the number of possible classes.
     33   // For regression, it's the target dimensions.
     34   explicit LeafModelOperator(const TensorForestParams& params)
     35       : params_(params) {}
     36   virtual ~LeafModelOperator() {}
     37 
     38   // Returns the value of the requested output, which should be
     39   // in [0, num_outputs_).  For classification, it's the class count (weighted
     40   // number of instances seen).  For regression, it's e.g. the average value.
     41   virtual float GetOutputValue(const decision_trees::Leaf& leaf,
     42                                int32 o) const = 0;
     43 
     44   // Update the given Leaf's model with the given example.
     45   virtual void UpdateModel(decision_trees::Leaf* leaf,
     46                            const InputTarget* target, int example) const = 0;
     47 
     48   // Initialize an empty Leaf model.
     49   virtual void InitModel(decision_trees::Leaf* leaf) const = 0;
     50 
     51   virtual void ExportModel(const LeafStat& stat,
     52                            decision_trees::Leaf* leaf) const = 0;
     53 
     54  protected:
     55   const TensorForestParams& params_;
     56 };
     57 
     58 // LeafModelOperator that stores class counts in a dense vector.
     59 class DenseClassificationLeafModelOperator : public LeafModelOperator {
     60  public:
     61   explicit DenseClassificationLeafModelOperator(
     62       const TensorForestParams& params)
     63       : LeafModelOperator(params) {}
     64   float GetOutputValue(const decision_trees::Leaf& leaf,
     65                        int32 o) const override;
     66 
     67   void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
     68                    int example) const override;
     69 
     70   void InitModel(decision_trees::Leaf* leaf) const override;
     71 
     72   void ExportModel(const LeafStat& stat,
     73                    decision_trees::Leaf* leaf) const override;
     74 };
     75 
     76 // LeafModelOperator that stores class counts sparsely in a map. Assumes default
     77 // value for yet-unseen classes is 0.
     78 class SparseClassificationLeafModelOperator : public LeafModelOperator {
     79  public:
     80   explicit SparseClassificationLeafModelOperator(
     81       const TensorForestParams& params)
     82       : LeafModelOperator(params) {}
     83   float GetOutputValue(const decision_trees::Leaf& leaf,
     84                        int32 o) const override;
     85 
     86   void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
     87                    int example) const override;
     88 
     89   void InitModel(decision_trees::Leaf* leaf) const override {}
     90 
     91   void ExportModel(const LeafStat& stat,
     92                    decision_trees::Leaf* leaf) const override;
     93 };
     94 
     95 class SparseOrDenseClassificationLeafModelOperator : public LeafModelOperator {
     96  public:
     97   explicit SparseOrDenseClassificationLeafModelOperator(
     98       const TensorForestParams& params)
     99       : LeafModelOperator(params),
    100         dense_(new DenseClassificationLeafModelOperator(params)),
    101         sparse_(new SparseClassificationLeafModelOperator(params)) {}
    102   float GetOutputValue(const decision_trees::Leaf& leaf,
    103                        int32 o) const override;
    104 
    105   void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
    106                    int example) const override;
    107 
    108   void InitModel(decision_trees::Leaf* leaf) const override {}
    109 
    110   void ExportModel(const LeafStat& stat,
    111                    decision_trees::Leaf* leaf) const override;
    112 
    113  protected:
    114   std::unique_ptr<DenseClassificationLeafModelOperator> dense_;
    115   std::unique_ptr<SparseClassificationLeafModelOperator> sparse_;
    116 };
    117 
    118 // LeafModelOperator that stores regression leaf models with constant-value
    119 // prediction.
    120 class RegressionLeafModelOperator : public LeafModelOperator {
    121  public:
    122   explicit RegressionLeafModelOperator(const TensorForestParams& params)
    123       : LeafModelOperator(params) {}
    124   float GetOutputValue(const decision_trees::Leaf& leaf,
    125                        int32 o) const override;
    126 
    127   // TODO(gilberth): Quick experimentation suggests it's not even worth
    128   // updating model and just using the seeded values.  Can add this in
    129   // with additional_data, though protobuf::Any is slow.  Maybe make it
    130   // optional.  Maybe make any update optional.
    131   void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
    132                    int example) const override {}
    133 
    134   void InitModel(decision_trees::Leaf* leaf) const override;
    135 
    136   void ExportModel(const LeafStat& stat,
    137                    decision_trees::Leaf* leaf) const override;
    138 };
    139 
    140 class LeafModelOperatorFactory {
    141  public:
    142   static std::unique_ptr<LeafModelOperator> CreateLeafModelOperator(
    143       const TensorForestParams& params);
    144 };
    145 
    146 }  // namespace tensorforest
    147 }  // namespace tensorflow
    148 
    149 #endif  // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_
    150