Home | History | Annotate | Download | only in v4
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_
     16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_
     17 
     18 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
     19 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h"
     20 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h"
     21 
     22 namespace tensorflow {
     23 namespace tensorforest {
     24 
     25 // Base class for evaluators of decision nodes that effectively copy proto
     26 // contents into C++ structures for faster execution.
     27 class DecisionNodeEvaluator {
     28  public:
     29   virtual ~DecisionNodeEvaluator() {}
     30 
     31   // Returns the index of the child node.
     32   virtual int32 Decide(const std::unique_ptr<TensorDataSet>& dataset,
     33                        int example) const = 0;
     34 };
     35 
     36 // An evaluator for binary decisions with left and right children.
     37 class BinaryDecisionNodeEvaluator : public DecisionNodeEvaluator {
     38  protected:
     39   BinaryDecisionNodeEvaluator(int32 left, int32 right)
     40       : left_child_id_(left), right_child_id_(right) {}
     41 
     42   int32 left_child_id_;
     43   int32 right_child_id_;
     44 };
     45 
     46 // Evaluator for basic inequality decisions (f[x] <= T).
     47 class InequalityDecisionNodeEvaluator : public BinaryDecisionNodeEvaluator {
     48  public:
     49   InequalityDecisionNodeEvaluator(const decision_trees::InequalityTest& test,
     50                                   int32 left, int32 right);
     51 
     52   int32 Decide(const std::unique_ptr<TensorDataSet>& dataset,
     53                int example) const override;
     54 
     55  protected:
     56   int32 feature_num_;
     57   float threshold_;
     58 
     59   // If decision is '<=' as opposed to '<'.
     60   bool include_equals_;
     61 };
     62 
     63 // Evalutor for splits with multiple weighted features.
     64 class ObliqueInequalityDecisionNodeEvaluator
     65     : public BinaryDecisionNodeEvaluator {
     66  public:
     67   ObliqueInequalityDecisionNodeEvaluator(
     68       const decision_trees::InequalityTest& test, int32 left, int32 right);
     69 
     70   int32 Decide(const std::unique_ptr<TensorDataSet>& dataset,
     71                int example) const override;
     72 
     73  protected:
     74   std::vector<int32> feature_num_;
     75   std::vector<float> feature_weights_;
     76   float threshold_;
     77 };
     78 
     79 // Evaluator for contains-in-set decisions.  Also supports inverse (not-in-set).
     80 class MatchingValuesDecisionNodeEvaluator : public BinaryDecisionNodeEvaluator {
     81  public:
     82   MatchingValuesDecisionNodeEvaluator(
     83       const decision_trees::MatchingValuesTest& test, int32 left, int32 right);
     84 
     85   int32 Decide(const std::unique_ptr<TensorDataSet>& dataset,
     86                int example) const override;
     87 
     88  protected:
     89   int32 feature_num_;
     90   std::vector<float> values_;
     91   bool inverse_;
     92 };
     93 
     94 std::unique_ptr<DecisionNodeEvaluator> CreateDecisionNodeEvaluator(
     95     const decision_trees::TreeNode& node);
     96 std::unique_ptr<DecisionNodeEvaluator> CreateBinaryDecisionNodeEvaluator(
     97     const decision_trees::BinaryNode& node, int32 left, int32 right);
     98 
     99 struct CandidateEvalatorCollection {
    100   std::vector<std::unique_ptr<DecisionNodeEvaluator>> splits;
    101 };
    102 
    103 }  // namespace tensorforest
    104 }  // namespace tensorflow
    105 
    106 #endif  // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_
    107