Home | History | Annotate | Download | only in kernels
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #include <functional>
     16 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
     17 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h"
     18 #include "tensorflow/contrib/tensor_forest/kernels/data_spec.h"
     19 #include "tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h"
     20 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h"
     21 #include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/resource_mgr.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/tensor_shape.h"
     26 #include "tensorflow/core/framework/tensor_types.h"
     27 #include "tensorflow/core/lib/strings/strcat.h"
     28 #include "tensorflow/core/platform/thread_annotations.h"
     29 #include "tensorflow/core/platform/types.h"
     30 #include "tensorflow/core/util/work_sharder.h"
     31 
     32 namespace tensorflow {
     33 namespace tensorforest {
     34 
     35 // Creates a tree  variable.
     36 class CreateTreeVariableOp : public OpKernel {
     37  public:
     38   explicit CreateTreeVariableOp(OpKernelConstruction* context)
     39       : OpKernel(context) {
     40     string serialized_params;
     41     OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
     42     ParseProtoUnlimited(&param_proto_, serialized_params);
     43   }
     44 
     45   void Compute(OpKernelContext* context) override {
     46     const Tensor* tree_config_t;
     47     OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t));
     48     OP_REQUIRES(context, TensorShapeUtils::IsScalar(tree_config_t->shape()),
     49                 errors::InvalidArgument("Tree config must be a scalar."));
     50 
     51     auto* result = new DecisionTreeResource(param_proto_);
     52     if (!ParseProtoUnlimited(result->mutable_decision_tree(),
     53                              tree_config_t->scalar<string>()())) {
     54       result->Unref();
     55       OP_REQUIRES(context, false,
     56                   errors::InvalidArgument("Unable to parse tree  config."));
     57     }
     58 
     59     result->MaybeInitialize();
     60 
     61     // Only create one, if one does not exist already. Report status for all
     62     // other exceptions.
     63     auto status = CreateResource(context, HandleFromInput(context, 0), result);
     64     if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
     65       OP_REQUIRES(context, false, status);
     66     }
     67   }
     68 
     69  private:
     70   TensorForestParams param_proto_;
     71 };
     72 
     73 // Op for serializing a model.
     74 class TreeSerializeOp : public OpKernel {
     75  public:
     76   explicit TreeSerializeOp(OpKernelConstruction* context) : OpKernel(context) {}
     77 
     78   void Compute(OpKernelContext* context) override {
     79     DecisionTreeResource* decision_tree_resource;
     80     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
     81                                            &decision_tree_resource));
     82     mutex_lock l(*decision_tree_resource->get_mutex());
     83     core::ScopedUnref unref_me(decision_tree_resource);
     84     Tensor* output_config_t = nullptr;
     85     OP_REQUIRES_OK(
     86         context, context->allocate_output(0, TensorShape(), &output_config_t));
     87     output_config_t->scalar<string>()() =
     88         decision_tree_resource->decision_tree().SerializeAsString();
     89   }
     90 };
     91 
     92 // Op for deserializing a tree variable from a checkpoint.
     93 class TreeDeserializeOp : public OpKernel {
     94  public:
     95   explicit TreeDeserializeOp(OpKernelConstruction* context)
     96       : OpKernel(context) {
     97     string serialized_params;
     98     OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
     99     ParseProtoUnlimited(&param_proto_, serialized_params);
    100   }
    101 
    102   void Compute(OpKernelContext* context) override {
    103     DecisionTreeResource* decision_tree_resource;
    104     auto handle = HandleFromInput(context, 0);
    105     OP_REQUIRES_OK(context,
    106                    LookupResource(context, handle, &decision_tree_resource));
    107     mutex_lock l(*decision_tree_resource->get_mutex());
    108     core::ScopedUnref unref_me(decision_tree_resource);
    109 
    110     const Tensor* tree_config_t;
    111     OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t));
    112     OP_REQUIRES(context, TensorShapeUtils::IsScalar(tree_config_t->shape()),
    113                 errors::InvalidArgument("Tree config must be a scalar."));
    114     // Deallocate all the previous objects on the resource.
    115     decision_tree_resource->Reset();
    116     decision_trees::Model* config =
    117         decision_tree_resource->mutable_decision_tree();
    118     OP_REQUIRES(context,
    119                 ParseProtoUnlimited(config, tree_config_t->scalar<string>()()),
    120                 errors::InvalidArgument("Unable to parse tree  config."));
    121     decision_tree_resource->MaybeInitialize();
    122   }
    123 
    124  private:
    125   TensorForestParams param_proto_;
    126 };
    127 
    128 // Op for getting tree size.
    129 class TreeSizeOp : public OpKernel {
    130  public:
    131   explicit TreeSizeOp(OpKernelConstruction* context) : OpKernel(context) {}
    132 
    133   void Compute(OpKernelContext* context) override {
    134     DecisionTreeResource* decision_tree_resource;
    135     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
    136                                            &decision_tree_resource));
    137     mutex_lock l(*decision_tree_resource->get_mutex());
    138     core::ScopedUnref unref_me(decision_tree_resource);
    139     Tensor* output_t = nullptr;
    140     OP_REQUIRES_OK(context,
    141                    context->allocate_output(0, TensorShape(), &output_t));
    142     output_t->scalar<int32>()() =
    143         decision_tree_resource->decision_tree().decision_tree().nodes_size();
    144   }
    145 };
    146 
    147 void TraverseTree(const DecisionTreeResource* tree_resource,
    148                   const std::unique_ptr<TensorDataSet>& data, int32 start,
    149                   int32 end,
    150                   const std::function<void(int32, int32)>& set_leaf_id,
    151                   std::vector<TreePath>* tree_paths) {
    152   for (int i = start; i < end; ++i) {
    153     const int32 id = tree_resource->TraverseTree(
    154         data, i, nullptr,
    155         (tree_paths == nullptr) ? nullptr : &(*tree_paths)[i]);
    156     set_leaf_id(i, id);
    157   }
    158 }
    159 
    160 // Op for tree inference.
    161 class TreePredictionsV4Op : public OpKernel {
    162  public:
    163   explicit TreePredictionsV4Op(OpKernelConstruction* context)
    164       : OpKernel(context) {
    165     string serialized_params;
    166     OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
    167     ParseProtoUnlimited(&param_proto_, serialized_params);
    168 
    169     string serialized_proto;
    170     OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto));
    171     input_spec_.ParseFromString(serialized_proto);
    172     model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(param_proto_);
    173   }
    174 
    175   void Compute(OpKernelContext* context) override {
    176     const Tensor& input_data = context->input(1);
    177     const Tensor& sparse_input_indices = context->input(2);
    178     const Tensor& sparse_input_values = context->input(3);
    179     const Tensor& sparse_input_shape = context->input(4);
    180 
    181     std::unique_ptr<TensorDataSet> data_set(new TensorDataSet(input_spec_, 0));
    182     data_set->set_input_tensors(input_data, sparse_input_indices,
    183                                 sparse_input_values, sparse_input_shape);
    184 
    185     DecisionTreeResource* decision_tree_resource;
    186     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
    187                                            &decision_tree_resource));
    188     mutex_lock l(*decision_tree_resource->get_mutex());
    189     core::ScopedUnref unref_me(decision_tree_resource);
    190 
    191     const int num_data = data_set->NumItems();
    192     const int32 num_outputs = param_proto_.num_outputs();
    193 
    194     Tensor* output_predictions = nullptr;
    195     TensorShape output_shape;
    196     output_shape.AddDim(num_data);
    197     output_shape.AddDim(num_outputs);
    198     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape,
    199                                                      &output_predictions));
    200     TTypes<float, 2>::Tensor out = output_predictions->tensor<float, 2>();
    201 
    202     std::vector<TreePath> tree_paths(
    203         param_proto_.inference_tree_paths() ? num_data : 0);
    204 
    205     auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
    206     int num_threads = worker_threads->num_threads;
    207     const int64 costPerTraverse = 500;
    208     auto traverse = [this, &out, &data_set, decision_tree_resource, num_data,
    209                      &tree_paths](int64 start, int64 end) {
    210       CHECK(start <= end);
    211       CHECK(end <= num_data);
    212       TraverseTree(decision_tree_resource, data_set, static_cast<int32>(start),
    213                    static_cast<int32>(end),
    214                    std::bind(&TreePredictionsV4Op::set_output_value, this,
    215                              std::placeholders::_1, std::placeholders::_2,
    216                              decision_tree_resource, &out),
    217                    param_proto_.inference_tree_paths() ? &tree_paths : nullptr);
    218     };
    219     Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
    220           traverse);
    221 
    222     Tensor* output_tree_paths = nullptr;
    223     TensorShape output_paths_shape;
    224     output_paths_shape.AddDim(param_proto_.inference_tree_paths() ? num_data
    225                                                                   : 0);
    226     OP_REQUIRES_OK(context, context->allocate_output(1, output_paths_shape,
    227                                                      &output_tree_paths));
    228     auto out_paths = output_tree_paths->unaligned_flat<string>();
    229 
    230     // TODO(gilberth): If this slows down inference too much, consider having
    231     // a filter that only serializes paths for the predicted label that we're
    232     // interested in.
    233     for (int i = 0; i < tree_paths.size(); ++i) {
    234       out_paths(i) = tree_paths[i].SerializeAsString();
    235     }
    236   }
    237 
    238   void set_output_value(int32 i, int32 id,
    239                         DecisionTreeResource* decision_tree_resource,
    240                         TTypes<float, 2>::Tensor* out) {
    241     const decision_trees::Leaf& leaf = decision_tree_resource->get_leaf(id);
    242 
    243     float sum = 0;
    244     for (int j = 0; j < param_proto_.num_outputs(); ++j) {
    245       const float count = model_op_->GetOutputValue(leaf, j);
    246       (*out)(i, j) = count;
    247       sum += count;
    248     }
    249 
    250     if (!param_proto_.is_regression() && sum > 0 && sum != 1) {
    251       for (int j = 0; j < param_proto_.num_outputs(); ++j) {
    252         (*out)(i, j) /= sum;
    253       }
    254     }
    255   }
    256 
    257  private:
    258   tensorforest::TensorForestDataSpec input_spec_;
    259   std::unique_ptr<LeafModelOperator> model_op_;
    260   TensorForestParams param_proto_;
    261 };
    262 
    263 // Outputs leaf ids for the given examples.
    264 class TraverseTreeV4Op : public OpKernel {
    265  public:
    266   explicit TraverseTreeV4Op(OpKernelConstruction* context) : OpKernel(context) {
    267     string serialized_params;
    268     OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
    269     ParseProtoUnlimited(&param_proto_, serialized_params);
    270 
    271     string serialized_proto;
    272     OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto));
    273     input_spec_.ParseFromString(serialized_proto);
    274   }
    275 
    276   void Compute(OpKernelContext* context) override {
    277     const Tensor& input_data = context->input(1);
    278     const Tensor& sparse_input_indices = context->input(2);
    279     const Tensor& sparse_input_values = context->input(3);
    280     const Tensor& sparse_input_shape = context->input(4);
    281 
    282     std::unique_ptr<TensorDataSet> data_set(new TensorDataSet(input_spec_, 0));
    283     data_set->set_input_tensors(input_data, sparse_input_indices,
    284                                 sparse_input_values, sparse_input_shape);
    285 
    286     DecisionTreeResource* decision_tree_resource;
    287     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
    288                                            &decision_tree_resource));
    289     mutex_lock l(*decision_tree_resource->get_mutex());
    290     core::ScopedUnref unref_me(decision_tree_resource);
    291 
    292     const int num_data = data_set->NumItems();
    293 
    294     Tensor* output_predictions = nullptr;
    295     TensorShape output_shape;
    296     output_shape.AddDim(num_data);
    297     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape,
    298                                                      &output_predictions));
    299 
    300     auto leaf_ids = output_predictions->tensor<int32, 1>();
    301 
    302     auto set_leaf_ids = [&leaf_ids](int32 i, int32 id) { leaf_ids(i) = id; };
    303 
    304     auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
    305     int num_threads = worker_threads->num_threads;
    306     const int64 costPerTraverse = 500;
    307     auto traverse = [this, &set_leaf_ids, &data_set, decision_tree_resource,
    308                      num_data](int64 start, int64 end) {
    309       CHECK(start <= end);
    310       CHECK(end <= num_data);
    311       TraverseTree(decision_tree_resource, data_set, static_cast<int32>(start),
    312                    static_cast<int32>(end), set_leaf_ids, nullptr);
    313     };
    314     Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
    315           traverse);
    316   }
    317 
    318  private:
    319   tensorforest::TensorForestDataSpec input_spec_;
    320   TensorForestParams param_proto_;
    321 };
    322 
    323 // Update the given leaf models using the batch of labels.
    324 class UpdateModelV4Op : public OpKernel {
    325  public:
    326   explicit UpdateModelV4Op(OpKernelConstruction* context) : OpKernel(context) {
    327     string serialized_params;
    328     OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
    329     ParseProtoUnlimited(&param_proto_, serialized_params);
    330 
    331     model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(param_proto_);
    332   }
    333 
    334   void Compute(OpKernelContext* context) override {
    335     const Tensor& leaf_ids = context->input(1);
    336     const Tensor& input_labels = context->input(2);
    337     const Tensor& input_weights = context->input(3);
    338 
    339     DecisionTreeResource* decision_tree_resource;
    340     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
    341                                            &decision_tree_resource));
    342     mutex_lock l(*decision_tree_resource->get_mutex());
    343     core::ScopedUnref unref_me(decision_tree_resource);
    344 
    345     const int num_data = input_labels.shape().dim_size(0);
    346     const int32 label_dim =
    347         input_labels.shape().dims() <= 1
    348             ? 0
    349             : static_cast<int>(input_labels.shape().dim_size(1));
    350     const int32 num_targets =
    351         param_proto_.is_regression() ? (std::max(1, label_dim)) : 1;
    352 
    353     TensorInputTarget target(input_labels, input_weights, num_targets);
    354 
    355     // TODO(gilberth): Make this thread safe and multi-thread.
    356     UpdateModel(leaf_ids, target, 0, num_data, decision_tree_resource);
    357   }
    358 
    359   void UpdateModel(const Tensor& leaf_ids, const TensorInputTarget& target,
    360                    int32 start, int32 end,
    361                    DecisionTreeResource* decision_tree_resource) {
    362     const auto leaves = leaf_ids.unaligned_flat<int32>();
    363     for (int i = start; i < end; ++i) {
    364       model_op_->UpdateModel(
    365           decision_tree_resource->get_mutable_tree_node(leaves(i))
    366               ->mutable_leaf(),
    367           &target, i);
    368     }
    369   }
    370 
    371  private:
    372   std::unique_ptr<LeafModelOperator> model_op_;
    373   TensorForestParams param_proto_;
    374 };
    375 
    376 // Op for getting feature usage counts.
    377 class FeatureUsageCountsOp : public OpKernel {
    378  public:
    379   explicit FeatureUsageCountsOp(OpKernelConstruction* context)
    380       : OpKernel(context) {
    381     string serialized_params;
    382     OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
    383     ParseProtoUnlimited(&param_proto_, serialized_params);
    384   }
    385 
    386   void Compute(OpKernelContext* context) override {
    387     DecisionTreeResource* decision_tree_resource;
    388     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
    389                                            &decision_tree_resource));
    390     mutex_lock l(*decision_tree_resource->get_mutex());
    391     core::ScopedUnref unref_me(decision_tree_resource);
    392 
    393     const auto& tree = decision_tree_resource->decision_tree();
    394 
    395     Tensor* output_counts = nullptr;
    396     TensorShape output_shape;
    397     output_shape.AddDim(param_proto_.num_features());
    398     OP_REQUIRES_OK(context,
    399                    context->allocate_output(0, output_shape, &output_counts));
    400 
    401     auto counts = output_counts->unaligned_flat<int32>();
    402     counts.setZero();
    403 
    404     for (const auto& node : tree.decision_tree().nodes()) {
    405       if (node.has_custom_node_type()) {
    406         LOG(WARNING) << "Can't count feature usage for custom nodes.";
    407       } else if (node.has_binary_node()) {
    408         const auto& bnode = node.binary_node();
    409         if (bnode.has_custom_left_child_test()) {
    410           decision_trees::MatchingValuesTest test;
    411           if (!bnode.custom_left_child_test().UnpackTo(&test)) {
    412             LOG(WARNING) << "Unknown custom child test";
    413             continue;
    414           }
    415           int32 feat;
    416           safe_strto32(test.feature_id().id().value(), &feat);
    417           ++counts(feat);
    418         } else {
    419           const auto& test = bnode.inequality_left_child_test();
    420           if (test.has_feature_id()) {
    421             int32 feat;
    422             safe_strto32(test.feature_id().id().value(), &feat);
    423             ++counts(feat);
    424           } else if (test.has_oblique()) {
    425             for (const auto& featid : test.oblique().features()) {
    426               int32 feat;
    427               safe_strto32(featid.id().value(), &feat);
    428               ++counts(feat);
    429             }
    430           }
    431         }
    432       }
    433     }
    434   }
    435 
    436  private:
    437   TensorForestParams param_proto_;
    438 };
    439 
    440 REGISTER_RESOURCE_HANDLE_KERNEL(DecisionTreeResource);
    441 
    442 REGISTER_KERNEL_BUILDER(Name("TreeIsInitializedOp").Device(DEVICE_CPU),
    443                         IsResourceInitialized<DecisionTreeResource>);
    444 
    445 REGISTER_KERNEL_BUILDER(Name("CreateTreeVariable").Device(DEVICE_CPU),
    446                         CreateTreeVariableOp);
    447 
    448 REGISTER_KERNEL_BUILDER(Name("TreeSerialize").Device(DEVICE_CPU),
    449                         TreeSerializeOp);
    450 
    451 REGISTER_KERNEL_BUILDER(Name("TreeDeserialize").Device(DEVICE_CPU),
    452                         TreeDeserializeOp);
    453 
    454 REGISTER_KERNEL_BUILDER(Name("TreeSize").Device(DEVICE_CPU), TreeSizeOp);
    455 
    456 REGISTER_KERNEL_BUILDER(Name("TreePredictionsV4").Device(DEVICE_CPU),
    457                         TreePredictionsV4Op);
    458 
    459 REGISTER_KERNEL_BUILDER(Name("TraverseTreeV4").Device(DEVICE_CPU),
    460                         TraverseTreeV4Op);
    461 
    462 REGISTER_KERNEL_BUILDER(Name("FeatureUsageCounts").Device(DEVICE_CPU),
    463                         FeatureUsageCountsOp);
    464 
    465 REGISTER_KERNEL_BUILDER(Name("UpdateModelV4").Device(DEVICE_CPU),
    466                         UpdateModelV4Op);
    467 
    468 }  // namespace tensorforest
    469 }  // namespace tensorflow
    470