1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h" 16 #include "tensorflow/core/lib/strings/numbers.h" 17 18 namespace tensorflow { 19 namespace tensorforest { 20 21 std::unique_ptr<DecisionNodeEvaluator> CreateDecisionNodeEvaluator( 22 const decision_trees::TreeNode& node) { 23 const decision_trees::BinaryNode& bnode = node.binary_node(); 24 return CreateBinaryDecisionNodeEvaluator(bnode, bnode.left_child_id().value(), 25 bnode.right_child_id().value()); 26 } 27 28 std::unique_ptr<DecisionNodeEvaluator> CreateBinaryDecisionNodeEvaluator( 29 const decision_trees::BinaryNode& bnode, int32 left, int32 right) { 30 if (bnode.has_inequality_left_child_test()) { 31 const auto& test = bnode.inequality_left_child_test(); 32 if (test.has_oblique()) { 33 return std::unique_ptr<ObliqueInequalityDecisionNodeEvaluator>( 34 new ObliqueInequalityDecisionNodeEvaluator(test, left, right)); 35 } else { 36 return std::unique_ptr<InequalityDecisionNodeEvaluator>( 37 new InequalityDecisionNodeEvaluator(test, left, right)); 38 } 39 } else { 40 decision_trees::MatchingValuesTest test; 41 if (bnode.custom_left_child_test().UnpackTo(&test)) { 42 return std::unique_ptr<MatchingValuesDecisionNodeEvaluator>( 43 new MatchingValuesDecisionNodeEvaluator(test, left, right)); 44 } else { 45 LOG(ERROR) << "Unknown split test: " << bnode.DebugString(); 46 return nullptr; 47 } 48 } 49 } 50 51 InequalityDecisionNodeEvaluator::InequalityDecisionNodeEvaluator( 52 const decision_trees::InequalityTest& test, int32 left, int32 right) 53 : BinaryDecisionNodeEvaluator(left, right) { 54 safe_strto32(test.feature_id().id().value(), &feature_num_); 55 threshold_ = test.threshold().float_value(); 56 include_equals_ = 57 test.type() == decision_trees::InequalityTest::LESS_OR_EQUAL; 58 } 59 60 int32 InequalityDecisionNodeEvaluator::Decide( 61 const std::unique_ptr<TensorDataSet>& dataset, int example) const { 62 const float val = dataset->GetExampleValue(example, feature_num_); 63 if (val < threshold_ || (include_equals_ && val == threshold_)) { 64 return left_child_id_; 65 } else { 66 return right_child_id_; 67 } 68 } 69 70 ObliqueInequalityDecisionNodeEvaluator::ObliqueInequalityDecisionNodeEvaluator( 71 const decision_trees::InequalityTest& test, int32 left, int32 right) 72 : BinaryDecisionNodeEvaluator(left, right) { 73 for (int i = 0; i < test.oblique().features_size(); ++i) { 74 int32 val; 75 safe_strto32(test.oblique().features(i).id().value(), &val); 76 feature_num_.push_back(val); 77 feature_weights_.push_back(test.oblique().weights(i)); 78 } 79 threshold_ = test.threshold().float_value(); 80 } 81 82 int32 ObliqueInequalityDecisionNodeEvaluator::Decide( 83 const std::unique_ptr<TensorDataSet>& dataset, int example) const { 84 float val = 0; 85 for (int i = 0; i < feature_num_.size(); ++i) { 86 val += feature_weights_[i] * 87 dataset->GetExampleValue(example, feature_num_[i]); 88 } 89 90 if (val <= threshold_) { 91 return left_child_id_; 92 } else { 93 return right_child_id_; 94 } 95 } 96 97 MatchingValuesDecisionNodeEvaluator::MatchingValuesDecisionNodeEvaluator( 98 const decision_trees::MatchingValuesTest& test, int32 left, int32 right) 99 : BinaryDecisionNodeEvaluator(left, right) { 100 safe_strto32(test.feature_id().id().value(), &feature_num_); 101 for (const auto& val : test.value()) { 102 values_.push_back(val.float_value()); 103 } 104 inverse_ = test.inverse(); 105 } 106 107 int32 MatchingValuesDecisionNodeEvaluator::Decide( 108 const std::unique_ptr<TensorDataSet>& dataset, int example) const { 109 const float val = dataset->GetExampleValue(example, feature_num_); 110 for (float testval : values_) { 111 if (val == testval) { 112 return inverse_ ? right_child_id_ : left_child_id_; 113 } 114 } 115 116 return inverse_ ? left_child_id_ : right_child_id_; 117 } 118 119 } // namespace tensorforest 120 } // namespace tensorflow 121