1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 // RoutingFunction returns the probability of reaching each leaf node 16 // in a soft decision tree. 17 18 #include <stdlib.h> 19 #include <time.h> 20 #include <algorithm> 21 #include <cmath> 22 #include <memory> 23 #include <unordered_map> 24 #include <unordered_set> 25 #include <utility> 26 #include <vector> 27 28 #include "tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h" 29 #include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h" 30 #include "tensorflow/core/framework/op.h" 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/framework/shape_inference.h" 33 #include "tensorflow/core/framework/tensor.h" 34 #include "tensorflow/core/lib/gtl/top_n.h" 35 #include "tensorflow/core/platform/types.h" 36 #include "tensorflow/core/util/work_sharder.h" 37 38 namespace tensorflow { 39 40 using shape_inference::InferenceContext; 41 using shape_inference::ShapeHandle; 42 43 using tensorforest::CheckTensorBounds; 44 using tensorforest::LeftProbabilityK; 45 46 // The term 'routing function' is synonymous with 'the probability 47 // that an instance is routed to each leaf node.' It is defined in 48 // 'Deep Neural Decision Forests' by Kontschieder et al. 49 REGISTER_OP("KFeatureRoutingFunction") 50 .Attr("layer_num: int") 51 .Attr("max_nodes: int") 52 .Attr("num_features_per_node: int") 53 .Attr("random_seed: int") 54 .Input("input_data: float") 55 .Input("tree_parameters: float") 56 .Input("tree_biases: float") 57 .Output("probabilities: float") 58 .SetShapeFn([](InferenceContext* c) { 59 ShapeHandle input, params; 60 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input)); 61 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, ¶ms)); 62 63 c->set_output(0, c->Matrix(c->Dim(input, 0), c->Dim(params, 0))); 64 return Status::OK(); 65 }) 66 .Doc(R"doc( 67 68 Returns the probability that each input will reach each leaf node. Each 69 decision is made based on k features. 70 71 layer_num: The layer number of this tree. 72 max_nodes: The number of nodes in the tree. 73 num_features_per_node: The number of features each node can use to make a 74 decision. 75 random_seed: The base random seed. 76 77 input_data: The training batch's features as a 2-d tensor; `input_data[i][j]` 78 gives the j-th feature of the i-th input. 79 tree_parameters: `tree_parameters[i]` gives the weight of 80 the logistic regression model that translates from node features to 81 probabilities. 82 tree_biases: `tree_biases[i]` gives the bias of the logistic 83 regression model that translates from node features to 84 probabilities. 85 tree_features: `tree_features[i]` gives the decision feature for node i. 86 87 probabilities: `probabilities[i][j]` is the probability that input i 88 will reach node j. 89 )doc"); 90 91 class KFeatureRoutingFunction : public OpKernel { 92 public: 93 explicit KFeatureRoutingFunction(OpKernelConstruction* context) 94 : OpKernel(context) { 95 OP_REQUIRES_OK(context, context->GetAttr("max_nodes", &max_nodes_)); 96 OP_REQUIRES_OK(context, context->GetAttr("num_features_per_node", 97 &num_features_per_node_)); 98 OP_REQUIRES_OK(context, context->GetAttr("layer_num", &layer_num_)); 99 OP_REQUIRES_OK(context, context->GetAttr("random_seed", &random_seed_)); 100 } 101 102 void Compute(OpKernelContext* context) override { 103 const Tensor& input_data = context->input(0); 104 const Tensor& tree_parameters_tensor = context->input(1); 105 const Tensor& tree_biases_tensor = context->input(2); 106 107 if (input_data.shape().dim_size(0) > 0) { 108 OP_REQUIRES( 109 context, input_data.shape().dims() == 2, 110 errors::InvalidArgument("input_data should be two-dimensional")); 111 } 112 113 // Check tensor bounds. 114 if (!CheckTensorBounds(context, input_data)) return; 115 116 const int32 num_data = static_cast<int32>(input_data.shape().dim_size(0)); 117 const int32 num_features = 118 static_cast<int32>(input_data.shape().dim_size(1)); 119 120 Tensor* output_probabilities = nullptr; 121 TensorShape output_shape; 122 output_shape.AddDim(num_data); 123 output_shape.AddDim(max_nodes_); 124 125 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, 126 &output_probabilities)); 127 128 auto out_probs = output_probabilities->tensor<float, 2>(); 129 const auto tree_biases = tree_biases_tensor.tensor<float, 1>(); 130 131 // Iteratively compute the probability of reaching each leaf. 132 std::vector<int32> feature_set; 133 for (int i = 0; i < num_data; i++) { 134 const Tensor point = input_data.Slice(i, i + 1); 135 136 out_probs(i, 0) = 1.0f; 137 138 for (int j = 0; j < max_nodes_ / 2; j++) { 139 feature_set.clear(); 140 tensorforest::GetFeatureSet(layer_num_, i, random_seed_, num_features, 141 num_features_per_node_, &feature_set); 142 143 int32 left_child = 2 * j + 1; 144 int32 right_child = left_child + 1; 145 146 float prob = out_probs(i, j); 147 float left_prob = LeftProbabilityK( 148 point, feature_set, tree_parameters_tensor.Slice(j, j + 1), 149 tree_biases(j), num_features, num_features_per_node_); 150 151 out_probs(i, left_child) = prob * left_prob; 152 out_probs(i, right_child) = prob * (1.0f - left_prob); 153 } 154 } 155 } 156 157 private: 158 int32 layer_num_; 159 int32 max_nodes_; 160 int32 num_features_per_node_; 161 int32 random_seed_; 162 }; 163 164 REGISTER_KERNEL_BUILDER(Name("KFeatureRoutingFunction").Device(DEVICE_CPU), 165 KFeatureRoutingFunction); 166 } // namespace tensorflow 167