Home | History | Annotate | Download | only in kernels
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #include <algorithm>
     16 #include <iterator>
     17 #include <map>
     18 #include <string>
     19 #include <vector>
     20 
     21 #include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h"
     22 #include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h"
     23 #include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h"
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/framework/resource_mgr.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/tensor_shape.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/lib/core/errors.h"
     30 #include "tensorflow/core/lib/core/status.h"
     31 #include "tensorflow/core/platform/types.h"
     32 #include "tensorflow/core/util/work_sharder.h"
     33 
     34 namespace tensorflow {
     35 namespace boosted_trees {
     36 
     37 namespace {
     38 const char* const kStampTokenName = "stamp_token";
     39 const char* const kNextStampTokenName = "next_stamp_token";
     40 
     41 struct PartitionKey {
     42   PartitionKey() : partition_id(-1), feature_id(-1), dimension(-1) {}
     43 
     44   PartitionKey(int32 p, int64 f, int32 d)
     45       : partition_id(p), feature_id(f), dimension(d) {}
     46 
     47   bool operator==(const PartitionKey& other) const {
     48     return (partition_id == other.partition_id) &&
     49            (dimension == other.dimension) && (feature_id == other.feature_id);
     50   }
     51 
     52   // Compare for PartitionKey.
     53   struct Less {
     54     bool operator()(const PartitionKey& a, const PartitionKey& b) const {
     55       if (a.partition_id < b.partition_id) {
     56         return true;
     57       }
     58       if ((a.partition_id == b.partition_id) && (a.dimension < b.dimension)) {
     59         return true;
     60       }
     61       if ((a.partition_id == b.partition_id) && (a.dimension == b.dimension) &&
     62           (a.feature_id < b.feature_id)) {
     63         return true;
     64       }
     65       return false;
     66     }
     67   };
     68 
     69   // Tree partition defined by traversing the tree to the leaf.
     70   int32 partition_id;
     71 
     72   // Feature column id.
     73   int64 feature_id;
     74 
     75   // Dimension within feature column.
     76   int32 dimension;
     77 };
     78 
     79 template <typename GradientType, typename HessianType>
     80 class StatsAccumulatorResource : public boosted_trees::StampedResource {
     81   using StatsByPartition =
     82       std::map<PartitionKey, std::pair<GradientType, HessianType>,
     83                PartitionKey::Less>;
     84 
     85  public:
     86   StatsAccumulatorResource(const TensorShape& gradient_shape,
     87                            const TensorShape& hessian_shape)
     88       : gradient_shape_(gradient_shape),
     89         hessian_shape_(hessian_shape),
     90         num_updates_(0) {
     91     // If GradientType/HessianType is scalar float then the shapes should be
     92     // scalar and vice versa.
     93     CHECK_EQ((std::is_same<GradientType, float>::value),
     94              TensorShapeUtils::IsScalar(gradient_shape));
     95     CHECK_EQ((std::is_same<HessianType, float>::value),
     96              TensorShapeUtils::IsScalar(hessian_shape));
     97   }
     98 
     99   string DebugString() const override {
    100     return strings::StrCat("StatsAccumulatorResource[size=", values_.size(),
    101                            "]");
    102   }
    103 
    104   void Clear() {
    105     values_.clear();
    106     num_updates_ = 0;
    107   }
    108 
    109   tensorflow::mutex* mutex() { return &mu_; }
    110   StatsByPartition* mutable_values() { return &values_; }
    111   const StatsByPartition& values() const { return values_; }
    112   const int64& num_updates() const { return num_updates_; }
    113   void set_num_updates(int64 val) { num_updates_ = val; }
    114   const TensorShape& gradient_shape() const { return gradient_shape_; }
    115   const TensorShape& hessian_shape() const { return hessian_shape_; }
    116 
    117  private:
    118   // Key into a specific partition to accumulate stats for the specified feature
    119   // id.
    120   StatsByPartition values_;
    121   const TensorShape gradient_shape_;
    122   const TensorShape hessian_shape_;
    123   int64 num_updates_;
    124   tensorflow::mutex mu_;
    125   TF_DISALLOW_COPY_AND_ASSIGN(StatsAccumulatorResource);
    126 };
    127 
    128 using StatsAccumulatorScalarResource = StatsAccumulatorResource<float, float>;
    129 using StatsAccumulatorTensorResource =
    130     StatsAccumulatorResource<std::vector<float>, std::vector<float>>;
    131 
    132 void SerializeScalarAccumulatorToOutput(
    133     const StatsAccumulatorScalarResource& accumulator_resource,
    134     OpKernelContext* context) {
    135   int64 num_slots = accumulator_resource.values().size();
    136   Tensor* partition_ids_t = nullptr;
    137   OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
    138                                                    TensorShape({num_slots}),
    139                                                    &partition_ids_t));
    140   auto partition_ids = partition_ids_t->vec<int32>();
    141 
    142   // Feature ids tensor has ids of feature columns and their dimensions.
    143   Tensor* feature_ids_t = nullptr;
    144   OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids",
    145                                                    TensorShape({num_slots, 2}),
    146                                                    &feature_ids_t));
    147   auto feature_ids = feature_ids_t->matrix<int64>();
    148 
    149   Tensor* gradients_t = nullptr;
    150   OP_REQUIRES_OK(
    151       context, context->allocate_output(
    152                    "output_gradients", TensorShape({num_slots}), &gradients_t));
    153   auto gradients = gradients_t->vec<float>();
    154 
    155   Tensor* hessians_t = nullptr;
    156   OP_REQUIRES_OK(
    157       context, context->allocate_output("output_hessians",
    158                                         TensorShape({num_slots}), &hessians_t));
    159   auto hessians = hessians_t->vec<float>();
    160 
    161   int i = 0;
    162   for (const auto& iter : accumulator_resource.values()) {
    163     partition_ids(i) = iter.first.partition_id;
    164     feature_ids(i, 0) = iter.first.feature_id;
    165     feature_ids(i, 1) = iter.first.dimension;
    166 
    167     gradients(i) = iter.second.first;
    168     hessians(i) = iter.second.second;
    169     ++i;
    170   }
    171 }
    172 
    173 void SerializeTensorAccumulatorToOutput(
    174     const StatsAccumulatorTensorResource& accumulator_resource,
    175     OpKernelContext* context) {
    176   int64 num_slots = accumulator_resource.values().size();
    177   Tensor* partition_ids_t = nullptr;
    178   OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
    179                                                    TensorShape({num_slots}),
    180                                                    &partition_ids_t));
    181   auto partition_ids = partition_ids_t->vec<int32>();
    182 
    183   Tensor* feature_ids_t = nullptr;
    184   OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids",
    185                                                    TensorShape({num_slots, 2}),
    186                                                    &feature_ids_t));
    187   auto feature_ids = feature_ids_t->matrix<int64>();
    188 
    189   TensorShape gradient_shape = accumulator_resource.gradient_shape();
    190   int64 num_gradient_elements = gradient_shape.num_elements();
    191   gradient_shape.InsertDim(0, num_slots);
    192   Tensor* gradients_t = nullptr;
    193   OP_REQUIRES_OK(context,
    194                  context->allocate_output("output_gradients", gradient_shape,
    195                                           &gradients_t));
    196   auto gradients = gradients_t->flat_outer_dims<float>();
    197 
    198   TensorShape hessian_shape = accumulator_resource.hessian_shape();
    199   int64 num_hessian_elements = hessian_shape.num_elements();
    200   hessian_shape.InsertDim(0, num_slots);
    201   Tensor* hessians_t = nullptr;
    202   OP_REQUIRES_OK(context, context->allocate_output("output_hessians",
    203                                                    hessian_shape, &hessians_t));
    204   auto hessians = hessians_t->flat_outer_dims<float>();
    205 
    206   int i = 0;
    207   for (const auto& iter : accumulator_resource.values()) {
    208     partition_ids(i) = iter.first.partition_id;
    209     feature_ids(i, 0) = iter.first.feature_id;
    210     feature_ids(i, 1) = iter.first.dimension;
    211 
    212     for (int j = 0; j < num_gradient_elements; ++j) {
    213       gradients(i, j) = iter.second.first[j];
    214     }
    215     for (int j = 0; j < num_hessian_elements; ++j) {
    216       hessians(i, j) = iter.second.second[j];
    217     }
    218     ++i;
    219   }
    220 }
    221 
    222 void AddToScalarAccumulator(
    223     StatsAccumulatorScalarResource* accumulator_resource,
    224     const Tensor& partition_ids_t, const Tensor& feature_ids_t,
    225     const Tensor& gradients_t, const Tensor& hessians_t) {
    226   accumulator_resource->set_num_updates(accumulator_resource->num_updates() +
    227                                         1);
    228   const TensorShape& partition_ids_shape = partition_ids_t.shape();
    229   const auto& partition_ids = partition_ids_t.vec<int32>();
    230   const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>();
    231   const auto& gradients = gradients_t.vec<float>();
    232   const auto& hessians = hessians_t.vec<float>();
    233 
    234   int64 num_updates = partition_ids_shape.dim_size(0);
    235   auto stats_map = accumulator_resource->mutable_values();
    236   for (int64 i = 0; i < num_updates; ++i) {
    237     const auto key =
    238         PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
    239                      feature_ids_and_dimensions(i, 1));
    240     auto itr = stats_map->find(key);
    241     if (itr != stats_map->end()) {
    242       itr->second.first += gradients(i);
    243       itr->second.second += hessians(i);
    244     } else {
    245       (*stats_map)[key] = {gradients(i), hessians(i)};
    246     }
    247   }
    248 }
    249 
    250 void AddToScalarAccumulator(
    251     StatsAccumulatorScalarResource* accumulator_resource,
    252     OpKernelContext* context) {
    253   const Tensor* partition_ids_t;
    254   OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
    255   const Tensor* feature_ids_t;
    256   OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t));
    257   const Tensor* gradients_t;
    258   OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
    259   const Tensor* hessians_t;
    260   OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
    261   AddToScalarAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t,
    262                          *gradients_t, *hessians_t);
    263 }
    264 
    265 void AddToTensorAccumulator(
    266     StatsAccumulatorTensorResource* accumulator_resource,
    267     const Tensor& partition_ids_t, const Tensor& feature_ids_t,
    268     const Tensor& gradients_t, const Tensor& hessians_t,
    269     OpKernelContext* context) {
    270   accumulator_resource->set_num_updates(accumulator_resource->num_updates() +
    271                                         1);
    272 
    273   const TensorShape& partition_ids_shape = partition_ids_t.shape();
    274   const auto& partition_ids = partition_ids_t.vec<int32>();
    275   const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>();
    276   TensorShape gradients_shape = gradients_t.shape();
    277   const auto& gradients = gradients_t.flat_outer_dims<float>();
    278   TensorShape hessians_shape = hessians_t.shape();
    279   const auto& hessians = hessians_t.flat_outer_dims<float>();
    280 
    281   gradients_shape.RemoveDim(0);
    282   hessians_shape.RemoveDim(0);
    283 
    284   // TODO(soroush): Move gradient and hessian shape check to ShapeFn.
    285   OP_REQUIRES(
    286       context, gradients_shape == accumulator_resource->gradient_shape(),
    287       errors::InvalidArgument(strings::StrCat(
    288           "Gradients dimensions must match: ", gradients_shape.DebugString(),
    289           ", ", accumulator_resource->gradient_shape().DebugString())));
    290 
    291   OP_REQUIRES(
    292       context, hessians_shape == accumulator_resource->hessian_shape(),
    293       errors::InvalidArgument(strings::StrCat(
    294           "Hessian dimensions must match: ", hessians_shape.DebugString(), ", ",
    295           accumulator_resource->hessian_shape().DebugString())));
    296 
    297   int64 num_updates = partition_ids_shape.dim_size(0);
    298   auto stats_map = accumulator_resource->mutable_values();
    299   for (int64 i = 0; i < num_updates; ++i) {
    300     const auto key =
    301         PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
    302                      feature_ids_and_dimensions(i, 1));
    303     auto itr = stats_map->find(key);
    304     if (itr == stats_map->end()) {
    305       std::vector<float> new_gradients(gradients_shape.num_elements());
    306       for (int j = 0; j < gradients_shape.num_elements(); ++j) {
    307         new_gradients[j] = gradients(i, j);
    308       }
    309       std::vector<float> new_hessians(hessians_shape.num_elements());
    310       for (int j = 0; j < hessians_shape.num_elements(); ++j) {
    311         new_hessians[j] = hessians(i, j);
    312       }
    313       (*stats_map)[key] = {new_gradients, new_hessians};
    314     } else {
    315       auto& stored_gradients = itr->second.first;
    316       for (int j = 0; j < gradients_shape.num_elements(); ++j) {
    317         stored_gradients[j] += gradients(i, j);
    318       }
    319       auto& stored_hessians = itr->second.second;
    320       for (int j = 0; j < hessians_shape.num_elements(); ++j) {
    321         stored_hessians[j] += hessians(i, j);
    322       }
    323     }
    324   }
    325 }
    326 
    327 void AddToTensorAccumulator(
    328     StatsAccumulatorTensorResource* accumulator_resource,
    329     OpKernelContext* context) {
    330   const Tensor* partition_ids_t;
    331   OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
    332   const Tensor* feature_ids_t;
    333   OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t));
    334   const Tensor* gradients_t;
    335   OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
    336   const Tensor* hessians_t;
    337   OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
    338   AddToTensorAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t,
    339                          *gradients_t, *hessians_t, context);
    340 }
    341 
    342 }  // namespace
    343 
    344 REGISTER_RESOURCE_HANDLE_KERNEL(StatsAccumulatorScalarResource);
    345 REGISTER_RESOURCE_HANDLE_KERNEL(StatsAccumulatorTensorResource);
    346 
    347 REGISTER_KERNEL_BUILDER(
    348     Name("StatsAccumulatorScalarIsInitialized").Device(DEVICE_CPU),
    349     IsResourceInitialized<StatsAccumulatorScalarResource>);
    350 
    351 REGISTER_KERNEL_BUILDER(
    352     Name("StatsAccumulatorTensorIsInitialized").Device(DEVICE_CPU),
    353     IsResourceInitialized<StatsAccumulatorTensorResource>);
    354 
    355 class CreateStatsAccumulatorScalarOp : public OpKernel {
    356  public:
    357   explicit CreateStatsAccumulatorScalarOp(OpKernelConstruction* context)
    358       : OpKernel(context) {}
    359 
    360   void Compute(OpKernelContext* context) override {
    361     const Tensor* stamp_token_t;
    362     OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
    363 
    364     TensorShape gradient_shape = TensorShape({});
    365     TensorShape hessian_shape = TensorShape({});
    366 
    367     auto* result =
    368         new StatsAccumulatorScalarResource(gradient_shape, hessian_shape);
    369     result->set_stamp(stamp_token_t->scalar<int64>()());
    370     // Only create one, if one does not exist already. Report status for all
    371     // other exceptions. If one already exists, it unrefs the new one.
    372     auto status = CreateResource(context, HandleFromInput(context, 0), result);
    373     if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
    374       OP_REQUIRES(context, false, status);
    375     }
    376   }
    377 };
    378 
    379 REGISTER_KERNEL_BUILDER(Name("CreateStatsAccumulatorScalar").Device(DEVICE_CPU),
    380                         CreateStatsAccumulatorScalarOp);
    381 
    382 class CreateStatsAccumulatorTensorOp : public OpKernel {
    383  public:
    384   explicit CreateStatsAccumulatorTensorOp(OpKernelConstruction* context)
    385       : OpKernel(context) {}
    386 
    387   void Compute(OpKernelContext* context) override {
    388     const Tensor* stamp_token_t;
    389     OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
    390 
    391     const Tensor* gradient_shape_t;
    392     OP_REQUIRES_OK(
    393         context, context->input("per_slot_gradient_shape", &gradient_shape_t));
    394 
    395     const Tensor* hessian_shape_t;
    396     OP_REQUIRES_OK(context,
    397                    context->input("per_slot_hessian_shape", &hessian_shape_t));
    398     TensorShape gradient_shape = TensorShape(gradient_shape_t->vec<int64>());
    399     TensorShape hessian_shape = TensorShape(hessian_shape_t->vec<int64>());
    400     auto* result =
    401         new StatsAccumulatorTensorResource(gradient_shape, hessian_shape);
    402     result->set_stamp(stamp_token_t->scalar<int64>()());
    403 
    404     // Only create one, if one does not exist already. Report status for all
    405     // other exceptions. If one already exists, it unrefs the new one.
    406     auto status = CreateResource(context, HandleFromInput(context, 0), result);
    407     if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
    408       OP_REQUIRES(context, false, status);
    409     }
    410   }
    411 };
    412 
    413 REGISTER_KERNEL_BUILDER(Name("CreateStatsAccumulatorTensor").Device(DEVICE_CPU),
    414                         CreateStatsAccumulatorTensorOp);
    415 
    416 class StatsAccumulatorScalarAddOp : public OpKernel {
    417  public:
    418   explicit StatsAccumulatorScalarAddOp(OpKernelConstruction* context)
    419       : OpKernel(context) {}
    420 
    421   void Compute(OpKernelContext* context) override {
    422     OpInputList resource_handle_list;
    423     OP_REQUIRES_OK(context, context->input_list("stats_accumulator_handles",
    424                                                 &resource_handle_list));
    425     OpInputList partition_ids_list;
    426     OP_REQUIRES_OK(context,
    427                    context->input_list("partition_ids", &partition_ids_list));
    428 
    429     OpInputList feature_ids_list;
    430     OP_REQUIRES_OK(context,
    431                    context->input_list("feature_ids", &feature_ids_list));
    432     OpInputList gradients_list;
    433     OP_REQUIRES_OK(context, context->input_list("gradients", &gradients_list));
    434     OpInputList hessians_list;
    435     OP_REQUIRES_OK(context, context->input_list("hessians", &hessians_list));
    436 
    437     const Tensor* stamp_token_t;
    438     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
    439     int64 stamp_token = stamp_token_t->scalar<int64>()();
    440 
    441     thread::ThreadPool* const worker_threads =
    442         context->device()->tensorflow_cpu_worker_threads()->workers;
    443     boosted_trees::utils::ParallelFor(
    444         resource_handle_list.size(), worker_threads->NumThreads(),
    445         worker_threads,
    446         [&context, &resource_handle_list, &partition_ids_list,
    447          &feature_ids_list, &gradients_list, &hessians_list,
    448          stamp_token](int64 start, int64 end) {
    449           for (int resource_handle_idx = start; resource_handle_idx < end;
    450                ++resource_handle_idx) {
    451             const ResourceHandle& handle =
    452                 resource_handle_list[resource_handle_idx]
    453                     .flat<ResourceHandle>()(0);
    454 
    455             StatsAccumulatorScalarResource* accumulator_resource;
    456             OP_REQUIRES_OK(context, LookupResource(context, handle,
    457                                                    &accumulator_resource));
    458             mutex_lock l(*accumulator_resource->mutex());
    459             core::ScopedUnref unref_me(accumulator_resource);
    460 
    461             // If the stamp is invalid we drop the update.
    462             if (!accumulator_resource->is_stamp_valid(stamp_token)) {
    463               VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
    464                       << "Passed stamp token: " << stamp_token << " "
    465                       << "Current token: " << accumulator_resource->stamp();
    466               return;
    467             }
    468             AddToScalarAccumulator(accumulator_resource,
    469                                    partition_ids_list[resource_handle_idx],
    470                                    feature_ids_list[resource_handle_idx],
    471                                    gradients_list[resource_handle_idx],
    472                                    hessians_list[resource_handle_idx]);
    473           }
    474         });
    475   }
    476 };
    477 
    478 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorScalarAdd").Device(DEVICE_CPU),
    479                         StatsAccumulatorScalarAddOp);
    480 
    481 class StatsAccumulatorTensorAddOp : public OpKernel {
    482  public:
    483   explicit StatsAccumulatorTensorAddOp(OpKernelConstruction* context)
    484       : OpKernel(context) {}
    485 
    486   void Compute(OpKernelContext* context) override {
    487     OpInputList resource_handle_list;
    488     OP_REQUIRES_OK(context, context->input_list("stats_accumulator_handles",
    489                                                 &resource_handle_list));
    490     OpInputList partition_ids_list;
    491     OP_REQUIRES_OK(context,
    492                    context->input_list("partition_ids", &partition_ids_list));
    493 
    494     OpInputList feature_ids_list;
    495     OP_REQUIRES_OK(context,
    496                    context->input_list("feature_ids", &feature_ids_list));
    497     OpInputList gradients_list;
    498     OP_REQUIRES_OK(context, context->input_list("gradients", &gradients_list));
    499     OpInputList hessians_list;
    500     OP_REQUIRES_OK(context, context->input_list("hessians", &hessians_list));
    501 
    502     const Tensor* stamp_token_t;
    503     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
    504     int64 stamp_token = stamp_token_t->scalar<int64>()();
    505 
    506     thread::ThreadPool* const worker_threads =
    507         context->device()->tensorflow_cpu_worker_threads()->workers;
    508     boosted_trees::utils::ParallelFor(
    509         resource_handle_list.size(), worker_threads->NumThreads(),
    510         worker_threads,
    511         [&context, &resource_handle_list, &partition_ids_list,
    512          &feature_ids_list, &gradients_list, &hessians_list,
    513          stamp_token](int64 start, int64 end) {
    514           for (int resource_handle_idx = start; resource_handle_idx < end;
    515                ++resource_handle_idx) {
    516             const ResourceHandle& handle =
    517                 resource_handle_list[resource_handle_idx]
    518                     .flat<ResourceHandle>()(0);
    519 
    520             StatsAccumulatorTensorResource* accumulator_resource;
    521             OP_REQUIRES_OK(context, LookupResource(context, handle,
    522                                                    &accumulator_resource));
    523             mutex_lock l(*accumulator_resource->mutex());
    524             core::ScopedUnref unref_me(accumulator_resource);
    525 
    526             // If the stamp is invalid we drop the update.
    527             if (!accumulator_resource->is_stamp_valid(stamp_token)) {
    528               VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
    529                       << "Passed stamp token: " << stamp_token << " "
    530                       << "Current token: " << accumulator_resource->stamp();
    531               return;
    532             }
    533             AddToTensorAccumulator(accumulator_resource,
    534                                    partition_ids_list[resource_handle_idx],
    535                                    feature_ids_list[resource_handle_idx],
    536                                    gradients_list[resource_handle_idx],
    537                                    hessians_list[resource_handle_idx], context);
    538           }
    539         });
    540   }
    541 };
    542 
    543 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorTensorAdd").Device(DEVICE_CPU),
    544                         StatsAccumulatorTensorAddOp);
    545 
    546 class StatsAccumulatorScalarFlushOp : public OpKernel {
    547  public:
    548   explicit StatsAccumulatorScalarFlushOp(OpKernelConstruction* context)
    549       : OpKernel(context) {}
    550 
    551   void Compute(OpKernelContext* context) override {
    552     StatsAccumulatorScalarResource* accumulator_resource;
    553     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
    554                                            &accumulator_resource));
    555     mutex_lock l(*accumulator_resource->mutex());
    556     core::ScopedUnref unref_me(accumulator_resource);
    557 
    558     const Tensor* stamp_token_t;
    559     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
    560     int64 stamp_token = stamp_token_t->scalar<int64>()();
    561 
    562     // If the stamp is invalid we restart the PS. It shouldn't happen since
    563     // only Chief should call this function and chief is guaranteed to be in
    564     // a consistent state.
    565     CHECK(accumulator_resource->is_stamp_valid(stamp_token));
    566 
    567     const Tensor* next_stamp_token_t;
    568     OP_REQUIRES_OK(context,
    569                    context->input(kNextStampTokenName, &next_stamp_token_t));
    570     int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
    571     CHECK(stamp_token != next_stamp_token);
    572 
    573     SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
    574     Tensor* num_updates_t = nullptr;
    575     OP_REQUIRES_OK(context,
    576                    context->allocate_output("num_updates", TensorShape({}),
    577                                             &num_updates_t));
    578     num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
    579 
    580     accumulator_resource->Clear();
    581     accumulator_resource->set_stamp(next_stamp_token);
    582   }
    583 };
    584 
    585 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorScalarFlush").Device(DEVICE_CPU),
    586                         StatsAccumulatorScalarFlushOp);
    587 
    588 class StatsAccumulatorTensorFlushOp : public OpKernel {
    589  public:
    590   explicit StatsAccumulatorTensorFlushOp(OpKernelConstruction* context)
    591       : OpKernel(context) {}
    592 
    593   void Compute(OpKernelContext* context) override {
    594     StatsAccumulatorTensorResource* accumulator_resource;
    595     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
    596                                            &accumulator_resource));
    597     mutex_lock l(*accumulator_resource->mutex());
    598     core::ScopedUnref unref_me(accumulator_resource);
    599 
    600     const Tensor* stamp_token_t;
    601     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
    602     int64 stamp_token = stamp_token_t->scalar<int64>()();
    603 
    604     const Tensor* next_stamp_token_t;
    605     OP_REQUIRES_OK(context,
    606                    context->input(kNextStampTokenName, &next_stamp_token_t));
    607     int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
    608 
    609     // If the stamp is invalid we restart the PS. It shouldn't happen since
    610     // only Chief should call this function and chief is guaranteed to be in
    611     // a consistent state.
    612     CHECK(accumulator_resource->is_stamp_valid(stamp_token));
    613     CHECK(stamp_token != next_stamp_token);
    614     SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
    615     Tensor* num_updates_t = nullptr;
    616     OP_REQUIRES_OK(context,
    617                    context->allocate_output("num_updates", TensorShape({}),
    618                                             &num_updates_t));
    619     num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
    620     accumulator_resource->Clear();
    621     accumulator_resource->set_stamp(next_stamp_token);
    622   }
    623 };
    624 
    625 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorTensorFlush").Device(DEVICE_CPU),
    626                         StatsAccumulatorTensorFlushOp);
    627 
    628 class StatsAccumulatorScalarDeserializeOp : public OpKernel {
    629  public:
    630   explicit StatsAccumulatorScalarDeserializeOp(OpKernelConstruction* context)
    631       : OpKernel(context) {}
    632 
    633   void Compute(OpKernelContext* context) override {
    634     StatsAccumulatorScalarResource* accumulator_resource;
    635     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
    636                                            &accumulator_resource));
    637     mutex_lock l(*accumulator_resource->mutex());
    638     core::ScopedUnref unref_me(accumulator_resource);
    639 
    640     // Check the stamp token.
    641     const Tensor* stamp_token_t;
    642     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
    643     int64 stamp_token = stamp_token_t->scalar<int64>()();
    644     accumulator_resource->Clear();
    645     accumulator_resource->set_stamp(stamp_token);
    646     AddToScalarAccumulator(accumulator_resource, context);
    647     const Tensor* num_updates_t;
    648     OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
    649     accumulator_resource->set_num_updates(num_updates_t->scalar<int64>()());
    650   }
    651 };
    652 
    653 REGISTER_KERNEL_BUILDER(
    654     Name("StatsAccumulatorScalarDeserialize").Device(DEVICE_CPU),
    655     StatsAccumulatorScalarDeserializeOp);
    656 
    657 class StatsAccumulatorTensorDeserializeOp : public OpKernel {
    658  public:
    659   explicit StatsAccumulatorTensorDeserializeOp(OpKernelConstruction* context)
    660       : OpKernel(context) {}
    661 
    662   void Compute(OpKernelContext* context) override {
    663     StatsAccumulatorTensorResource* accumulator_resource;
    664     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
    665                                            &accumulator_resource));
    666     mutex_lock l(*accumulator_resource->mutex());
    667     core::ScopedUnref unref_me(accumulator_resource);
    668 
    669     // Check the stamp token.
    670     const Tensor* stamp_token_t;
    671     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
    672     int64 stamp_token = stamp_token_t->scalar<int64>()();
    673     accumulator_resource->Clear();
    674     accumulator_resource->set_stamp(stamp_token);
    675     AddToTensorAccumulator(accumulator_resource, context);
    676     const Tensor* num_updates_t;
    677     OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
    678     accumulator_resource->set_num_updates(num_updates_t->scalar<int64>()());
    679   }
    680 };
    681 
    682 REGISTER_KERNEL_BUILDER(
    683     Name("StatsAccumulatorTensorDeserialize").Device(DEVICE_CPU),
    684     StatsAccumulatorTensorDeserializeOp);
    685 
    686 class StatsAccumulatorScalarSerializeOp : public OpKernel {
    687  public:
    688   explicit StatsAccumulatorScalarSerializeOp(OpKernelConstruction* context)
    689       : OpKernel(context) {}
    690 
    691   void Compute(OpKernelContext* context) override {
    692     StatsAccumulatorScalarResource* accumulator_resource;
    693     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
    694                                            &accumulator_resource));
    695     mutex_lock l(*accumulator_resource->mutex());
    696     core::ScopedUnref unref_me(accumulator_resource);
    697     SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
    698     Tensor* stamp_token_t = nullptr;
    699     OP_REQUIRES_OK(context,
    700                    context->allocate_output("stamp_token", TensorShape({}),
    701                                             &stamp_token_t));
    702     stamp_token_t->scalar<int64>()() = accumulator_resource->stamp();
    703 
    704     Tensor* num_updates_t = nullptr;
    705     OP_REQUIRES_OK(context,
    706                    context->allocate_output("num_updates", TensorShape({}),
    707                                             &num_updates_t));
    708     num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
    709   }
    710 };
    711 
    712 REGISTER_KERNEL_BUILDER(
    713     Name("StatsAccumulatorScalarSerialize").Device(DEVICE_CPU),
    714     StatsAccumulatorScalarSerializeOp);
    715 
    716 class StatsAccumulatorTensorSerializeOp : public OpKernel {
    717  public:
    718   explicit StatsAccumulatorTensorSerializeOp(OpKernelConstruction* context)
    719       : OpKernel(context) {}
    720 
    721   void Compute(OpKernelContext* context) override {
    722     StatsAccumulatorTensorResource* accumulator_resource;
    723     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
    724                                            &accumulator_resource));
    725     mutex_lock l(*accumulator_resource->mutex());
    726     core::ScopedUnref unref_me(accumulator_resource);
    727     SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
    728     Tensor* stamp_token_t = nullptr;
    729     OP_REQUIRES_OK(context,
    730                    context->allocate_output("stamp_token", TensorShape({}),
    731                                             &stamp_token_t));
    732     stamp_token_t->scalar<int64>()() = accumulator_resource->stamp();
    733 
    734     Tensor* num_updates_t = nullptr;
    735     OP_REQUIRES_OK(context,
    736                    context->allocate_output("num_updates", TensorShape({}),
    737                                             &num_updates_t));
    738     num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
    739   }
    740 };
    741 
    742 REGISTER_KERNEL_BUILDER(
    743     Name("StatsAccumulatorTensorSerialize").Device(DEVICE_CPU),
    744     StatsAccumulatorTensorSerializeOp);
    745 
    746 class StatsAccumulatorScalarMakeSummaryOp : public OpKernel {
    747  public:
    748   explicit StatsAccumulatorScalarMakeSummaryOp(OpKernelConstruction* context)
    749       : OpKernel(context) {}
    750 
    751   void Compute(OpKernelContext* context) override {
    752     TensorShape gradient_shape = TensorShape({});
    753     TensorShape hessian_shape = TensorShape({});
    754     StatsAccumulatorScalarResource* accumulator_resource =
    755         new StatsAccumulatorScalarResource(gradient_shape, hessian_shape);
    756     core::ScopedUnref unref_me(accumulator_resource);
    757     // Check the stamp token.
    758     AddToScalarAccumulator(accumulator_resource, context);
    759     SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
    760   }
    761 };
    762 
    763 REGISTER_KERNEL_BUILDER(
    764     Name("StatsAccumulatorScalarMakeSummary").Device(DEVICE_CPU),
    765     StatsAccumulatorScalarMakeSummaryOp);
    766 
    767 class StatsAccumulatorTensorMakeSummaryOp : public OpKernel {
    768  public:
    769   explicit StatsAccumulatorTensorMakeSummaryOp(OpKernelConstruction* context)
    770       : OpKernel(context) {}
    771 
    772   void Compute(OpKernelContext* context) override {
    773     const Tensor* gradients_t;
    774     OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
    775     TensorShape gradients_shape = gradients_t->shape();
    776     gradients_shape.RemoveDim(0);
    777 
    778     const Tensor* hessians_t;
    779     OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
    780     TensorShape hessians_shape = hessians_t->shape();
    781     hessians_shape.RemoveDim(0);
    782 
    783     StatsAccumulatorTensorResource* accumulator_resource =
    784         new StatsAccumulatorTensorResource(gradients_shape, hessians_shape);
    785     core::ScopedUnref unref_me(accumulator_resource);
    786     // Check the stamp token.
    787     AddToTensorAccumulator(accumulator_resource, context);
    788     SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
    789   }
    790 };
    791 
    792 REGISTER_KERNEL_BUILDER(
    793     Name("StatsAccumulatorTensorMakeSummary").Device(DEVICE_CPU),
    794     StatsAccumulatorTensorMakeSummaryOp);
    795 
    796 }  // namespace boosted_trees
    797 }  // namespace tensorflow
    798