Home | History | Annotate | Download | only in ops
      1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #include <stdlib.h>
     16 #include <time.h>
     17 #include <algorithm>
     18 #include <cmath>
     19 #include <memory>
     20 #include <unordered_map>
     21 #include <unordered_set>
     22 #include <utility>
     23 #include <vector>
     24 
     25 #include "tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h"
     26 #include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
     27 #include "tensorflow/core/framework/op.h"
     28 #include "tensorflow/core/framework/op_kernel.h"
     29 #include "tensorflow/core/framework/tensor.h"
     30 #include "tensorflow/core/lib/gtl/top_n.h"
     31 #include "tensorflow/core/platform/types.h"
     32 #include "tensorflow/core/util/work_sharder.h"
     33 
     34 namespace tensorflow {
     35 
     36 using tensorforest::LeftProbabilityK;
     37 
     38 REGISTER_OP("KFeatureGradient")
     39     .Attr("layer_num: int")
     40     .Attr("random_seed: int")
     41     .Input("input_data: float")
     42     .Input("tree_parameters: float")
     43     .Input("tree_biases: float")
     44     .Input("routes: float")
     45     .Output("routing_gradient: float")
     46     .Output("data_gradient: float")
     47     .Output("weight_gradient: float")
     48     .Doc(R"doc(
     49     Computes the derivative of the routing loss with respect to each decision
     50     node.  Each decision node is constrained to make a decision based on only
     51     k features.
     52 
     53     layer_num: The layer number of this tree.
     54     random_seed: The base random seed.
     55 
     56     input_data: The training batch's features as a 2-d tensor;
     57      `input_data[i][j]` gives the j-th feature of the i-th input.
     58     tree_parameters: `tree_parameters[i]` gives the weight of
     59      the logistic regression model that translates from node features to
     60      probabilities.
     61     tree_biases: `tree_biases[i]` gives the bias of the logistic
     62      regression model that translates from node features to
     63      probabilities.
     64     routes: The routes computed by routing_function_op.
     65 
     66     routing_gradient: `routing_gradient` provides du / df, where u is the
     67      routing function and f is the (vector of) decision functions.  A decision
     68      function f_i computes the routing decision at node i.
     69 
     70     data_gradient: `data_gradient` provides df / dx, where f is the (vector
     71      of) decision functions and x is a batch of data.
     72 
     73     weights_gradient: `weights_gradient` provides df / dw, where f is the
     74      (vector of) decision functions and w is the matrix of parameters that
     75      determine how instances are routed through a tree.
     76 
     77     f_i, the decision function at node i, is parameterized by t_i (parameters)
     78     and b_i (bias) and takes data x as input.  This op is called in
     79     training_ops.py to compute du / df, and we use that to compute
     80 
     81     du / dx = du / df * df / dx,
     82     du / dt = du / df * df / dt, and
     83     du / db = du / df * df / db.
     84 )doc");
     85 
     86 class KFeatureGradient : public OpKernel {
     87  public:
     88   explicit KFeatureGradient(OpKernelConstruction* context) : OpKernel(context) {
     89     OP_REQUIRES_OK(context, context->GetAttr("layer_num", &layer_num_));
     90     OP_REQUIRES_OK(context, context->GetAttr("random_seed", &random_seed_));
     91   }
     92 
     93   void Compute(OpKernelContext* context) override {
     94     // Gather input.
     95     const Tensor& input_data_tensor = context->input(0);
     96     const Tensor& tree_parameters_tensor = context->input(1);
     97     const Tensor& tree_biases_tensor = context->input(2);
     98     const Tensor& routing_tensor = context->input(3);
     99 
    100     // Extract dimensions from input tensors.
    101     const int32 num_data =
    102         static_cast<int32>(input_data_tensor.shape().dim_size(0));
    103     const int32 num_features =
    104         static_cast<int32>(input_data_tensor.shape().dim_size(1));
    105     const int32 num_nodes =
    106         static_cast<int32>(tree_parameters_tensor.shape().dim_size(0));
    107     const int32 num_features_per_node =
    108         static_cast<int32>(tree_parameters_tensor.shape().dim_size(1));
    109 
    110     // Construct output tensors.
    111     Tensor* out_routes = nullptr;
    112     TensorShape out_routes_shape;
    113     out_routes_shape.AddDim(num_data);
    114     out_routes_shape.AddDim(num_nodes);
    115 
    116     Tensor* out_data = nullptr;
    117     TensorShape out_data_shape;
    118     out_data_shape.AddDim(num_nodes);
    119     out_data_shape.AddDim(num_features);
    120 
    121     Tensor* out_weights = nullptr;
    122     TensorShape out_weights_shape;
    123     out_weights_shape.AddDim(num_data);
    124     out_weights_shape.AddDim(num_nodes);
    125     out_weights_shape.AddDim(num_features_per_node);
    126 
    127     OP_REQUIRES_OK(context,
    128                    context->allocate_output(0, out_routes_shape, &out_routes));
    129     OP_REQUIRES_OK(context,
    130                    context->allocate_output(1, out_data_shape, &out_data));
    131     OP_REQUIRES_OK(
    132         context, context->allocate_output(2, out_weights_shape, &out_weights));
    133 
    134     tensorforest::Initialize(*out_data, 0.0f);
    135 
    136     // Compute output.
    137     const auto input_data = input_data_tensor.tensor<float, 2>();
    138     const auto tree_parameters = tree_parameters_tensor.tensor<float, 2>();
    139     const auto tree_biases = tree_biases_tensor.tensor<float, 1>();
    140     const auto routes = routing_tensor.tensor<float, 2>();
    141 
    142     auto routes_grad = out_routes->tensor<float, 2>();
    143     auto data_grad = out_data->tensor<float, 2>();
    144     auto weights_grad = out_weights->tensor<float, 3>();
    145 
    146     std::vector<int32> feature_set;
    147     for (int i = 0; i < num_data; i++) {
    148       const Tensor point = input_data_tensor.Slice(i, i + 1);
    149       feature_set.clear();
    150 
    151       // Traverse the tree from the bottom up.
    152       for (int j = num_nodes - 1; j >= 0; j--) {
    153         tensorforest::GetFeatureSet(layer_num_, j, random_seed_, num_features,
    154                                     num_features_per_node, &feature_set);
    155 
    156         // Compute routing gradient.
    157         // j is a leaf node.
    158         if (j >= num_nodes / 2) {
    159           routes_grad(i, j) = routes(i, j);
    160         } else {  // j is not a leaf node
    161           int32 left_child = 2 * j + 1;
    162           int32 right_child = left_child + 1;
    163 
    164           float left_prob = LeftProbabilityK(
    165               point, feature_set, tree_parameters_tensor.Slice(j, j + 1),
    166               tree_biases(j), num_features, num_features_per_node);
    167 
    168           float right_prob = 1.0f - left_prob;
    169 
    170           routes_grad(i, j) = (right_prob * routes(i, left_child) +
    171                                left_prob * routes(i, right_child));
    172         }
    173         // Compute data and weight gradient.
    174         for (int k = 0; k < num_features_per_node; k++) {
    175           CHECK_LT(feature_set[k], num_features);
    176           data_grad(j, feature_set[k]) = tree_parameters(j, k);
    177           weights_grad(i, j, k) = input_data(i, feature_set[k]);
    178         }
    179       }
    180     }
    181   }
    182 
    183  private:
    184   int32 layer_num_;
    185   int32 random_seed_;
    186 };
    187 
    188 REGISTER_KERNEL_BUILDER(Name("KFeatureGradient").Device(DEVICE_CPU),
    189                         KFeatureGradient);
    190 }  // namespace tensorflow
    191