Home | History | Annotate | Download | only in ops
      1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 // RoutingFunction returns the probability of reaching each leaf node
     16 // in a soft decision tree.
     17 
     18 #include <stdlib.h>
     19 #include <time.h>
     20 #include <algorithm>
     21 #include <cmath>
     22 #include <memory>
     23 #include <unordered_map>
     24 #include <unordered_set>
     25 #include <utility>
     26 #include <vector>
     27 
     28 #include "tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h"
     29 #include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
     30 #include "tensorflow/core/framework/op.h"
     31 #include "tensorflow/core/framework/op_kernel.h"
     32 #include "tensorflow/core/framework/shape_inference.h"
     33 #include "tensorflow/core/framework/tensor.h"
     34 #include "tensorflow/core/lib/gtl/top_n.h"
     35 #include "tensorflow/core/platform/types.h"
     36 #include "tensorflow/core/util/work_sharder.h"
     37 
     38 namespace tensorflow {
     39 
     40 using shape_inference::InferenceContext;
     41 using shape_inference::ShapeHandle;
     42 
     43 using tensorforest::CheckTensorBounds;
     44 using tensorforest::LeftProbabilityK;
     45 
     46 // The term 'routing function' is synonymous with 'the probability
     47 // that an instance is routed to each leaf node.'  It is defined in
     48 // 'Deep Neural Decision Forests' by Kontschieder et al.
     49 REGISTER_OP("KFeatureRoutingFunction")
     50     .Attr("layer_num: int")
     51     .Attr("max_nodes: int")
     52     .Attr("num_features_per_node: int")
     53     .Attr("random_seed: int")
     54     .Input("input_data: float")
     55     .Input("tree_parameters: float")
     56     .Input("tree_biases: float")
     57     .Output("probabilities: float")
     58     .SetShapeFn([](InferenceContext* c) {
     59       ShapeHandle input, params;
     60       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
     61       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &params));
     62 
     63       c->set_output(0, c->Matrix(c->Dim(input, 0), c->Dim(params, 0)));
     64       return Status::OK();
     65     })
     66     .Doc(R"doc(
     67 
     68   Returns the probability that each input will reach each leaf node.  Each
     69   decision is made based on k features.
     70 
     71   layer_num: The layer number of this tree.
     72   max_nodes: The number of nodes in the tree.
     73   num_features_per_node: The number of features each node can use to make a
     74    decision.
     75   random_seed: The base random seed.
     76 
     77   input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
     78    gives the j-th feature of the i-th input.
     79   tree_parameters: `tree_parameters[i]` gives the weight of
     80    the logistic regression model that translates from node features to
     81    probabilities.
     82   tree_biases: `tree_biases[i]` gives the bias of the logistic
     83    regression model that translates from node features to
     84    probabilities.
     85   tree_features: `tree_features[i]` gives the decision feature for node i.
     86 
     87   probabilities: `probabilities[i][j]` is the probability that input i
     88    will reach node j.
     89 )doc");
     90 
     91 class KFeatureRoutingFunction : public OpKernel {
     92  public:
     93   explicit KFeatureRoutingFunction(OpKernelConstruction* context)
     94       : OpKernel(context) {
     95     OP_REQUIRES_OK(context, context->GetAttr("max_nodes", &max_nodes_));
     96     OP_REQUIRES_OK(context, context->GetAttr("num_features_per_node",
     97                                              &num_features_per_node_));
     98     OP_REQUIRES_OK(context, context->GetAttr("layer_num", &layer_num_));
     99     OP_REQUIRES_OK(context, context->GetAttr("random_seed", &random_seed_));
    100   }
    101 
    102   void Compute(OpKernelContext* context) override {
    103     const Tensor& input_data = context->input(0);
    104     const Tensor& tree_parameters_tensor = context->input(1);
    105     const Tensor& tree_biases_tensor = context->input(2);
    106 
    107     if (input_data.shape().dim_size(0) > 0) {
    108       OP_REQUIRES(
    109           context, input_data.shape().dims() == 2,
    110           errors::InvalidArgument("input_data should be two-dimensional"));
    111     }
    112 
    113     // Check tensor bounds.
    114     if (!CheckTensorBounds(context, input_data)) return;
    115 
    116     const int32 num_data = static_cast<int32>(input_data.shape().dim_size(0));
    117     const int32 num_features =
    118         static_cast<int32>(input_data.shape().dim_size(1));
    119 
    120     Tensor* output_probabilities = nullptr;
    121     TensorShape output_shape;
    122     output_shape.AddDim(num_data);
    123     output_shape.AddDim(max_nodes_);
    124 
    125     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape,
    126                                                      &output_probabilities));
    127 
    128     auto out_probs = output_probabilities->tensor<float, 2>();
    129     const auto tree_biases = tree_biases_tensor.tensor<float, 1>();
    130 
    131     // Iteratively compute the probability of reaching each leaf.
    132     std::vector<int32> feature_set;
    133     for (int i = 0; i < num_data; i++) {
    134       const Tensor point = input_data.Slice(i, i + 1);
    135 
    136       out_probs(i, 0) = 1.0f;
    137 
    138       for (int j = 0; j < max_nodes_ / 2; j++) {
    139         feature_set.clear();
    140         tensorforest::GetFeatureSet(layer_num_, i, random_seed_, num_features,
    141                                     num_features_per_node_, &feature_set);
    142 
    143         int32 left_child = 2 * j + 1;
    144         int32 right_child = left_child + 1;
    145 
    146         float prob = out_probs(i, j);
    147         float left_prob = LeftProbabilityK(
    148             point, feature_set, tree_parameters_tensor.Slice(j, j + 1),
    149             tree_biases(j), num_features, num_features_per_node_);
    150 
    151         out_probs(i, left_child) = prob * left_prob;
    152         out_probs(i, right_child) = prob * (1.0f - left_prob);
    153       }
    154     }
    155   }
    156 
    157  private:
    158   int32 layer_num_;
    159   int32 max_nodes_;
    160   int32 num_features_per_node_;
    161   int32 random_seed_;
    162 };
    163 
    164 REGISTER_KERNEL_BUILDER(Name("KFeatureRoutingFunction").Device(DEVICE_CPU),
    165                         KFeatureRoutingFunction);
    166 }  // namespace tensorflow
    167