Home | History | Annotate | Download | only in resources
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_
     16 #define TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_
     17 
     18 #include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
     19 #include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h"
     20 #include "tensorflow/core/framework/resource_mgr.h"
     21 #include "tensorflow/core/platform/mutex.h"
     22 #include "tensorflow/core/platform/protobuf.h"
     23 
     24 namespace tensorflow {
     25 namespace boosted_trees {
     26 namespace models {
     27 
     28 // Keep a tree ensemble in memory for efficient evaluation and mutation.
     29 class DecisionTreeEnsembleResource : public StampedResource {
     30  public:
     31   // Constructor.
     32   explicit DecisionTreeEnsembleResource()
     33       : decision_tree_ensemble_(
     34             protobuf::Arena::CreateMessage<
     35                 boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_)) {}
     36 
     37   string DebugString() override {
     38     return strings::StrCat("GTFlowDecisionTreeEnsemble[size=",
     39                            decision_tree_ensemble_->trees_size(), "]");
     40   }
     41 
     42   const boosted_trees::trees::DecisionTreeEnsembleConfig&
     43   decision_tree_ensemble() const {
     44     return *decision_tree_ensemble_;
     45   }
     46 
     47   int32 num_trees() const { return decision_tree_ensemble_->trees_size(); }
     48 
     49   bool InitFromSerialized(const string& serialized, const int64 stamp_token) {
     50     CHECK_EQ(stamp(), -1) << "Must Reset before Init.";
     51     if (ParseProtoUnlimited(decision_tree_ensemble_, serialized)) {
     52       set_stamp(stamp_token);
     53       return true;
     54     }
     55     return false;
     56   }
     57 
     58   string SerializeAsString() const {
     59     return decision_tree_ensemble_->SerializeAsString();
     60   }
     61 
     62   // Increment num_layers_attempted and num_trees_attempted in growing_metadata
     63   // if the tree is finalized.
     64   void IncrementAttempts() {
     65     boosted_trees::trees::GrowingMetadata* const growing_metadata =
     66         decision_tree_ensemble_->mutable_growing_metadata();
     67     growing_metadata->set_num_layers_attempted(
     68         growing_metadata->num_layers_attempted() + 1);
     69     const int num_trees = decision_tree_ensemble_->trees_size();
     70     if (num_trees <= 0 || LastTreeMetadata()->is_finalized()) {
     71       growing_metadata->set_num_trees_attempted(
     72           growing_metadata->num_trees_attempted() + 1);
     73     }
     74   }
     75 
     76   boosted_trees::trees::DecisionTreeConfig* AddNewTree(const float weight) {
     77     // Adding a tree as well as a weight and a tree_metadata.
     78     decision_tree_ensemble_->add_tree_weights(weight);
     79     boosted_trees::trees::DecisionTreeMetadata* const metadata =
     80         decision_tree_ensemble_->add_tree_metadata();
     81     metadata->set_num_layers_grown(1);
     82     return decision_tree_ensemble_->add_trees();
     83   }
     84 
     85   void RemoveLastTree() {
     86     QCHECK_GT(decision_tree_ensemble_->trees_size(), 0);
     87     decision_tree_ensemble_->mutable_trees()->RemoveLast();
     88     decision_tree_ensemble_->mutable_tree_weights()->RemoveLast();
     89     decision_tree_ensemble_->mutable_tree_metadata()->RemoveLast();
     90   }
     91 
     92   boosted_trees::trees::DecisionTreeConfig* LastTree() {
     93     const int32 tree_size = decision_tree_ensemble_->trees_size();
     94     QCHECK_GT(tree_size, 0);
     95     return decision_tree_ensemble_->mutable_trees(tree_size - 1);
     96   }
     97 
     98   boosted_trees::trees::DecisionTreeMetadata* LastTreeMetadata() {
     99     const int32 metadata_size = decision_tree_ensemble_->tree_metadata_size();
    100     QCHECK_GT(metadata_size, 0);
    101     return decision_tree_ensemble_->mutable_tree_metadata(metadata_size - 1);
    102   }
    103 
    104   // Retrieves tree weights and returns as a vector.
    105   std::vector<float> GetTreeWeights() const {
    106     return {decision_tree_ensemble_->tree_weights().begin(),
    107             decision_tree_ensemble_->tree_weights().end()};
    108   }
    109 
    110   float GetTreeWeight(const int32 index) const {
    111     return decision_tree_ensemble_->tree_weights(index);
    112   }
    113 
    114   void MaybeAddUsedHandler(const int32 handler_id) {
    115     protobuf::RepeatedField<protobuf_int64>* used_ids =
    116         decision_tree_ensemble_->mutable_growing_metadata()
    117             ->mutable_used_handler_ids();
    118     protobuf::RepeatedField<protobuf_int64>::iterator first =
    119         std::lower_bound(used_ids->begin(), used_ids->end(), handler_id);
    120     if (first == used_ids->end()) {
    121       used_ids->Add(handler_id);
    122       return;
    123     }
    124     if (handler_id == *first) {
    125       // It is a duplicate entry.
    126       return;
    127     }
    128     used_ids->Add(handler_id);
    129     std::rotate(first, used_ids->end() - 1, used_ids->end());
    130   }
    131 
    132   std::vector<int64> GetUsedHandlers() const {
    133     std::vector<int64> result;
    134     result.reserve(
    135         decision_tree_ensemble_->growing_metadata().used_handler_ids().size());
    136     for (int64 h :
    137          decision_tree_ensemble_->growing_metadata().used_handler_ids()) {
    138       result.push_back(h);
    139     }
    140     return result;
    141   }
    142 
    143   // Sets the weight of i'th tree, and increment num_updates in tree_metadata.
    144   void SetTreeWeight(const int32 index, const float weight,
    145                      const int32 increment_num_updates) {
    146     QCHECK_GE(index, 0);
    147     QCHECK_LT(index, num_trees());
    148     decision_tree_ensemble_->set_tree_weights(index, weight);
    149     if (increment_num_updates != 0) {
    150       const int32 num_updates = decision_tree_ensemble_->tree_metadata(index)
    151                                     .num_tree_weight_updates();
    152       decision_tree_ensemble_->mutable_tree_metadata(index)
    153           ->set_num_tree_weight_updates(num_updates + increment_num_updates);
    154     }
    155   }
    156 
    157   // Resets the resource and frees the protos in arena.
    158   // Caller needs to hold the mutex lock while calling this.
    159   virtual void Reset() {
    160     // Reset stamp.
    161     set_stamp(-1);
    162 
    163     // Clear tree ensemle.
    164     arena_.Reset();
    165     CHECK_EQ(0, arena_.SpaceAllocated());
    166     decision_tree_ensemble_ = protobuf::Arena::CreateMessage<
    167         boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_);
    168   }
    169 
    170   mutex* get_mutex() { return &mu_; }
    171 
    172  protected:
    173   protobuf::Arena arena_;
    174   mutex mu_;
    175   boosted_trees::trees::DecisionTreeEnsembleConfig* decision_tree_ensemble_;
    176 };
    177 
    178 }  // namespace models
    179 }  // namespace boosted_trees
    180 }  // namespace tensorflow
    181 
    182 #endif  // TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_
    183