1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_ 16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_ 17 18 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" 19 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h" 20 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" 21 22 namespace tensorflow { 23 namespace tensorforest { 24 25 // Base class for evaluators of decision nodes that effectively copy proto 26 // contents into C++ structures for faster execution. 27 class DecisionNodeEvaluator { 28 public: 29 virtual ~DecisionNodeEvaluator() {} 30 31 // Returns the index of the child node. 32 virtual int32 Decide(const std::unique_ptr<TensorDataSet>& dataset, 33 int example) const = 0; 34 }; 35 36 // An evaluator for binary decisions with left and right children. 37 class BinaryDecisionNodeEvaluator : public DecisionNodeEvaluator { 38 protected: 39 BinaryDecisionNodeEvaluator(int32 left, int32 right) 40 : left_child_id_(left), right_child_id_(right) {} 41 42 int32 left_child_id_; 43 int32 right_child_id_; 44 }; 45 46 // Evaluator for basic inequality decisions (f[x] <= T). 47 class InequalityDecisionNodeEvaluator : public BinaryDecisionNodeEvaluator { 48 public: 49 InequalityDecisionNodeEvaluator(const decision_trees::InequalityTest& test, 50 int32 left, int32 right); 51 52 int32 Decide(const std::unique_ptr<TensorDataSet>& dataset, 53 int example) const override; 54 55 protected: 56 int32 feature_num_; 57 float threshold_; 58 59 // If decision is '<=' as opposed to '<'. 60 bool include_equals_; 61 }; 62 63 // Evalutor for splits with multiple weighted features. 64 class ObliqueInequalityDecisionNodeEvaluator 65 : public BinaryDecisionNodeEvaluator { 66 public: 67 ObliqueInequalityDecisionNodeEvaluator( 68 const decision_trees::InequalityTest& test, int32 left, int32 right); 69 70 int32 Decide(const std::unique_ptr<TensorDataSet>& dataset, 71 int example) const override; 72 73 protected: 74 std::vector<int32> feature_num_; 75 std::vector<float> feature_weights_; 76 float threshold_; 77 }; 78 79 // Evaluator for contains-in-set decisions. Also supports inverse (not-in-set). 80 class MatchingValuesDecisionNodeEvaluator : public BinaryDecisionNodeEvaluator { 81 public: 82 MatchingValuesDecisionNodeEvaluator( 83 const decision_trees::MatchingValuesTest& test, int32 left, int32 right); 84 85 int32 Decide(const std::unique_ptr<TensorDataSet>& dataset, 86 int example) const override; 87 88 protected: 89 int32 feature_num_; 90 std::vector<float> values_; 91 bool inverse_; 92 }; 93 94 std::unique_ptr<DecisionNodeEvaluator> CreateDecisionNodeEvaluator( 95 const decision_trees::TreeNode& node); 96 std::unique_ptr<DecisionNodeEvaluator> CreateBinaryDecisionNodeEvaluator( 97 const decision_trees::BinaryNode& node, int32 left, int32 right); 98 99 struct CandidateEvalatorCollection { 100 std::vector<std::unique_ptr<DecisionNodeEvaluator>> splits; 101 }; 102 103 } // namespace tensorforest 104 } // namespace tensorflow 105 106 #endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_ 107