1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #include <functional> 16 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" 17 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h" 18 #include "tensorflow/contrib/tensor_forest/kernels/data_spec.h" 19 #include "tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h" 20 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" 21 #include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/resource_mgr.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_shape.h" 26 #include "tensorflow/core/framework/tensor_types.h" 27 #include "tensorflow/core/lib/strings/strcat.h" 28 #include "tensorflow/core/platform/thread_annotations.h" 29 #include "tensorflow/core/platform/types.h" 30 #include "tensorflow/core/util/work_sharder.h" 31 32 namespace tensorflow { 33 namespace tensorforest { 34 35 // Creates a tree variable. 36 class CreateTreeVariableOp : public OpKernel { 37 public: 38 explicit CreateTreeVariableOp(OpKernelConstruction* context) 39 : OpKernel(context) { 40 string serialized_params; 41 OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params)); 42 ParseProtoUnlimited(¶m_proto_, serialized_params); 43 } 44 45 void Compute(OpKernelContext* context) override { 46 const Tensor* tree_config_t; 47 OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t)); 48 OP_REQUIRES(context, TensorShapeUtils::IsScalar(tree_config_t->shape()), 49 errors::InvalidArgument("Tree config must be a scalar.")); 50 51 auto* result = new DecisionTreeResource(param_proto_); 52 if (!ParseProtoUnlimited(result->mutable_decision_tree(), 53 tree_config_t->scalar<string>()())) { 54 result->Unref(); 55 OP_REQUIRES(context, false, 56 errors::InvalidArgument("Unable to parse tree config.")); 57 } 58 59 result->MaybeInitialize(); 60 61 // Only create one, if one does not exist already. Report status for all 62 // other exceptions. 63 auto status = CreateResource(context, HandleFromInput(context, 0), result); 64 if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) { 65 OP_REQUIRES(context, false, status); 66 } 67 } 68 69 private: 70 TensorForestParams param_proto_; 71 }; 72 73 // Op for serializing a model. 74 class TreeSerializeOp : public OpKernel { 75 public: 76 explicit TreeSerializeOp(OpKernelConstruction* context) : OpKernel(context) {} 77 78 void Compute(OpKernelContext* context) override { 79 DecisionTreeResource* decision_tree_resource; 80 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 81 &decision_tree_resource)); 82 mutex_lock l(*decision_tree_resource->get_mutex()); 83 core::ScopedUnref unref_me(decision_tree_resource); 84 Tensor* output_config_t = nullptr; 85 OP_REQUIRES_OK( 86 context, context->allocate_output(0, TensorShape(), &output_config_t)); 87 output_config_t->scalar<string>()() = 88 decision_tree_resource->decision_tree().SerializeAsString(); 89 } 90 }; 91 92 // Op for deserializing a tree variable from a checkpoint. 93 class TreeDeserializeOp : public OpKernel { 94 public: 95 explicit TreeDeserializeOp(OpKernelConstruction* context) 96 : OpKernel(context) { 97 string serialized_params; 98 OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params)); 99 ParseProtoUnlimited(¶m_proto_, serialized_params); 100 } 101 102 void Compute(OpKernelContext* context) override { 103 DecisionTreeResource* decision_tree_resource; 104 auto handle = HandleFromInput(context, 0); 105 OP_REQUIRES_OK(context, 106 LookupResource(context, handle, &decision_tree_resource)); 107 mutex_lock l(*decision_tree_resource->get_mutex()); 108 core::ScopedUnref unref_me(decision_tree_resource); 109 110 const Tensor* tree_config_t; 111 OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t)); 112 OP_REQUIRES(context, TensorShapeUtils::IsScalar(tree_config_t->shape()), 113 errors::InvalidArgument("Tree config must be a scalar.")); 114 // Deallocate all the previous objects on the resource. 115 decision_tree_resource->Reset(); 116 decision_trees::Model* config = 117 decision_tree_resource->mutable_decision_tree(); 118 OP_REQUIRES(context, 119 ParseProtoUnlimited(config, tree_config_t->scalar<string>()()), 120 errors::InvalidArgument("Unable to parse tree config.")); 121 decision_tree_resource->MaybeInitialize(); 122 } 123 124 private: 125 TensorForestParams param_proto_; 126 }; 127 128 // Op for getting tree size. 129 class TreeSizeOp : public OpKernel { 130 public: 131 explicit TreeSizeOp(OpKernelConstruction* context) : OpKernel(context) {} 132 133 void Compute(OpKernelContext* context) override { 134 DecisionTreeResource* decision_tree_resource; 135 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 136 &decision_tree_resource)); 137 mutex_lock l(*decision_tree_resource->get_mutex()); 138 core::ScopedUnref unref_me(decision_tree_resource); 139 Tensor* output_t = nullptr; 140 OP_REQUIRES_OK(context, 141 context->allocate_output(0, TensorShape(), &output_t)); 142 output_t->scalar<int32>()() = 143 decision_tree_resource->decision_tree().decision_tree().nodes_size(); 144 } 145 }; 146 147 void TraverseTree(const DecisionTreeResource* tree_resource, 148 const std::unique_ptr<TensorDataSet>& data, int32 start, 149 int32 end, 150 const std::function<void(int32, int32)>& set_leaf_id, 151 std::vector<TreePath>* tree_paths) { 152 for (int i = start; i < end; ++i) { 153 const int32 id = tree_resource->TraverseTree( 154 data, i, nullptr, 155 (tree_paths == nullptr) ? nullptr : &(*tree_paths)[i]); 156 set_leaf_id(i, id); 157 } 158 } 159 160 // Op for tree inference. 161 class TreePredictionsV4Op : public OpKernel { 162 public: 163 explicit TreePredictionsV4Op(OpKernelConstruction* context) 164 : OpKernel(context) { 165 string serialized_params; 166 OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params)); 167 ParseProtoUnlimited(¶m_proto_, serialized_params); 168 169 string serialized_proto; 170 OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto)); 171 input_spec_.ParseFromString(serialized_proto); 172 model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(param_proto_); 173 } 174 175 void Compute(OpKernelContext* context) override { 176 const Tensor& input_data = context->input(1); 177 const Tensor& sparse_input_indices = context->input(2); 178 const Tensor& sparse_input_values = context->input(3); 179 const Tensor& sparse_input_shape = context->input(4); 180 181 std::unique_ptr<TensorDataSet> data_set(new TensorDataSet(input_spec_, 0)); 182 data_set->set_input_tensors(input_data, sparse_input_indices, 183 sparse_input_values, sparse_input_shape); 184 185 DecisionTreeResource* decision_tree_resource; 186 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 187 &decision_tree_resource)); 188 mutex_lock l(*decision_tree_resource->get_mutex()); 189 core::ScopedUnref unref_me(decision_tree_resource); 190 191 const int num_data = data_set->NumItems(); 192 const int32 num_outputs = param_proto_.num_outputs(); 193 194 Tensor* output_predictions = nullptr; 195 TensorShape output_shape; 196 output_shape.AddDim(num_data); 197 output_shape.AddDim(num_outputs); 198 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, 199 &output_predictions)); 200 TTypes<float, 2>::Tensor out = output_predictions->tensor<float, 2>(); 201 202 std::vector<TreePath> tree_paths( 203 param_proto_.inference_tree_paths() ? num_data : 0); 204 205 auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); 206 int num_threads = worker_threads->num_threads; 207 const int64 costPerTraverse = 500; 208 auto traverse = [this, &out, &data_set, decision_tree_resource, num_data, 209 &tree_paths](int64 start, int64 end) { 210 CHECK(start <= end); 211 CHECK(end <= num_data); 212 TraverseTree(decision_tree_resource, data_set, static_cast<int32>(start), 213 static_cast<int32>(end), 214 std::bind(&TreePredictionsV4Op::set_output_value, this, 215 std::placeholders::_1, std::placeholders::_2, 216 decision_tree_resource, &out), 217 param_proto_.inference_tree_paths() ? &tree_paths : nullptr); 218 }; 219 Shard(num_threads, worker_threads->workers, num_data, costPerTraverse, 220 traverse); 221 222 Tensor* output_tree_paths = nullptr; 223 TensorShape output_paths_shape; 224 output_paths_shape.AddDim(param_proto_.inference_tree_paths() ? num_data 225 : 0); 226 OP_REQUIRES_OK(context, context->allocate_output(1, output_paths_shape, 227 &output_tree_paths)); 228 auto out_paths = output_tree_paths->unaligned_flat<string>(); 229 230 // TODO(gilberth): If this slows down inference too much, consider having 231 // a filter that only serializes paths for the predicted label that we're 232 // interested in. 233 for (int i = 0; i < tree_paths.size(); ++i) { 234 out_paths(i) = tree_paths[i].SerializeAsString(); 235 } 236 } 237 238 void set_output_value(int32 i, int32 id, 239 DecisionTreeResource* decision_tree_resource, 240 TTypes<float, 2>::Tensor* out) { 241 const decision_trees::Leaf& leaf = decision_tree_resource->get_leaf(id); 242 243 float sum = 0; 244 for (int j = 0; j < param_proto_.num_outputs(); ++j) { 245 const float count = model_op_->GetOutputValue(leaf, j); 246 (*out)(i, j) = count; 247 sum += count; 248 } 249 250 if (!param_proto_.is_regression() && sum > 0 && sum != 1) { 251 for (int j = 0; j < param_proto_.num_outputs(); ++j) { 252 (*out)(i, j) /= sum; 253 } 254 } 255 } 256 257 private: 258 tensorforest::TensorForestDataSpec input_spec_; 259 std::unique_ptr<LeafModelOperator> model_op_; 260 TensorForestParams param_proto_; 261 }; 262 263 // Outputs leaf ids for the given examples. 264 class TraverseTreeV4Op : public OpKernel { 265 public: 266 explicit TraverseTreeV4Op(OpKernelConstruction* context) : OpKernel(context) { 267 string serialized_params; 268 OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params)); 269 ParseProtoUnlimited(¶m_proto_, serialized_params); 270 271 string serialized_proto; 272 OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto)); 273 input_spec_.ParseFromString(serialized_proto); 274 } 275 276 void Compute(OpKernelContext* context) override { 277 const Tensor& input_data = context->input(1); 278 const Tensor& sparse_input_indices = context->input(2); 279 const Tensor& sparse_input_values = context->input(3); 280 const Tensor& sparse_input_shape = context->input(4); 281 282 std::unique_ptr<TensorDataSet> data_set(new TensorDataSet(input_spec_, 0)); 283 data_set->set_input_tensors(input_data, sparse_input_indices, 284 sparse_input_values, sparse_input_shape); 285 286 DecisionTreeResource* decision_tree_resource; 287 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 288 &decision_tree_resource)); 289 mutex_lock l(*decision_tree_resource->get_mutex()); 290 core::ScopedUnref unref_me(decision_tree_resource); 291 292 const int num_data = data_set->NumItems(); 293 294 Tensor* output_predictions = nullptr; 295 TensorShape output_shape; 296 output_shape.AddDim(num_data); 297 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, 298 &output_predictions)); 299 300 auto leaf_ids = output_predictions->tensor<int32, 1>(); 301 302 auto set_leaf_ids = [&leaf_ids](int32 i, int32 id) { leaf_ids(i) = id; }; 303 304 auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); 305 int num_threads = worker_threads->num_threads; 306 const int64 costPerTraverse = 500; 307 auto traverse = [this, &set_leaf_ids, &data_set, decision_tree_resource, 308 num_data](int64 start, int64 end) { 309 CHECK(start <= end); 310 CHECK(end <= num_data); 311 TraverseTree(decision_tree_resource, data_set, static_cast<int32>(start), 312 static_cast<int32>(end), set_leaf_ids, nullptr); 313 }; 314 Shard(num_threads, worker_threads->workers, num_data, costPerTraverse, 315 traverse); 316 } 317 318 private: 319 tensorforest::TensorForestDataSpec input_spec_; 320 TensorForestParams param_proto_; 321 }; 322 323 // Update the given leaf models using the batch of labels. 324 class UpdateModelV4Op : public OpKernel { 325 public: 326 explicit UpdateModelV4Op(OpKernelConstruction* context) : OpKernel(context) { 327 string serialized_params; 328 OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params)); 329 ParseProtoUnlimited(¶m_proto_, serialized_params); 330 331 model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(param_proto_); 332 } 333 334 void Compute(OpKernelContext* context) override { 335 const Tensor& leaf_ids = context->input(1); 336 const Tensor& input_labels = context->input(2); 337 const Tensor& input_weights = context->input(3); 338 339 DecisionTreeResource* decision_tree_resource; 340 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 341 &decision_tree_resource)); 342 mutex_lock l(*decision_tree_resource->get_mutex()); 343 core::ScopedUnref unref_me(decision_tree_resource); 344 345 const int num_data = input_labels.shape().dim_size(0); 346 const int32 label_dim = 347 input_labels.shape().dims() <= 1 348 ? 0 349 : static_cast<int>(input_labels.shape().dim_size(1)); 350 const int32 num_targets = 351 param_proto_.is_regression() ? (std::max(1, label_dim)) : 1; 352 353 TensorInputTarget target(input_labels, input_weights, num_targets); 354 355 // TODO(gilberth): Make this thread safe and multi-thread. 356 UpdateModel(leaf_ids, target, 0, num_data, decision_tree_resource); 357 } 358 359 void UpdateModel(const Tensor& leaf_ids, const TensorInputTarget& target, 360 int32 start, int32 end, 361 DecisionTreeResource* decision_tree_resource) { 362 const auto leaves = leaf_ids.unaligned_flat<int32>(); 363 for (int i = start; i < end; ++i) { 364 model_op_->UpdateModel( 365 decision_tree_resource->get_mutable_tree_node(leaves(i)) 366 ->mutable_leaf(), 367 &target, i); 368 } 369 } 370 371 private: 372 std::unique_ptr<LeafModelOperator> model_op_; 373 TensorForestParams param_proto_; 374 }; 375 376 // Op for getting feature usage counts. 377 class FeatureUsageCountsOp : public OpKernel { 378 public: 379 explicit FeatureUsageCountsOp(OpKernelConstruction* context) 380 : OpKernel(context) { 381 string serialized_params; 382 OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params)); 383 ParseProtoUnlimited(¶m_proto_, serialized_params); 384 } 385 386 void Compute(OpKernelContext* context) override { 387 DecisionTreeResource* decision_tree_resource; 388 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 389 &decision_tree_resource)); 390 mutex_lock l(*decision_tree_resource->get_mutex()); 391 core::ScopedUnref unref_me(decision_tree_resource); 392 393 const auto& tree = decision_tree_resource->decision_tree(); 394 395 Tensor* output_counts = nullptr; 396 TensorShape output_shape; 397 output_shape.AddDim(param_proto_.num_features()); 398 OP_REQUIRES_OK(context, 399 context->allocate_output(0, output_shape, &output_counts)); 400 401 auto counts = output_counts->unaligned_flat<int32>(); 402 counts.setZero(); 403 404 for (const auto& node : tree.decision_tree().nodes()) { 405 if (node.has_custom_node_type()) { 406 LOG(WARNING) << "Can't count feature usage for custom nodes."; 407 } else if (node.has_binary_node()) { 408 const auto& bnode = node.binary_node(); 409 if (bnode.has_custom_left_child_test()) { 410 decision_trees::MatchingValuesTest test; 411 if (!bnode.custom_left_child_test().UnpackTo(&test)) { 412 LOG(WARNING) << "Unknown custom child test"; 413 continue; 414 } 415 int32 feat; 416 safe_strto32(test.feature_id().id().value(), &feat); 417 ++counts(feat); 418 } else { 419 const auto& test = bnode.inequality_left_child_test(); 420 if (test.has_feature_id()) { 421 int32 feat; 422 safe_strto32(test.feature_id().id().value(), &feat); 423 ++counts(feat); 424 } else if (test.has_oblique()) { 425 for (const auto& featid : test.oblique().features()) { 426 int32 feat; 427 safe_strto32(featid.id().value(), &feat); 428 ++counts(feat); 429 } 430 } 431 } 432 } 433 } 434 } 435 436 private: 437 TensorForestParams param_proto_; 438 }; 439 440 REGISTER_RESOURCE_HANDLE_KERNEL(DecisionTreeResource); 441 442 REGISTER_KERNEL_BUILDER(Name("TreeIsInitializedOp").Device(DEVICE_CPU), 443 IsResourceInitialized<DecisionTreeResource>); 444 445 REGISTER_KERNEL_BUILDER(Name("CreateTreeVariable").Device(DEVICE_CPU), 446 CreateTreeVariableOp); 447 448 REGISTER_KERNEL_BUILDER(Name("TreeSerialize").Device(DEVICE_CPU), 449 TreeSerializeOp); 450 451 REGISTER_KERNEL_BUILDER(Name("TreeDeserialize").Device(DEVICE_CPU), 452 TreeDeserializeOp); 453 454 REGISTER_KERNEL_BUILDER(Name("TreeSize").Device(DEVICE_CPU), TreeSizeOp); 455 456 REGISTER_KERNEL_BUILDER(Name("TreePredictionsV4").Device(DEVICE_CPU), 457 TreePredictionsV4Op); 458 459 REGISTER_KERNEL_BUILDER(Name("TraverseTreeV4").Device(DEVICE_CPU), 460 TraverseTreeV4Op); 461 462 REGISTER_KERNEL_BUILDER(Name("FeatureUsageCounts").Device(DEVICE_CPU), 463 FeatureUsageCountsOp); 464 465 REGISTER_KERNEL_BUILDER(Name("UpdateModelV4").Device(DEVICE_CPU), 466 UpdateModelV4Op); 467 468 } // namespace tensorforest 469 } // namespace tensorflow 470