Home | History | Annotate | Download | only in v4
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h"
     16 #include "tensorflow/core/lib/strings/numbers.h"
     17 
     18 namespace tensorflow {
     19 namespace tensorforest {
     20 
     21 std::unique_ptr<DecisionNodeEvaluator> CreateDecisionNodeEvaluator(
     22     const decision_trees::TreeNode& node) {
     23   const decision_trees::BinaryNode& bnode = node.binary_node();
     24   return CreateBinaryDecisionNodeEvaluator(bnode, bnode.left_child_id().value(),
     25                                            bnode.right_child_id().value());
     26 }
     27 
     28 std::unique_ptr<DecisionNodeEvaluator> CreateBinaryDecisionNodeEvaluator(
     29     const decision_trees::BinaryNode& bnode, int32 left, int32 right) {
     30   if (bnode.has_inequality_left_child_test()) {
     31     const auto& test = bnode.inequality_left_child_test();
     32     if (test.has_oblique()) {
     33       return std::unique_ptr<ObliqueInequalityDecisionNodeEvaluator>(
     34           new ObliqueInequalityDecisionNodeEvaluator(test, left, right));
     35     } else {
     36       return std::unique_ptr<InequalityDecisionNodeEvaluator>(
     37           new InequalityDecisionNodeEvaluator(test, left, right));
     38     }
     39   } else {
     40     decision_trees::MatchingValuesTest test;
     41     if (bnode.custom_left_child_test().UnpackTo(&test)) {
     42       return std::unique_ptr<MatchingValuesDecisionNodeEvaluator>(
     43           new MatchingValuesDecisionNodeEvaluator(test, left, right));
     44     } else {
     45       LOG(ERROR) << "Unknown split test: " << bnode.DebugString();
     46       return nullptr;
     47     }
     48   }
     49 }
     50 
     51 InequalityDecisionNodeEvaluator::InequalityDecisionNodeEvaluator(
     52     const decision_trees::InequalityTest& test, int32 left, int32 right)
     53     : BinaryDecisionNodeEvaluator(left, right) {
     54   safe_strto32(test.feature_id().id().value(), &feature_num_);
     55   threshold_ = test.threshold().float_value();
     56   include_equals_ =
     57       test.type() == decision_trees::InequalityTest::LESS_OR_EQUAL;
     58 }
     59 
     60 int32 InequalityDecisionNodeEvaluator::Decide(
     61     const std::unique_ptr<TensorDataSet>& dataset, int example) const {
     62   const float val = dataset->GetExampleValue(example, feature_num_);
     63   if (val < threshold_ || (include_equals_ && val == threshold_)) {
     64     return left_child_id_;
     65   } else {
     66     return right_child_id_;
     67   }
     68 }
     69 
     70 ObliqueInequalityDecisionNodeEvaluator::ObliqueInequalityDecisionNodeEvaluator(
     71     const decision_trees::InequalityTest& test, int32 left, int32 right)
     72     : BinaryDecisionNodeEvaluator(left, right) {
     73   for (int i = 0; i < test.oblique().features_size(); ++i) {
     74     int32 val;
     75     safe_strto32(test.oblique().features(i).id().value(), &val);
     76     feature_num_.push_back(val);
     77     feature_weights_.push_back(test.oblique().weights(i));
     78   }
     79   threshold_ = test.threshold().float_value();
     80 }
     81 
     82 int32 ObliqueInequalityDecisionNodeEvaluator::Decide(
     83     const std::unique_ptr<TensorDataSet>& dataset, int example) const {
     84   float val = 0;
     85   for (int i = 0; i < feature_num_.size(); ++i) {
     86     val += feature_weights_[i] *
     87            dataset->GetExampleValue(example, feature_num_[i]);
     88   }
     89 
     90   if (val <= threshold_) {
     91     return left_child_id_;
     92   } else {
     93     return right_child_id_;
     94   }
     95 }
     96 
     97 MatchingValuesDecisionNodeEvaluator::MatchingValuesDecisionNodeEvaluator(
     98     const decision_trees::MatchingValuesTest& test, int32 left, int32 right)
     99     : BinaryDecisionNodeEvaluator(left, right) {
    100   safe_strto32(test.feature_id().id().value(), &feature_num_);
    101   for (const auto& val : test.value()) {
    102     values_.push_back(val.float_value());
    103   }
    104   inverse_ = test.inverse();
    105 }
    106 
    107 int32 MatchingValuesDecisionNodeEvaluator::Decide(
    108     const std::unique_ptr<TensorDataSet>& dataset, int example) const {
    109   const float val = dataset->GetExampleValue(example, feature_num_);
    110   for (float testval : values_) {
    111     if (val == testval) {
    112       return inverse_ ? right_child_id_ : left_child_id_;
    113     }
    114   }
    115 
    116   return inverse_ ? left_child_id_ : right_child_id_;
    117 }
    118 
    119 }  // namespace tensorforest
    120 }  // namespace tensorflow
    121