1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #include <algorithm> 16 #include <iterator> 17 #include <map> 18 #include <string> 19 #include <vector> 20 21 #include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h" 22 #include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h" 23 #include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/resource_mgr.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/lib/core/errors.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/platform/types.h" 32 #include "tensorflow/core/util/work_sharder.h" 33 34 namespace tensorflow { 35 namespace boosted_trees { 36 37 namespace { 38 const char* const kStampTokenName = "stamp_token"; 39 const char* const kNextStampTokenName = "next_stamp_token"; 40 41 struct PartitionKey { 42 PartitionKey() : partition_id(-1), feature_id(-1), dimension(-1) {} 43 44 PartitionKey(int32 p, int64 f, int32 d) 45 : partition_id(p), feature_id(f), dimension(d) {} 46 47 bool operator==(const PartitionKey& other) const { 48 return (partition_id == other.partition_id) && 49 (dimension == other.dimension) && (feature_id == other.feature_id); 50 } 51 52 // Compare for PartitionKey. 53 struct Less { 54 bool operator()(const PartitionKey& a, const PartitionKey& b) const { 55 if (a.partition_id < b.partition_id) { 56 return true; 57 } 58 if ((a.partition_id == b.partition_id) && (a.dimension < b.dimension)) { 59 return true; 60 } 61 if ((a.partition_id == b.partition_id) && (a.dimension == b.dimension) && 62 (a.feature_id < b.feature_id)) { 63 return true; 64 } 65 return false; 66 } 67 }; 68 69 // Tree partition defined by traversing the tree to the leaf. 70 int32 partition_id; 71 72 // Feature column id. 73 int64 feature_id; 74 75 // Dimension within feature column. 76 int32 dimension; 77 }; 78 79 template <typename GradientType, typename HessianType> 80 class StatsAccumulatorResource : public boosted_trees::StampedResource { 81 using StatsByPartition = 82 std::map<PartitionKey, std::pair<GradientType, HessianType>, 83 PartitionKey::Less>; 84 85 public: 86 StatsAccumulatorResource(const TensorShape& gradient_shape, 87 const TensorShape& hessian_shape) 88 : gradient_shape_(gradient_shape), 89 hessian_shape_(hessian_shape), 90 num_updates_(0) { 91 // If GradientType/HessianType is scalar float then the shapes should be 92 // scalar and vice versa. 93 CHECK_EQ((std::is_same<GradientType, float>::value), 94 TensorShapeUtils::IsScalar(gradient_shape)); 95 CHECK_EQ((std::is_same<HessianType, float>::value), 96 TensorShapeUtils::IsScalar(hessian_shape)); 97 } 98 99 string DebugString() const override { 100 return strings::StrCat("StatsAccumulatorResource[size=", values_.size(), 101 "]"); 102 } 103 104 void Clear() { 105 values_.clear(); 106 num_updates_ = 0; 107 } 108 109 tensorflow::mutex* mutex() { return &mu_; } 110 StatsByPartition* mutable_values() { return &values_; } 111 const StatsByPartition& values() const { return values_; } 112 const int64& num_updates() const { return num_updates_; } 113 void set_num_updates(int64 val) { num_updates_ = val; } 114 const TensorShape& gradient_shape() const { return gradient_shape_; } 115 const TensorShape& hessian_shape() const { return hessian_shape_; } 116 117 private: 118 // Key into a specific partition to accumulate stats for the specified feature 119 // id. 120 StatsByPartition values_; 121 const TensorShape gradient_shape_; 122 const TensorShape hessian_shape_; 123 int64 num_updates_; 124 tensorflow::mutex mu_; 125 TF_DISALLOW_COPY_AND_ASSIGN(StatsAccumulatorResource); 126 }; 127 128 using StatsAccumulatorScalarResource = StatsAccumulatorResource<float, float>; 129 using StatsAccumulatorTensorResource = 130 StatsAccumulatorResource<std::vector<float>, std::vector<float>>; 131 132 void SerializeScalarAccumulatorToOutput( 133 const StatsAccumulatorScalarResource& accumulator_resource, 134 OpKernelContext* context) { 135 int64 num_slots = accumulator_resource.values().size(); 136 Tensor* partition_ids_t = nullptr; 137 OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids", 138 TensorShape({num_slots}), 139 &partition_ids_t)); 140 auto partition_ids = partition_ids_t->vec<int32>(); 141 142 // Feature ids tensor has ids of feature columns and their dimensions. 143 Tensor* feature_ids_t = nullptr; 144 OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids", 145 TensorShape({num_slots, 2}), 146 &feature_ids_t)); 147 auto feature_ids = feature_ids_t->matrix<int64>(); 148 149 Tensor* gradients_t = nullptr; 150 OP_REQUIRES_OK( 151 context, context->allocate_output( 152 "output_gradients", TensorShape({num_slots}), &gradients_t)); 153 auto gradients = gradients_t->vec<float>(); 154 155 Tensor* hessians_t = nullptr; 156 OP_REQUIRES_OK( 157 context, context->allocate_output("output_hessians", 158 TensorShape({num_slots}), &hessians_t)); 159 auto hessians = hessians_t->vec<float>(); 160 161 int i = 0; 162 for (const auto& iter : accumulator_resource.values()) { 163 partition_ids(i) = iter.first.partition_id; 164 feature_ids(i, 0) = iter.first.feature_id; 165 feature_ids(i, 1) = iter.first.dimension; 166 167 gradients(i) = iter.second.first; 168 hessians(i) = iter.second.second; 169 ++i; 170 } 171 } 172 173 void SerializeTensorAccumulatorToOutput( 174 const StatsAccumulatorTensorResource& accumulator_resource, 175 OpKernelContext* context) { 176 int64 num_slots = accumulator_resource.values().size(); 177 Tensor* partition_ids_t = nullptr; 178 OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids", 179 TensorShape({num_slots}), 180 &partition_ids_t)); 181 auto partition_ids = partition_ids_t->vec<int32>(); 182 183 Tensor* feature_ids_t = nullptr; 184 OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids", 185 TensorShape({num_slots, 2}), 186 &feature_ids_t)); 187 auto feature_ids = feature_ids_t->matrix<int64>(); 188 189 TensorShape gradient_shape = accumulator_resource.gradient_shape(); 190 int64 num_gradient_elements = gradient_shape.num_elements(); 191 gradient_shape.InsertDim(0, num_slots); 192 Tensor* gradients_t = nullptr; 193 OP_REQUIRES_OK(context, 194 context->allocate_output("output_gradients", gradient_shape, 195 &gradients_t)); 196 auto gradients = gradients_t->flat_outer_dims<float>(); 197 198 TensorShape hessian_shape = accumulator_resource.hessian_shape(); 199 int64 num_hessian_elements = hessian_shape.num_elements(); 200 hessian_shape.InsertDim(0, num_slots); 201 Tensor* hessians_t = nullptr; 202 OP_REQUIRES_OK(context, context->allocate_output("output_hessians", 203 hessian_shape, &hessians_t)); 204 auto hessians = hessians_t->flat_outer_dims<float>(); 205 206 int i = 0; 207 for (const auto& iter : accumulator_resource.values()) { 208 partition_ids(i) = iter.first.partition_id; 209 feature_ids(i, 0) = iter.first.feature_id; 210 feature_ids(i, 1) = iter.first.dimension; 211 212 for (int j = 0; j < num_gradient_elements; ++j) { 213 gradients(i, j) = iter.second.first[j]; 214 } 215 for (int j = 0; j < num_hessian_elements; ++j) { 216 hessians(i, j) = iter.second.second[j]; 217 } 218 ++i; 219 } 220 } 221 222 void AddToScalarAccumulator( 223 StatsAccumulatorScalarResource* accumulator_resource, 224 const Tensor& partition_ids_t, const Tensor& feature_ids_t, 225 const Tensor& gradients_t, const Tensor& hessians_t) { 226 accumulator_resource->set_num_updates(accumulator_resource->num_updates() + 227 1); 228 const TensorShape& partition_ids_shape = partition_ids_t.shape(); 229 const auto& partition_ids = partition_ids_t.vec<int32>(); 230 const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>(); 231 const auto& gradients = gradients_t.vec<float>(); 232 const auto& hessians = hessians_t.vec<float>(); 233 234 int64 num_updates = partition_ids_shape.dim_size(0); 235 auto stats_map = accumulator_resource->mutable_values(); 236 for (int64 i = 0; i < num_updates; ++i) { 237 const auto key = 238 PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0), 239 feature_ids_and_dimensions(i, 1)); 240 auto itr = stats_map->find(key); 241 if (itr != stats_map->end()) { 242 itr->second.first += gradients(i); 243 itr->second.second += hessians(i); 244 } else { 245 (*stats_map)[key] = {gradients(i), hessians(i)}; 246 } 247 } 248 } 249 250 void AddToScalarAccumulator( 251 StatsAccumulatorScalarResource* accumulator_resource, 252 OpKernelContext* context) { 253 const Tensor* partition_ids_t; 254 OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t)); 255 const Tensor* feature_ids_t; 256 OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t)); 257 const Tensor* gradients_t; 258 OP_REQUIRES_OK(context, context->input("gradients", &gradients_t)); 259 const Tensor* hessians_t; 260 OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); 261 AddToScalarAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t, 262 *gradients_t, *hessians_t); 263 } 264 265 void AddToTensorAccumulator( 266 StatsAccumulatorTensorResource* accumulator_resource, 267 const Tensor& partition_ids_t, const Tensor& feature_ids_t, 268 const Tensor& gradients_t, const Tensor& hessians_t, 269 OpKernelContext* context) { 270 accumulator_resource->set_num_updates(accumulator_resource->num_updates() + 271 1); 272 273 const TensorShape& partition_ids_shape = partition_ids_t.shape(); 274 const auto& partition_ids = partition_ids_t.vec<int32>(); 275 const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>(); 276 TensorShape gradients_shape = gradients_t.shape(); 277 const auto& gradients = gradients_t.flat_outer_dims<float>(); 278 TensorShape hessians_shape = hessians_t.shape(); 279 const auto& hessians = hessians_t.flat_outer_dims<float>(); 280 281 gradients_shape.RemoveDim(0); 282 hessians_shape.RemoveDim(0); 283 284 // TODO(soroush): Move gradient and hessian shape check to ShapeFn. 285 OP_REQUIRES( 286 context, gradients_shape == accumulator_resource->gradient_shape(), 287 errors::InvalidArgument(strings::StrCat( 288 "Gradients dimensions must match: ", gradients_shape.DebugString(), 289 ", ", accumulator_resource->gradient_shape().DebugString()))); 290 291 OP_REQUIRES( 292 context, hessians_shape == accumulator_resource->hessian_shape(), 293 errors::InvalidArgument(strings::StrCat( 294 "Hessian dimensions must match: ", hessians_shape.DebugString(), ", ", 295 accumulator_resource->hessian_shape().DebugString()))); 296 297 int64 num_updates = partition_ids_shape.dim_size(0); 298 auto stats_map = accumulator_resource->mutable_values(); 299 for (int64 i = 0; i < num_updates; ++i) { 300 const auto key = 301 PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0), 302 feature_ids_and_dimensions(i, 1)); 303 auto itr = stats_map->find(key); 304 if (itr == stats_map->end()) { 305 std::vector<float> new_gradients(gradients_shape.num_elements()); 306 for (int j = 0; j < gradients_shape.num_elements(); ++j) { 307 new_gradients[j] = gradients(i, j); 308 } 309 std::vector<float> new_hessians(hessians_shape.num_elements()); 310 for (int j = 0; j < hessians_shape.num_elements(); ++j) { 311 new_hessians[j] = hessians(i, j); 312 } 313 (*stats_map)[key] = {new_gradients, new_hessians}; 314 } else { 315 auto& stored_gradients = itr->second.first; 316 for (int j = 0; j < gradients_shape.num_elements(); ++j) { 317 stored_gradients[j] += gradients(i, j); 318 } 319 auto& stored_hessians = itr->second.second; 320 for (int j = 0; j < hessians_shape.num_elements(); ++j) { 321 stored_hessians[j] += hessians(i, j); 322 } 323 } 324 } 325 } 326 327 void AddToTensorAccumulator( 328 StatsAccumulatorTensorResource* accumulator_resource, 329 OpKernelContext* context) { 330 const Tensor* partition_ids_t; 331 OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t)); 332 const Tensor* feature_ids_t; 333 OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t)); 334 const Tensor* gradients_t; 335 OP_REQUIRES_OK(context, context->input("gradients", &gradients_t)); 336 const Tensor* hessians_t; 337 OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); 338 AddToTensorAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t, 339 *gradients_t, *hessians_t, context); 340 } 341 342 } // namespace 343 344 REGISTER_RESOURCE_HANDLE_KERNEL(StatsAccumulatorScalarResource); 345 REGISTER_RESOURCE_HANDLE_KERNEL(StatsAccumulatorTensorResource); 346 347 REGISTER_KERNEL_BUILDER( 348 Name("StatsAccumulatorScalarIsInitialized").Device(DEVICE_CPU), 349 IsResourceInitialized<StatsAccumulatorScalarResource>); 350 351 REGISTER_KERNEL_BUILDER( 352 Name("StatsAccumulatorTensorIsInitialized").Device(DEVICE_CPU), 353 IsResourceInitialized<StatsAccumulatorTensorResource>); 354 355 class CreateStatsAccumulatorScalarOp : public OpKernel { 356 public: 357 explicit CreateStatsAccumulatorScalarOp(OpKernelConstruction* context) 358 : OpKernel(context) {} 359 360 void Compute(OpKernelContext* context) override { 361 const Tensor* stamp_token_t; 362 OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t)); 363 364 TensorShape gradient_shape = TensorShape({}); 365 TensorShape hessian_shape = TensorShape({}); 366 367 auto* result = 368 new StatsAccumulatorScalarResource(gradient_shape, hessian_shape); 369 result->set_stamp(stamp_token_t->scalar<int64>()()); 370 // Only create one, if one does not exist already. Report status for all 371 // other exceptions. If one already exists, it unrefs the new one. 372 auto status = CreateResource(context, HandleFromInput(context, 0), result); 373 if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) { 374 OP_REQUIRES(context, false, status); 375 } 376 } 377 }; 378 379 REGISTER_KERNEL_BUILDER(Name("CreateStatsAccumulatorScalar").Device(DEVICE_CPU), 380 CreateStatsAccumulatorScalarOp); 381 382 class CreateStatsAccumulatorTensorOp : public OpKernel { 383 public: 384 explicit CreateStatsAccumulatorTensorOp(OpKernelConstruction* context) 385 : OpKernel(context) {} 386 387 void Compute(OpKernelContext* context) override { 388 const Tensor* stamp_token_t; 389 OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t)); 390 391 const Tensor* gradient_shape_t; 392 OP_REQUIRES_OK( 393 context, context->input("per_slot_gradient_shape", &gradient_shape_t)); 394 395 const Tensor* hessian_shape_t; 396 OP_REQUIRES_OK(context, 397 context->input("per_slot_hessian_shape", &hessian_shape_t)); 398 TensorShape gradient_shape = TensorShape(gradient_shape_t->vec<int64>()); 399 TensorShape hessian_shape = TensorShape(hessian_shape_t->vec<int64>()); 400 auto* result = 401 new StatsAccumulatorTensorResource(gradient_shape, hessian_shape); 402 result->set_stamp(stamp_token_t->scalar<int64>()()); 403 404 // Only create one, if one does not exist already. Report status for all 405 // other exceptions. If one already exists, it unrefs the new one. 406 auto status = CreateResource(context, HandleFromInput(context, 0), result); 407 if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) { 408 OP_REQUIRES(context, false, status); 409 } 410 } 411 }; 412 413 REGISTER_KERNEL_BUILDER(Name("CreateStatsAccumulatorTensor").Device(DEVICE_CPU), 414 CreateStatsAccumulatorTensorOp); 415 416 class StatsAccumulatorScalarAddOp : public OpKernel { 417 public: 418 explicit StatsAccumulatorScalarAddOp(OpKernelConstruction* context) 419 : OpKernel(context) {} 420 421 void Compute(OpKernelContext* context) override { 422 OpInputList resource_handle_list; 423 OP_REQUIRES_OK(context, context->input_list("stats_accumulator_handles", 424 &resource_handle_list)); 425 OpInputList partition_ids_list; 426 OP_REQUIRES_OK(context, 427 context->input_list("partition_ids", &partition_ids_list)); 428 429 OpInputList feature_ids_list; 430 OP_REQUIRES_OK(context, 431 context->input_list("feature_ids", &feature_ids_list)); 432 OpInputList gradients_list; 433 OP_REQUIRES_OK(context, context->input_list("gradients", &gradients_list)); 434 OpInputList hessians_list; 435 OP_REQUIRES_OK(context, context->input_list("hessians", &hessians_list)); 436 437 const Tensor* stamp_token_t; 438 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); 439 int64 stamp_token = stamp_token_t->scalar<int64>()(); 440 441 thread::ThreadPool* const worker_threads = 442 context->device()->tensorflow_cpu_worker_threads()->workers; 443 boosted_trees::utils::ParallelFor( 444 resource_handle_list.size(), worker_threads->NumThreads(), 445 worker_threads, 446 [&context, &resource_handle_list, &partition_ids_list, 447 &feature_ids_list, &gradients_list, &hessians_list, 448 stamp_token](int64 start, int64 end) { 449 for (int resource_handle_idx = start; resource_handle_idx < end; 450 ++resource_handle_idx) { 451 const ResourceHandle& handle = 452 resource_handle_list[resource_handle_idx] 453 .flat<ResourceHandle>()(0); 454 455 StatsAccumulatorScalarResource* accumulator_resource; 456 OP_REQUIRES_OK(context, LookupResource(context, handle, 457 &accumulator_resource)); 458 mutex_lock l(*accumulator_resource->mutex()); 459 core::ScopedUnref unref_me(accumulator_resource); 460 461 // If the stamp is invalid we drop the update. 462 if (!accumulator_resource->is_stamp_valid(stamp_token)) { 463 VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. " 464 << "Passed stamp token: " << stamp_token << " " 465 << "Current token: " << accumulator_resource->stamp(); 466 return; 467 } 468 AddToScalarAccumulator(accumulator_resource, 469 partition_ids_list[resource_handle_idx], 470 feature_ids_list[resource_handle_idx], 471 gradients_list[resource_handle_idx], 472 hessians_list[resource_handle_idx]); 473 } 474 }); 475 } 476 }; 477 478 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorScalarAdd").Device(DEVICE_CPU), 479 StatsAccumulatorScalarAddOp); 480 481 class StatsAccumulatorTensorAddOp : public OpKernel { 482 public: 483 explicit StatsAccumulatorTensorAddOp(OpKernelConstruction* context) 484 : OpKernel(context) {} 485 486 void Compute(OpKernelContext* context) override { 487 OpInputList resource_handle_list; 488 OP_REQUIRES_OK(context, context->input_list("stats_accumulator_handles", 489 &resource_handle_list)); 490 OpInputList partition_ids_list; 491 OP_REQUIRES_OK(context, 492 context->input_list("partition_ids", &partition_ids_list)); 493 494 OpInputList feature_ids_list; 495 OP_REQUIRES_OK(context, 496 context->input_list("feature_ids", &feature_ids_list)); 497 OpInputList gradients_list; 498 OP_REQUIRES_OK(context, context->input_list("gradients", &gradients_list)); 499 OpInputList hessians_list; 500 OP_REQUIRES_OK(context, context->input_list("hessians", &hessians_list)); 501 502 const Tensor* stamp_token_t; 503 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); 504 int64 stamp_token = stamp_token_t->scalar<int64>()(); 505 506 thread::ThreadPool* const worker_threads = 507 context->device()->tensorflow_cpu_worker_threads()->workers; 508 boosted_trees::utils::ParallelFor( 509 resource_handle_list.size(), worker_threads->NumThreads(), 510 worker_threads, 511 [&context, &resource_handle_list, &partition_ids_list, 512 &feature_ids_list, &gradients_list, &hessians_list, 513 stamp_token](int64 start, int64 end) { 514 for (int resource_handle_idx = start; resource_handle_idx < end; 515 ++resource_handle_idx) { 516 const ResourceHandle& handle = 517 resource_handle_list[resource_handle_idx] 518 .flat<ResourceHandle>()(0); 519 520 StatsAccumulatorTensorResource* accumulator_resource; 521 OP_REQUIRES_OK(context, LookupResource(context, handle, 522 &accumulator_resource)); 523 mutex_lock l(*accumulator_resource->mutex()); 524 core::ScopedUnref unref_me(accumulator_resource); 525 526 // If the stamp is invalid we drop the update. 527 if (!accumulator_resource->is_stamp_valid(stamp_token)) { 528 VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. " 529 << "Passed stamp token: " << stamp_token << " " 530 << "Current token: " << accumulator_resource->stamp(); 531 return; 532 } 533 AddToTensorAccumulator(accumulator_resource, 534 partition_ids_list[resource_handle_idx], 535 feature_ids_list[resource_handle_idx], 536 gradients_list[resource_handle_idx], 537 hessians_list[resource_handle_idx], context); 538 } 539 }); 540 } 541 }; 542 543 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorTensorAdd").Device(DEVICE_CPU), 544 StatsAccumulatorTensorAddOp); 545 546 class StatsAccumulatorScalarFlushOp : public OpKernel { 547 public: 548 explicit StatsAccumulatorScalarFlushOp(OpKernelConstruction* context) 549 : OpKernel(context) {} 550 551 void Compute(OpKernelContext* context) override { 552 StatsAccumulatorScalarResource* accumulator_resource; 553 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 554 &accumulator_resource)); 555 mutex_lock l(*accumulator_resource->mutex()); 556 core::ScopedUnref unref_me(accumulator_resource); 557 558 const Tensor* stamp_token_t; 559 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); 560 int64 stamp_token = stamp_token_t->scalar<int64>()(); 561 562 // If the stamp is invalid we restart the PS. It shouldn't happen since 563 // only Chief should call this function and chief is guaranteed to be in 564 // a consistent state. 565 CHECK(accumulator_resource->is_stamp_valid(stamp_token)); 566 567 const Tensor* next_stamp_token_t; 568 OP_REQUIRES_OK(context, 569 context->input(kNextStampTokenName, &next_stamp_token_t)); 570 int64 next_stamp_token = next_stamp_token_t->scalar<int64>()(); 571 CHECK(stamp_token != next_stamp_token); 572 573 SerializeScalarAccumulatorToOutput(*accumulator_resource, context); 574 Tensor* num_updates_t = nullptr; 575 OP_REQUIRES_OK(context, 576 context->allocate_output("num_updates", TensorShape({}), 577 &num_updates_t)); 578 num_updates_t->scalar<int64>()() = accumulator_resource->num_updates(); 579 580 accumulator_resource->Clear(); 581 accumulator_resource->set_stamp(next_stamp_token); 582 } 583 }; 584 585 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorScalarFlush").Device(DEVICE_CPU), 586 StatsAccumulatorScalarFlushOp); 587 588 class StatsAccumulatorTensorFlushOp : public OpKernel { 589 public: 590 explicit StatsAccumulatorTensorFlushOp(OpKernelConstruction* context) 591 : OpKernel(context) {} 592 593 void Compute(OpKernelContext* context) override { 594 StatsAccumulatorTensorResource* accumulator_resource; 595 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 596 &accumulator_resource)); 597 mutex_lock l(*accumulator_resource->mutex()); 598 core::ScopedUnref unref_me(accumulator_resource); 599 600 const Tensor* stamp_token_t; 601 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); 602 int64 stamp_token = stamp_token_t->scalar<int64>()(); 603 604 const Tensor* next_stamp_token_t; 605 OP_REQUIRES_OK(context, 606 context->input(kNextStampTokenName, &next_stamp_token_t)); 607 int64 next_stamp_token = next_stamp_token_t->scalar<int64>()(); 608 609 // If the stamp is invalid we restart the PS. It shouldn't happen since 610 // only Chief should call this function and chief is guaranteed to be in 611 // a consistent state. 612 CHECK(accumulator_resource->is_stamp_valid(stamp_token)); 613 CHECK(stamp_token != next_stamp_token); 614 SerializeTensorAccumulatorToOutput(*accumulator_resource, context); 615 Tensor* num_updates_t = nullptr; 616 OP_REQUIRES_OK(context, 617 context->allocate_output("num_updates", TensorShape({}), 618 &num_updates_t)); 619 num_updates_t->scalar<int64>()() = accumulator_resource->num_updates(); 620 accumulator_resource->Clear(); 621 accumulator_resource->set_stamp(next_stamp_token); 622 } 623 }; 624 625 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorTensorFlush").Device(DEVICE_CPU), 626 StatsAccumulatorTensorFlushOp); 627 628 class StatsAccumulatorScalarDeserializeOp : public OpKernel { 629 public: 630 explicit StatsAccumulatorScalarDeserializeOp(OpKernelConstruction* context) 631 : OpKernel(context) {} 632 633 void Compute(OpKernelContext* context) override { 634 StatsAccumulatorScalarResource* accumulator_resource; 635 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 636 &accumulator_resource)); 637 mutex_lock l(*accumulator_resource->mutex()); 638 core::ScopedUnref unref_me(accumulator_resource); 639 640 // Check the stamp token. 641 const Tensor* stamp_token_t; 642 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); 643 int64 stamp_token = stamp_token_t->scalar<int64>()(); 644 accumulator_resource->Clear(); 645 accumulator_resource->set_stamp(stamp_token); 646 AddToScalarAccumulator(accumulator_resource, context); 647 const Tensor* num_updates_t; 648 OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t)); 649 accumulator_resource->set_num_updates(num_updates_t->scalar<int64>()()); 650 } 651 }; 652 653 REGISTER_KERNEL_BUILDER( 654 Name("StatsAccumulatorScalarDeserialize").Device(DEVICE_CPU), 655 StatsAccumulatorScalarDeserializeOp); 656 657 class StatsAccumulatorTensorDeserializeOp : public OpKernel { 658 public: 659 explicit StatsAccumulatorTensorDeserializeOp(OpKernelConstruction* context) 660 : OpKernel(context) {} 661 662 void Compute(OpKernelContext* context) override { 663 StatsAccumulatorTensorResource* accumulator_resource; 664 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 665 &accumulator_resource)); 666 mutex_lock l(*accumulator_resource->mutex()); 667 core::ScopedUnref unref_me(accumulator_resource); 668 669 // Check the stamp token. 670 const Tensor* stamp_token_t; 671 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); 672 int64 stamp_token = stamp_token_t->scalar<int64>()(); 673 accumulator_resource->Clear(); 674 accumulator_resource->set_stamp(stamp_token); 675 AddToTensorAccumulator(accumulator_resource, context); 676 const Tensor* num_updates_t; 677 OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t)); 678 accumulator_resource->set_num_updates(num_updates_t->scalar<int64>()()); 679 } 680 }; 681 682 REGISTER_KERNEL_BUILDER( 683 Name("StatsAccumulatorTensorDeserialize").Device(DEVICE_CPU), 684 StatsAccumulatorTensorDeserializeOp); 685 686 class StatsAccumulatorScalarSerializeOp : public OpKernel { 687 public: 688 explicit StatsAccumulatorScalarSerializeOp(OpKernelConstruction* context) 689 : OpKernel(context) {} 690 691 void Compute(OpKernelContext* context) override { 692 StatsAccumulatorScalarResource* accumulator_resource; 693 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 694 &accumulator_resource)); 695 mutex_lock l(*accumulator_resource->mutex()); 696 core::ScopedUnref unref_me(accumulator_resource); 697 SerializeScalarAccumulatorToOutput(*accumulator_resource, context); 698 Tensor* stamp_token_t = nullptr; 699 OP_REQUIRES_OK(context, 700 context->allocate_output("stamp_token", TensorShape({}), 701 &stamp_token_t)); 702 stamp_token_t->scalar<int64>()() = accumulator_resource->stamp(); 703 704 Tensor* num_updates_t = nullptr; 705 OP_REQUIRES_OK(context, 706 context->allocate_output("num_updates", TensorShape({}), 707 &num_updates_t)); 708 num_updates_t->scalar<int64>()() = accumulator_resource->num_updates(); 709 } 710 }; 711 712 REGISTER_KERNEL_BUILDER( 713 Name("StatsAccumulatorScalarSerialize").Device(DEVICE_CPU), 714 StatsAccumulatorScalarSerializeOp); 715 716 class StatsAccumulatorTensorSerializeOp : public OpKernel { 717 public: 718 explicit StatsAccumulatorTensorSerializeOp(OpKernelConstruction* context) 719 : OpKernel(context) {} 720 721 void Compute(OpKernelContext* context) override { 722 StatsAccumulatorTensorResource* accumulator_resource; 723 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 724 &accumulator_resource)); 725 mutex_lock l(*accumulator_resource->mutex()); 726 core::ScopedUnref unref_me(accumulator_resource); 727 SerializeTensorAccumulatorToOutput(*accumulator_resource, context); 728 Tensor* stamp_token_t = nullptr; 729 OP_REQUIRES_OK(context, 730 context->allocate_output("stamp_token", TensorShape({}), 731 &stamp_token_t)); 732 stamp_token_t->scalar<int64>()() = accumulator_resource->stamp(); 733 734 Tensor* num_updates_t = nullptr; 735 OP_REQUIRES_OK(context, 736 context->allocate_output("num_updates", TensorShape({}), 737 &num_updates_t)); 738 num_updates_t->scalar<int64>()() = accumulator_resource->num_updates(); 739 } 740 }; 741 742 REGISTER_KERNEL_BUILDER( 743 Name("StatsAccumulatorTensorSerialize").Device(DEVICE_CPU), 744 StatsAccumulatorTensorSerializeOp); 745 746 class StatsAccumulatorScalarMakeSummaryOp : public OpKernel { 747 public: 748 explicit StatsAccumulatorScalarMakeSummaryOp(OpKernelConstruction* context) 749 : OpKernel(context) {} 750 751 void Compute(OpKernelContext* context) override { 752 TensorShape gradient_shape = TensorShape({}); 753 TensorShape hessian_shape = TensorShape({}); 754 StatsAccumulatorScalarResource* accumulator_resource = 755 new StatsAccumulatorScalarResource(gradient_shape, hessian_shape); 756 core::ScopedUnref unref_me(accumulator_resource); 757 // Check the stamp token. 758 AddToScalarAccumulator(accumulator_resource, context); 759 SerializeScalarAccumulatorToOutput(*accumulator_resource, context); 760 } 761 }; 762 763 REGISTER_KERNEL_BUILDER( 764 Name("StatsAccumulatorScalarMakeSummary").Device(DEVICE_CPU), 765 StatsAccumulatorScalarMakeSummaryOp); 766 767 class StatsAccumulatorTensorMakeSummaryOp : public OpKernel { 768 public: 769 explicit StatsAccumulatorTensorMakeSummaryOp(OpKernelConstruction* context) 770 : OpKernel(context) {} 771 772 void Compute(OpKernelContext* context) override { 773 const Tensor* gradients_t; 774 OP_REQUIRES_OK(context, context->input("gradients", &gradients_t)); 775 TensorShape gradients_shape = gradients_t->shape(); 776 gradients_shape.RemoveDim(0); 777 778 const Tensor* hessians_t; 779 OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); 780 TensorShape hessians_shape = hessians_t->shape(); 781 hessians_shape.RemoveDim(0); 782 783 StatsAccumulatorTensorResource* accumulator_resource = 784 new StatsAccumulatorTensorResource(gradients_shape, hessians_shape); 785 core::ScopedUnref unref_me(accumulator_resource); 786 // Check the stamp token. 787 AddToTensorAccumulator(accumulator_resource, context); 788 SerializeTensorAccumulatorToOutput(*accumulator_resource, context); 789 } 790 }; 791 792 REGISTER_KERNEL_BUILDER( 793 Name("StatsAccumulatorTensorMakeSummary").Device(DEVICE_CPU), 794 StatsAccumulatorTensorMakeSummaryOp); 795 796 } // namespace boosted_trees 797 } // namespace tensorflow 798