1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #include <algorithm> 16 #include <iterator> 17 #include <map> 18 #include <string> 19 #include <vector> 20 21 #include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h" 22 #include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h" 23 #include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/resource_mgr.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/lib/core/errors.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/platform/types.h" 32 #include "tensorflow/core/util/work_sharder.h" 33 34 namespace tensorflow { 35 namespace boosted_trees { 36 37 namespace { 38 const char* const kStampTokenName = "stamp_token"; 39 const char* const kNextStampTokenName = "next_stamp_token"; 40 41 struct PartitionKey { 42 PartitionKey() : partition_id(-1), feature_id(-1), dimension(-1) {} 43 44 PartitionKey(int32 p, int64 f, int32 d) 45 : partition_id(p), feature_id(f), dimension(d) {} 46 47 bool operator==(const PartitionKey& other) const { 48 return (partition_id == other.partition_id) && 49 (dimension == other.dimension) && (feature_id == other.feature_id); 50 } 51 52 // Compare for PartitionKey. 53 struct Less { 54 bool operator()(const PartitionKey& a, const PartitionKey& b) const { 55 if (a.partition_id < b.partition_id) { 56 return true; 57 } 58 if ((a.partition_id == b.partition_id) && (a.dimension < b.dimension)) { 59 return true; 60 } 61 if ((a.partition_id == b.partition_id) && (a.dimension == b.dimension) && 62 (a.feature_id < b.feature_id)) { 63 return true; 64 } 65 return false; 66 } 67 }; 68 69 // Tree partition defined by traversing the tree to the leaf. 70 int32 partition_id; 71 72 // Feature column id. 73 int64 feature_id; 74 75 // Dimension within feature column. 76 int32 dimension; 77 }; 78 79 template <typename GradientType, typename HessianType> 80 class StatsAccumulatorResource : public boosted_trees::StampedResource { 81 using StatsByPartition = 82 std::map<PartitionKey, std::pair<GradientType, HessianType>, 83 PartitionKey::Less>; 84 85 public: 86 StatsAccumulatorResource(const TensorShape& gradient_shape, 87 const TensorShape& hessian_shape) 88 : gradient_shape_(gradient_shape), 89 hessian_shape_(hessian_shape), 90 num_updates_(0) { 91 // If GradientType/HessianType is scalar float then the shapes should be 92 // scalar and vice versa. 93 CHECK_EQ((std::is_same<GradientType, float>::value), 94 TensorShapeUtils::IsScalar(gradient_shape)); 95 CHECK_EQ((std::is_same<HessianType, float>::value), 96 TensorShapeUtils::IsScalar(hessian_shape)); 97 } 98 99 string DebugString() override { 100 return strings::StrCat("StatsAccumulatorResource[size=", values_.size(), 101 "]"); 102 } 103 104 void Clear() { 105 values_.clear(); 106 num_updates_ = 0; 107 } 108 109 tensorflow::mutex* mutex() { return &mu_; } 110 StatsByPartition* mutable_values() { return &values_; } 111 const StatsByPartition& values() const { return values_; } 112 const int64& num_updates() const { return num_updates_; } 113 void set_num_updates(int64 val) { num_updates_ = val; } 114 const TensorShape& gradient_shape() const { return gradient_shape_; } 115 const TensorShape& hessian_shape() const { return hessian_shape_; } 116 117 private: 118 // Key into a specific partition to accumulate stats for the specified feature 119 // id. 120 StatsByPartition values_; 121 const TensorShape gradient_shape_; 122 const TensorShape hessian_shape_; 123 int64 num_updates_; 124 tensorflow::mutex mu_; 125 TF_DISALLOW_COPY_AND_ASSIGN(StatsAccumulatorResource); 126 }; 127 128 using StatsAccumulatorScalarResource = StatsAccumulatorResource<float, float>; 129 using StatsAccumulatorTensorResource = 130 StatsAccumulatorResource<std::vector<float>, std::vector<float>>; 131 132 void SerializeScalarAccumulatorToOutput( 133 const StatsAccumulatorScalarResource& accumulator_resource, 134 OpKernelContext* context) { 135 int64 num_slots = accumulator_resource.values().size(); 136 Tensor* partition_ids_t = nullptr; 137 OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids", 138 TensorShape({num_slots}), 139 &partition_ids_t)); 140 auto partition_ids = partition_ids_t->vec<int32>(); 141 142 // Feature ids tensor has ids of feature columns and their dimensions. 143 Tensor* feature_ids_t = nullptr; 144 OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids", 145 TensorShape({num_slots, 2}), 146 &feature_ids_t)); 147 auto feature_ids = feature_ids_t->matrix<int64>(); 148 149 Tensor* gradients_t = nullptr; 150 OP_REQUIRES_OK( 151 context, context->allocate_output( 152 "output_gradients", TensorShape({num_slots}), &gradients_t)); 153 auto gradients = gradients_t->vec<float>(); 154 155 Tensor* hessians_t = nullptr; 156 OP_REQUIRES_OK( 157 context, context->allocate_output("output_hessians", 158 TensorShape({num_slots}), &hessians_t)); 159 auto hessians = hessians_t->vec<float>(); 160 161 int i = 0; 162 for (const auto& iter : accumulator_resource.values()) { 163 partition_ids(i) = iter.first.partition_id; 164 feature_ids(i, 0) = iter.first.feature_id; 165 feature_ids(i, 1) = iter.first.dimension; 166 167 gradients(i) = iter.second.first; 168 hessians(i) = iter.second.second; 169 ++i; 170 } 171 } 172 173 void SerializeTensorAccumulatorToOutput( 174 const StatsAccumulatorTensorResource& accumulator_resource, 175 OpKernelContext* context) { 176 int64 num_slots = accumulator_resource.values().size(); 177 Tensor* partition_ids_t = nullptr; 178 OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids", 179 TensorShape({num_slots}), 180 &partition_ids_t)); 181 auto partition_ids = partition_ids_t->vec<int32>(); 182 183 Tensor* feature_ids_t = nullptr; 184 OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids", 185 TensorShape({num_slots, 2}), 186 &feature_ids_t)); 187 auto feature_ids = feature_ids_t->matrix<int64>(); 188 189 TensorShape gradient_shape = accumulator_resource.gradient_shape(); 190 int64 num_gradient_elements = gradient_shape.num_elements(); 191 gradient_shape.InsertDim(0, num_slots); 192 Tensor* gradients_t = nullptr; 193 OP_REQUIRES_OK(context, 194 context->allocate_output("output_gradients", gradient_shape, 195 &gradients_t)); 196 auto gradients = gradients_t->flat_outer_dims<float>(); 197 198 TensorShape hessian_shape = accumulator_resource.hessian_shape(); 199 int64 num_hessian_elements = hessian_shape.num_elements(); 200 hessian_shape.InsertDim(0, num_slots); 201 Tensor* hessians_t = nullptr; 202 OP_REQUIRES_OK(context, context->allocate_output("output_hessians", 203 hessian_shape, &hessians_t)); 204 auto hessians = hessians_t->flat_outer_dims<float>(); 205 206 int i = 0; 207 for (const auto& iter : accumulator_resource.values()) { 208 partition_ids(i) = iter.first.partition_id; 209 feature_ids(i, 0) = iter.first.feature_id; 210 feature_ids(i, 1) = iter.first.dimension; 211 212 for (int j = 0; j < num_gradient_elements; ++j) { 213 gradients(i, j) = iter.second.first[j]; 214 } 215 for (int j = 0; j < num_hessian_elements; ++j) { 216 hessians(i, j) = iter.second.second[j]; 217 } 218 ++i; 219 } 220 } 221 222 void AddToScalarAccumulator( 223 StatsAccumulatorScalarResource* accumulator_resource, 224 const Tensor& partition_ids_t, const Tensor& feature_ids_t, 225 const Tensor& gradients_t, const Tensor& hessians_t) { 226 accumulator_resource->set_num_updates(accumulator_resource->num_updates() + 227 1); 228 const TensorShape& partition_ids_shape = partition_ids_t.shape(); 229 const auto& partition_ids = partition_ids_t.vec<int32>(); 230 const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>(); 231 const auto& gradients = gradients_t.vec<float>(); 232 const auto& hessians = hessians_t.vec<float>(); 233 234 int64 num_updates = partition_ids_shape.dim_size(0); 235 auto stats_map = accumulator_resource->mutable_values(); 236 for (int64 i = 0; i < num_updates; ++i) { 237 const auto key = 238 PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0), 239 feature_ids_and_dimensions(i, 1)); 240 auto itr = stats_map->find(key); 241 if (itr != stats_map->end()) { 242 itr->second.first += gradients(i); 243 itr->second.second += hessians(i); 244 } else { 245 (*stats_map)[key] = {gradients(i), hessians(i)}; 246 } 247 } 248 } 249 250 void AddToScalarAccumulator( 251 StatsAccumulatorScalarResource* accumulator_resource, 252 OpKernelContext* context) { 253 const Tensor* partition_ids_t; 254 OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t)); 255 const Tensor* feature_ids_t; 256 OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t)); 257 const Tensor* gradients_t; 258 OP_REQUIRES_OK(context, context->input("gradients", &gradients_t)); 259 const Tensor* hessians_t; 260 OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); 261 AddToScalarAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t, 262 *gradients_t, *hessians_t); 263 } 264 265 void AddToTensorAccumulator( 266 StatsAccumulatorTensorResource* accumulator_resource, 267 const Tensor& partition_ids_t, const Tensor& feature_ids_t, 268 const Tensor& gradients_t, const Tensor& hessians_t, 269 OpKernelContext* context) { 270 accumulator_resource->set_num_updates(accumulator_resource->num_updates() + 271 1); 272 273 const TensorShape& partition_ids_shape = partition_ids_t.shape(); 274 const auto& partition_ids = partition_ids_t.vec<int32>(); 275 const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>(); 276 TensorShape gradients_shape = gradients_t.shape(); 277 const auto& gradients = gradients_t.flat_outer_dims<float>(); 278 TensorShape hessians_shape = hessians_t.shape(); 279 const auto& hessians = hessians_t.flat_outer_dims<float>(); 280 281 gradients_shape.RemoveDim(0); 282 hessians_shape.RemoveDim(0); 283 284 // TODO(soroush): Move gradient and hessian shape check to ShapeFn. 285 OP_REQUIRES( 286 context, gradients_shape == accumulator_resource->gradient_shape(), 287 errors::InvalidArgument(strings::StrCat( 288 "Gradients dimensions must match: ", gradients_shape.DebugString(), 289 ", ", accumulator_resource->gradient_shape().DebugString()))); 290 291 OP_REQUIRES( 292 context, hessians_shape == accumulator_resource->hessian_shape(), 293 errors::InvalidArgument(strings::StrCat( 294 "Hessian dimensions must match: ", hessians_shape.DebugString(), ", ", 295 accumulator_resource->hessian_shape().DebugString()))); 296 297 int64 num_updates = partition_ids_shape.dim_size(0); 298 auto stats_map = accumulator_resource->mutable_values(); 299 for (int64 i = 0; i < num_updates; ++i) { 300 const auto key = 301 PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0), 302 feature_ids_and_dimensions(i, 1)); 303 auto itr = stats_map->find(key); 304 if (itr == stats_map->end()) { 305 std::vector<float> new_gradients(gradients_shape.num_elements()); 306 for (int j = 0; j < gradients_shape.num_elements(); ++j) { 307 new_gradients[j] = gradients(i, j); 308 } 309 std::vector<float> new_hessians(hessians_shape.num_elements()); 310 for (int j = 0; j < hessians_shape.num_elements(); ++j) { 311 new_hessians[j] = hessians(i, j); 312 } 313 (*stats_map)[key] = {new_gradients, new_hessians}; 314 } else { 315 auto& stored_gradients = itr->second.first; 316 for (int j = 0; j < gradients_shape.num_elements(); ++j) { 317 stored_gradients[j] += gradients(i, j); 318 } 319 auto& stored_hessians = itr->second.second; 320 for (int j = 0; j < hessians_shape.num_elements(); ++j) { 321 stored_hessians[j] += hessians(i, j); 322 } 323 } 324 } 325 } 326 327 void AddToTensorAccumulator( 328 StatsAccumulatorTensorResource* accumulator_resource, 329 OpKernelContext* context) { 330 const Tensor* partition_ids_t; 331 OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t)); 332 const Tensor* feature_ids_t; 333 OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t)); 334 const Tensor* gradients_t; 335 OP_REQUIRES_OK(context, context->input("gradients", &gradients_t)); 336 const Tensor* hessians_t; 337 OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); 338 AddToTensorAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t, 339 *gradients_t, *hessians_t, context); 340 } 341 342 } // namespace 343 344 REGISTER_RESOURCE_HANDLE_KERNEL(StatsAccumulatorScalarResource); 345 REGISTER_RESOURCE_HANDLE_KERNEL(StatsAccumulatorTensorResource); 346 347 REGISTER_KERNEL_BUILDER( 348 Name("StatsAccumulatorScalarIsInitialized").Device(DEVICE_CPU), 349 IsResourceInitialized<StatsAccumulatorScalarResource>); 350 351 REGISTER_KERNEL_BUILDER( 352 Name("StatsAccumulatorTensorIsInitialized").Device(DEVICE_CPU), 353 IsResourceInitialized<StatsAccumulatorTensorResource>); 354 355 class CreateStatsAccumulatorScalarOp : public OpKernel { 356 public: 357 explicit CreateStatsAccumulatorScalarOp(OpKernelConstruction* context) 358 : OpKernel(context) {} 359 360 void Compute(OpKernelContext* context) override { 361 const Tensor* stamp_token_t; 362 OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t)); 363 364 TensorShape gradient_shape = TensorShape({}); 365 TensorShape hessian_shape = TensorShape({}); 366 367 auto* result = 368 new StatsAccumulatorScalarResource(gradient_shape, hessian_shape); 369 result->set_stamp(stamp_token_t->scalar<int64>()()); 370 // Only create one, if one does not exist already. Report status for all 371 // other exceptions. If one already exists, it unrefs the new one. 372 auto status = CreateResource(context, HandleFromInput(context, 0), result); 373 if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) { 374 OP_REQUIRES(context, false, status); 375 } 376 } 377 }; 378 379 REGISTER_KERNEL_BUILDER(Name("CreateStatsAccumulatorScalar").Device(DEVICE_CPU), 380 CreateStatsAccumulatorScalarOp); 381 382 class CreateStatsAccumulatorTensorOp : public OpKernel { 383 public: 384 explicit CreateStatsAccumulatorTensorOp(OpKernelConstruction* context) 385 : OpKernel(context) {} 386 387 void Compute(OpKernelContext* context) override { 388 const Tensor* stamp_token_t; 389 OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t)); 390 391 const Tensor* gradient_shape_t; 392 OP_REQUIRES_OK( 393 context, context->input("per_slot_gradient_shape", &gradient_shape_t)); 394 395 const Tensor* hessian_shape_t; 396 OP_REQUIRES_OK(context, 397 context->input("per_slot_hessian_shape", &hessian_shape_t)); 398 TensorShape gradient_shape = TensorShape(gradient_shape_t->vec<int64>()); 399 TensorShape hessian_shape = TensorShape(hessian_shape_t->vec<int64>()); 400 auto* result = 401 new StatsAccumulatorTensorResource(gradient_shape, hessian_shape); 402 result->set_stamp(stamp_token_t->scalar<int64>()()); 403 404 // Only create one, if one does not exist already. Report status for all 405 // other exceptions. If one already exists, it unrefs the new one. 406 auto status = CreateResource(context, HandleFromInput(context, 0), result); 407 if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) { 408 OP_REQUIRES(context, false, status); 409 } 410 } 411 }; 412 413 REGISTER_KERNEL_BUILDER(Name("CreateStatsAccumulatorTensor").Device(DEVICE_CPU), 414 CreateStatsAccumulatorTensorOp); 415 416 class StatsAccumulatorScalarAddOp : public OpKernel { 417 public: 418 explicit StatsAccumulatorScalarAddOp(OpKernelConstruction* context) 419 : OpKernel(context) {} 420 421 void Compute(OpKernelContext* context) override { 422 OpInputList resource_handle_list; 423 OP_REQUIRES_OK(context, context->input_list("stats_accumulator_handles", 424 &resource_handle_list)); 425 OpInputList partition_ids_list; 426 OP_REQUIRES_OK(context, 427 context->input_list("partition_ids", &partition_ids_list)); 428 429 OpInputList feature_ids_list; 430 OP_REQUIRES_OK(context, 431 context->input_list("feature_ids", &feature_ids_list)); 432 OpInputList gradients_list; 433 OP_REQUIRES_OK(context, context->input_list("gradients", &gradients_list)); 434 OpInputList hessians_list; 435 OP_REQUIRES_OK(context, context->input_list("hessians", &hessians_list)); 436 437 const Tensor* stamp_token_t; 438 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); 439 int64 stamp_token = stamp_token_t->scalar<int64>()(); 440 441 thread::ThreadPool* const worker_threads = 442 context->device()->tensorflow_cpu_worker_threads()->workers; 443 boosted_trees::utils::ParallelFor( 444 resource_handle_list.size(), worker_threads->NumThreads(), 445 worker_threads, 446 [&context, &resource_handle_list, &partition_ids_list, 447 &feature_ids_list, &gradients_list, &hessians_list, 448 stamp_token](int64 start, int64 end) { 449 for (int resource_handle_idx = start; resource_handle_idx < end; 450 ++resource_handle_idx) { 451 ResourceHandle handle = resource_handle_list[resource_handle_idx] 452 .flat<ResourceHandle>()(0); 453 454 StatsAccumulatorScalarResource* accumulator_resource; 455 OP_REQUIRES_OK(context, LookupResource(context, handle, 456 &accumulator_resource)); 457 mutex_lock l(*accumulator_resource->mutex()); 458 core::ScopedUnref unref_me(accumulator_resource); 459 460 // If the stamp is invalid we drop the update. 461 if (!accumulator_resource->is_stamp_valid(stamp_token)) { 462 VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. " 463 << "Passed stamp token: " << stamp_token << " " 464 << "Current token: " << accumulator_resource->stamp(); 465 return; 466 } 467 AddToScalarAccumulator(accumulator_resource, 468 partition_ids_list[resource_handle_idx], 469 feature_ids_list[resource_handle_idx], 470 gradients_list[resource_handle_idx], 471 hessians_list[resource_handle_idx]); 472 } 473 }); 474 } 475 }; 476 477 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorScalarAdd").Device(DEVICE_CPU), 478 StatsAccumulatorScalarAddOp); 479 480 class StatsAccumulatorTensorAddOp : public OpKernel { 481 public: 482 explicit StatsAccumulatorTensorAddOp(OpKernelConstruction* context) 483 : OpKernel(context) {} 484 485 void Compute(OpKernelContext* context) override { 486 OpInputList resource_handle_list; 487 OP_REQUIRES_OK(context, context->input_list("stats_accumulator_handles", 488 &resource_handle_list)); 489 OpInputList partition_ids_list; 490 OP_REQUIRES_OK(context, 491 context->input_list("partition_ids", &partition_ids_list)); 492 493 OpInputList feature_ids_list; 494 OP_REQUIRES_OK(context, 495 context->input_list("feature_ids", &feature_ids_list)); 496 OpInputList gradients_list; 497 OP_REQUIRES_OK(context, context->input_list("gradients", &gradients_list)); 498 OpInputList hessians_list; 499 OP_REQUIRES_OK(context, context->input_list("hessians", &hessians_list)); 500 501 const Tensor* stamp_token_t; 502 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); 503 int64 stamp_token = stamp_token_t->scalar<int64>()(); 504 505 thread::ThreadPool* const worker_threads = 506 context->device()->tensorflow_cpu_worker_threads()->workers; 507 boosted_trees::utils::ParallelFor( 508 resource_handle_list.size(), worker_threads->NumThreads(), 509 worker_threads, 510 [&context, &resource_handle_list, &partition_ids_list, 511 &feature_ids_list, &gradients_list, &hessians_list, 512 stamp_token](int64 start, int64 end) { 513 for (int resource_handle_idx = start; resource_handle_idx < end; 514 ++resource_handle_idx) { 515 ResourceHandle handle = resource_handle_list[resource_handle_idx] 516 .flat<ResourceHandle>()(0); 517 518 StatsAccumulatorTensorResource* accumulator_resource; 519 OP_REQUIRES_OK(context, LookupResource(context, handle, 520 &accumulator_resource)); 521 mutex_lock l(*accumulator_resource->mutex()); 522 core::ScopedUnref unref_me(accumulator_resource); 523 524 // If the stamp is invalid we drop the update. 525 if (!accumulator_resource->is_stamp_valid(stamp_token)) { 526 VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. " 527 << "Passed stamp token: " << stamp_token << " " 528 << "Current token: " << accumulator_resource->stamp(); 529 return; 530 } 531 AddToTensorAccumulator(accumulator_resource, 532 partition_ids_list[resource_handle_idx], 533 feature_ids_list[resource_handle_idx], 534 gradients_list[resource_handle_idx], 535 hessians_list[resource_handle_idx], context); 536 } 537 }); 538 } 539 }; 540 541 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorTensorAdd").Device(DEVICE_CPU), 542 StatsAccumulatorTensorAddOp); 543 544 class StatsAccumulatorScalarFlushOp : public OpKernel { 545 public: 546 explicit StatsAccumulatorScalarFlushOp(OpKernelConstruction* context) 547 : OpKernel(context) {} 548 549 void Compute(OpKernelContext* context) override { 550 StatsAccumulatorScalarResource* accumulator_resource; 551 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 552 &accumulator_resource)); 553 mutex_lock l(*accumulator_resource->mutex()); 554 core::ScopedUnref unref_me(accumulator_resource); 555 556 const Tensor* stamp_token_t; 557 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); 558 int64 stamp_token = stamp_token_t->scalar<int64>()(); 559 560 // If the stamp is invalid we restart the PS. It shouldn't happen since 561 // only Chief should call this function and chief is guaranteed to be in 562 // a consistent state. 563 CHECK(accumulator_resource->is_stamp_valid(stamp_token)); 564 565 const Tensor* next_stamp_token_t; 566 OP_REQUIRES_OK(context, 567 context->input(kNextStampTokenName, &next_stamp_token_t)); 568 int64 next_stamp_token = next_stamp_token_t->scalar<int64>()(); 569 CHECK(stamp_token != next_stamp_token); 570 571 SerializeScalarAccumulatorToOutput(*accumulator_resource, context); 572 Tensor* num_updates_t = nullptr; 573 OP_REQUIRES_OK(context, 574 context->allocate_output("num_updates", TensorShape({}), 575 &num_updates_t)); 576 num_updates_t->scalar<int64>()() = accumulator_resource->num_updates(); 577 578 accumulator_resource->Clear(); 579 accumulator_resource->set_stamp(next_stamp_token); 580 } 581 }; 582 583 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorScalarFlush").Device(DEVICE_CPU), 584 StatsAccumulatorScalarFlushOp); 585 586 class StatsAccumulatorTensorFlushOp : public OpKernel { 587 public: 588 explicit StatsAccumulatorTensorFlushOp(OpKernelConstruction* context) 589 : OpKernel(context) {} 590 591 void Compute(OpKernelContext* context) override { 592 StatsAccumulatorTensorResource* accumulator_resource; 593 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 594 &accumulator_resource)); 595 mutex_lock l(*accumulator_resource->mutex()); 596 core::ScopedUnref unref_me(accumulator_resource); 597 598 const Tensor* stamp_token_t; 599 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); 600 int64 stamp_token = stamp_token_t->scalar<int64>()(); 601 602 const Tensor* next_stamp_token_t; 603 OP_REQUIRES_OK(context, 604 context->input(kNextStampTokenName, &next_stamp_token_t)); 605 int64 next_stamp_token = next_stamp_token_t->scalar<int64>()(); 606 607 // If the stamp is invalid we restart the PS. It shouldn't happen since 608 // only Chief should call this function and chief is guaranteed to be in 609 // a consistent state. 610 CHECK(accumulator_resource->is_stamp_valid(stamp_token)); 611 CHECK(stamp_token != next_stamp_token); 612 SerializeTensorAccumulatorToOutput(*accumulator_resource, context); 613 Tensor* num_updates_t = nullptr; 614 OP_REQUIRES_OK(context, 615 context->allocate_output("num_updates", TensorShape({}), 616 &num_updates_t)); 617 num_updates_t->scalar<int64>()() = accumulator_resource->num_updates(); 618 accumulator_resource->Clear(); 619 accumulator_resource->set_stamp(next_stamp_token); 620 } 621 }; 622 623 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorTensorFlush").Device(DEVICE_CPU), 624 StatsAccumulatorTensorFlushOp); 625 626 class StatsAccumulatorScalarDeserializeOp : public OpKernel { 627 public: 628 explicit StatsAccumulatorScalarDeserializeOp(OpKernelConstruction* context) 629 : OpKernel(context) {} 630 631 void Compute(OpKernelContext* context) override { 632 StatsAccumulatorScalarResource* accumulator_resource; 633 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 634 &accumulator_resource)); 635 mutex_lock l(*accumulator_resource->mutex()); 636 core::ScopedUnref unref_me(accumulator_resource); 637 638 // Check the stamp token. 639 const Tensor* stamp_token_t; 640 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); 641 int64 stamp_token = stamp_token_t->scalar<int64>()(); 642 accumulator_resource->Clear(); 643 accumulator_resource->set_stamp(stamp_token); 644 AddToScalarAccumulator(accumulator_resource, context); 645 const Tensor* num_updates_t; 646 OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t)); 647 accumulator_resource->set_num_updates(num_updates_t->scalar<int64>()()); 648 } 649 }; 650 651 REGISTER_KERNEL_BUILDER( 652 Name("StatsAccumulatorScalarDeserialize").Device(DEVICE_CPU), 653 StatsAccumulatorScalarDeserializeOp); 654 655 class StatsAccumulatorTensorDeserializeOp : public OpKernel { 656 public: 657 explicit StatsAccumulatorTensorDeserializeOp(OpKernelConstruction* context) 658 : OpKernel(context) {} 659 660 void Compute(OpKernelContext* context) override { 661 StatsAccumulatorTensorResource* accumulator_resource; 662 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 663 &accumulator_resource)); 664 mutex_lock l(*accumulator_resource->mutex()); 665 core::ScopedUnref unref_me(accumulator_resource); 666 667 // Check the stamp token. 668 const Tensor* stamp_token_t; 669 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); 670 int64 stamp_token = stamp_token_t->scalar<int64>()(); 671 accumulator_resource->Clear(); 672 accumulator_resource->set_stamp(stamp_token); 673 AddToTensorAccumulator(accumulator_resource, context); 674 const Tensor* num_updates_t; 675 OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t)); 676 accumulator_resource->set_num_updates(num_updates_t->scalar<int64>()()); 677 } 678 }; 679 680 REGISTER_KERNEL_BUILDER( 681 Name("StatsAccumulatorTensorDeserialize").Device(DEVICE_CPU), 682 StatsAccumulatorTensorDeserializeOp); 683 684 class StatsAccumulatorScalarSerializeOp : public OpKernel { 685 public: 686 explicit StatsAccumulatorScalarSerializeOp(OpKernelConstruction* context) 687 : OpKernel(context) {} 688 689 void Compute(OpKernelContext* context) override { 690 StatsAccumulatorScalarResource* accumulator_resource; 691 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 692 &accumulator_resource)); 693 mutex_lock l(*accumulator_resource->mutex()); 694 core::ScopedUnref unref_me(accumulator_resource); 695 SerializeScalarAccumulatorToOutput(*accumulator_resource, context); 696 Tensor* stamp_token_t = nullptr; 697 OP_REQUIRES_OK(context, 698 context->allocate_output("stamp_token", TensorShape({}), 699 &stamp_token_t)); 700 stamp_token_t->scalar<int64>()() = accumulator_resource->stamp(); 701 702 Tensor* num_updates_t = nullptr; 703 OP_REQUIRES_OK(context, 704 context->allocate_output("num_updates", TensorShape({}), 705 &num_updates_t)); 706 num_updates_t->scalar<int64>()() = accumulator_resource->num_updates(); 707 } 708 }; 709 710 REGISTER_KERNEL_BUILDER( 711 Name("StatsAccumulatorScalarSerialize").Device(DEVICE_CPU), 712 StatsAccumulatorScalarSerializeOp); 713 714 class StatsAccumulatorTensorSerializeOp : public OpKernel { 715 public: 716 explicit StatsAccumulatorTensorSerializeOp(OpKernelConstruction* context) 717 : OpKernel(context) {} 718 719 void Compute(OpKernelContext* context) override { 720 StatsAccumulatorTensorResource* accumulator_resource; 721 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 722 &accumulator_resource)); 723 mutex_lock l(*accumulator_resource->mutex()); 724 core::ScopedUnref unref_me(accumulator_resource); 725 SerializeTensorAccumulatorToOutput(*accumulator_resource, context); 726 Tensor* stamp_token_t = nullptr; 727 OP_REQUIRES_OK(context, 728 context->allocate_output("stamp_token", TensorShape({}), 729 &stamp_token_t)); 730 stamp_token_t->scalar<int64>()() = accumulator_resource->stamp(); 731 732 Tensor* num_updates_t = nullptr; 733 OP_REQUIRES_OK(context, 734 context->allocate_output("num_updates", TensorShape({}), 735 &num_updates_t)); 736 num_updates_t->scalar<int64>()() = accumulator_resource->num_updates(); 737 } 738 }; 739 740 REGISTER_KERNEL_BUILDER( 741 Name("StatsAccumulatorTensorSerialize").Device(DEVICE_CPU), 742 StatsAccumulatorTensorSerializeOp); 743 744 class StatsAccumulatorScalarMakeSummaryOp : public OpKernel { 745 public: 746 explicit StatsAccumulatorScalarMakeSummaryOp(OpKernelConstruction* context) 747 : OpKernel(context) {} 748 749 void Compute(OpKernelContext* context) override { 750 TensorShape gradient_shape = TensorShape({}); 751 TensorShape hessian_shape = TensorShape({}); 752 StatsAccumulatorScalarResource* accumulator_resource = 753 new StatsAccumulatorScalarResource(gradient_shape, hessian_shape); 754 core::ScopedUnref unref_me(accumulator_resource); 755 // Check the stamp token. 756 AddToScalarAccumulator(accumulator_resource, context); 757 SerializeScalarAccumulatorToOutput(*accumulator_resource, context); 758 } 759 }; 760 761 REGISTER_KERNEL_BUILDER( 762 Name("StatsAccumulatorScalarMakeSummary").Device(DEVICE_CPU), 763 StatsAccumulatorScalarMakeSummaryOp); 764 765 class StatsAccumulatorTensorMakeSummaryOp : public OpKernel { 766 public: 767 explicit StatsAccumulatorTensorMakeSummaryOp(OpKernelConstruction* context) 768 : OpKernel(context) {} 769 770 void Compute(OpKernelContext* context) override { 771 const Tensor* gradients_t; 772 OP_REQUIRES_OK(context, context->input("gradients", &gradients_t)); 773 TensorShape gradients_shape = gradients_t->shape(); 774 gradients_shape.RemoveDim(0); 775 776 const Tensor* hessians_t; 777 OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); 778 TensorShape hessians_shape = hessians_t->shape(); 779 hessians_shape.RemoveDim(0); 780 781 StatsAccumulatorTensorResource* accumulator_resource = 782 new StatsAccumulatorTensorResource(gradients_shape, hessians_shape); 783 core::ScopedUnref unref_me(accumulator_resource); 784 // Check the stamp token. 785 AddToTensorAccumulator(accumulator_resource, context); 786 SerializeTensorAccumulatorToOutput(*accumulator_resource, context); 787 } 788 }; 789 790 REGISTER_KERNEL_BUILDER( 791 Name("StatsAccumulatorTensorMakeSummary").Device(DEVICE_CPU), 792 StatsAccumulatorTensorMakeSummaryOp); 793 794 } // namespace boosted_trees 795 } // namespace tensorflow 796