1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ 16 #define TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ 17 18 #include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" 19 #include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h" 20 #include "tensorflow/core/framework/resource_mgr.h" 21 #include "tensorflow/core/platform/mutex.h" 22 #include "tensorflow/core/platform/protobuf.h" 23 24 namespace tensorflow { 25 namespace boosted_trees { 26 namespace models { 27 28 // Keep a tree ensemble in memory for efficient evaluation and mutation. 29 class DecisionTreeEnsembleResource : public StampedResource { 30 public: 31 // Constructor. 32 explicit DecisionTreeEnsembleResource() 33 : decision_tree_ensemble_( 34 protobuf::Arena::CreateMessage< 35 boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_)) {} 36 37 string DebugString() override { 38 return strings::StrCat("GTFlowDecisionTreeEnsemble[size=", 39 decision_tree_ensemble_->trees_size(), "]"); 40 } 41 42 const boosted_trees::trees::DecisionTreeEnsembleConfig& 43 decision_tree_ensemble() const { 44 return *decision_tree_ensemble_; 45 } 46 47 int32 num_trees() const { return decision_tree_ensemble_->trees_size(); } 48 49 bool InitFromSerialized(const string& serialized, const int64 stamp_token) { 50 CHECK_EQ(stamp(), -1) << "Must Reset before Init."; 51 if (ParseProtoUnlimited(decision_tree_ensemble_, serialized)) { 52 set_stamp(stamp_token); 53 return true; 54 } 55 return false; 56 } 57 58 string SerializeAsString() const { 59 return decision_tree_ensemble_->SerializeAsString(); 60 } 61 62 // Increment num_layers_attempted and num_trees_attempted in growing_metadata 63 // if the tree is finalized. 64 void IncrementAttempts() { 65 boosted_trees::trees::GrowingMetadata* const growing_metadata = 66 decision_tree_ensemble_->mutable_growing_metadata(); 67 growing_metadata->set_num_layers_attempted( 68 growing_metadata->num_layers_attempted() + 1); 69 const int num_trees = decision_tree_ensemble_->trees_size(); 70 if (num_trees <= 0 || LastTreeMetadata()->is_finalized()) { 71 growing_metadata->set_num_trees_attempted( 72 growing_metadata->num_trees_attempted() + 1); 73 } 74 } 75 76 boosted_trees::trees::DecisionTreeConfig* AddNewTree(const float weight) { 77 // Adding a tree as well as a weight and a tree_metadata. 78 decision_tree_ensemble_->add_tree_weights(weight); 79 boosted_trees::trees::DecisionTreeMetadata* const metadata = 80 decision_tree_ensemble_->add_tree_metadata(); 81 metadata->set_num_layers_grown(1); 82 return decision_tree_ensemble_->add_trees(); 83 } 84 85 void RemoveLastTree() { 86 QCHECK_GT(decision_tree_ensemble_->trees_size(), 0); 87 decision_tree_ensemble_->mutable_trees()->RemoveLast(); 88 decision_tree_ensemble_->mutable_tree_weights()->RemoveLast(); 89 decision_tree_ensemble_->mutable_tree_metadata()->RemoveLast(); 90 } 91 92 boosted_trees::trees::DecisionTreeConfig* LastTree() { 93 const int32 tree_size = decision_tree_ensemble_->trees_size(); 94 QCHECK_GT(tree_size, 0); 95 return decision_tree_ensemble_->mutable_trees(tree_size - 1); 96 } 97 98 boosted_trees::trees::DecisionTreeMetadata* LastTreeMetadata() { 99 const int32 metadata_size = decision_tree_ensemble_->tree_metadata_size(); 100 QCHECK_GT(metadata_size, 0); 101 return decision_tree_ensemble_->mutable_tree_metadata(metadata_size - 1); 102 } 103 104 // Retrieves tree weights and returns as a vector. 105 std::vector<float> GetTreeWeights() const { 106 return {decision_tree_ensemble_->tree_weights().begin(), 107 decision_tree_ensemble_->tree_weights().end()}; 108 } 109 110 float GetTreeWeight(const int32 index) const { 111 return decision_tree_ensemble_->tree_weights(index); 112 } 113 114 void MaybeAddUsedHandler(const int32 handler_id) { 115 protobuf::RepeatedField<protobuf_int64>* used_ids = 116 decision_tree_ensemble_->mutable_growing_metadata() 117 ->mutable_used_handler_ids(); 118 protobuf::RepeatedField<protobuf_int64>::iterator first = 119 std::lower_bound(used_ids->begin(), used_ids->end(), handler_id); 120 if (first == used_ids->end()) { 121 used_ids->Add(handler_id); 122 return; 123 } 124 if (handler_id == *first) { 125 // It is a duplicate entry. 126 return; 127 } 128 used_ids->Add(handler_id); 129 std::rotate(first, used_ids->end() - 1, used_ids->end()); 130 } 131 132 std::vector<int64> GetUsedHandlers() const { 133 std::vector<int64> result; 134 result.reserve( 135 decision_tree_ensemble_->growing_metadata().used_handler_ids().size()); 136 for (int64 h : 137 decision_tree_ensemble_->growing_metadata().used_handler_ids()) { 138 result.push_back(h); 139 } 140 return result; 141 } 142 143 // Sets the weight of i'th tree, and increment num_updates in tree_metadata. 144 void SetTreeWeight(const int32 index, const float weight, 145 const int32 increment_num_updates) { 146 QCHECK_GE(index, 0); 147 QCHECK_LT(index, num_trees()); 148 decision_tree_ensemble_->set_tree_weights(index, weight); 149 if (increment_num_updates != 0) { 150 const int32 num_updates = decision_tree_ensemble_->tree_metadata(index) 151 .num_tree_weight_updates(); 152 decision_tree_ensemble_->mutable_tree_metadata(index) 153 ->set_num_tree_weight_updates(num_updates + increment_num_updates); 154 } 155 } 156 157 // Resets the resource and frees the protos in arena. 158 // Caller needs to hold the mutex lock while calling this. 159 virtual void Reset() { 160 // Reset stamp. 161 set_stamp(-1); 162 163 // Clear tree ensemle. 164 arena_.Reset(); 165 CHECK_EQ(0, arena_.SpaceAllocated()); 166 decision_tree_ensemble_ = protobuf::Arena::CreateMessage< 167 boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_); 168 } 169 170 mutex* get_mutex() { return &mu_; } 171 172 protected: 173 protobuf::Arena arena_; 174 mutex mu_; 175 boosted_trees::trees::DecisionTreeEnsembleConfig* decision_tree_ensemble_; 176 }; 177 178 } // namespace models 179 } // namespace boosted_trees 180 } // namespace tensorflow 181 182 #endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ 183