Home | History | Annotate | Download | only in kernels
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #include <algorithm>
     16 #include <iterator>
     17 #include <map>
     18 #include <string>
     19 #include <vector>
     20 
     21 #include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h"
     22 #include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h"
     23 #include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h"
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/framework/resource_mgr.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/tensor_shape.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/lib/core/errors.h"
     30 #include "tensorflow/core/lib/core/status.h"
     31 #include "tensorflow/core/platform/types.h"
     32 #include "tensorflow/core/util/work_sharder.h"
     33 
     34 namespace tensorflow {
     35 namespace boosted_trees {
     36 
     37 namespace {
     38 const char* const kStampTokenName = "stamp_token";
     39 const char* const kNextStampTokenName = "next_stamp_token";
     40 
     41 struct PartitionKey {
     42   PartitionKey() : partition_id(-1), feature_id(-1), dimension(-1) {}
     43 
     44   PartitionKey(int32 p, int64 f, int32 d)
     45       : partition_id(p), feature_id(f), dimension(d) {}
     46 
     47   bool operator==(const PartitionKey& other) const {
     48     return (partition_id == other.partition_id) &&
     49            (dimension == other.dimension) && (feature_id == other.feature_id);
     50   }
     51 
     52   // Compare for PartitionKey.
     53   struct Less {
     54     bool operator()(const PartitionKey& a, const PartitionKey& b) const {
     55       if (a.partition_id < b.partition_id) {
     56         return true;
     57       }
     58       if ((a.partition_id == b.partition_id) && (a.dimension < b.dimension)) {
     59         return true;
     60       }
     61       if ((a.partition_id == b.partition_id) && (a.dimension == b.dimension) &&
     62           (a.feature_id < b.feature_id)) {
     63         return true;
     64       }
     65       return false;
     66     }
     67   };
     68 
     69   // Tree partition defined by traversing the tree to the leaf.
     70   int32 partition_id;
     71 
     72   // Feature column id.
     73   int64 feature_id;
     74 
     75   // Dimension within feature column.
     76   int32 dimension;
     77 };
     78 
     79 template <typename GradientType, typename HessianType>
     80 class StatsAccumulatorResource : public boosted_trees::StampedResource {
     81   using StatsByPartition =
     82       std::map<PartitionKey, std::pair<GradientType, HessianType>,
     83                PartitionKey::Less>;
     84 
     85  public:
     86   StatsAccumulatorResource(const TensorShape& gradient_shape,
     87                            const TensorShape& hessian_shape)
     88       : gradient_shape_(gradient_shape),
     89         hessian_shape_(hessian_shape),
     90         num_updates_(0) {
     91     // If GradientType/HessianType is scalar float then the shapes should be
     92     // scalar and vice versa.
     93     CHECK_EQ((std::is_same<GradientType, float>::value),
     94              TensorShapeUtils::IsScalar(gradient_shape));
     95     CHECK_EQ((std::is_same<HessianType, float>::value),
     96              TensorShapeUtils::IsScalar(hessian_shape));
     97   }
     98 
     99   string DebugString() override {
    100     return strings::StrCat("StatsAccumulatorResource[size=", values_.size(),
    101                            "]");
    102   }
    103 
    104   void Clear() {
    105     values_.clear();
    106     num_updates_ = 0;
    107   }
    108 
    109   tensorflow::mutex* mutex() { return &mu_; }
    110   StatsByPartition* mutable_values() { return &values_; }
    111   const StatsByPartition& values() const { return values_; }
    112   const int64& num_updates() const { return num_updates_; }
    113   void set_num_updates(int64 val) { num_updates_ = val; }
    114   const TensorShape& gradient_shape() const { return gradient_shape_; }
    115   const TensorShape& hessian_shape() const { return hessian_shape_; }
    116 
    117  private:
    118   // Key into a specific partition to accumulate stats for the specified feature
    119   // id.
    120   StatsByPartition values_;
    121   const TensorShape gradient_shape_;
    122   const TensorShape hessian_shape_;
    123   int64 num_updates_;
    124   tensorflow::mutex mu_;
    125   TF_DISALLOW_COPY_AND_ASSIGN(StatsAccumulatorResource);
    126 };
    127 
    128 using StatsAccumulatorScalarResource = StatsAccumulatorResource<float, float>;
    129 using StatsAccumulatorTensorResource =
    130     StatsAccumulatorResource<std::vector<float>, std::vector<float>>;
    131 
    132 void SerializeScalarAccumulatorToOutput(
    133     const StatsAccumulatorScalarResource& accumulator_resource,
    134     OpKernelContext* context) {
    135   int64 num_slots = accumulator_resource.values().size();
    136   Tensor* partition_ids_t = nullptr;
    137   OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
    138                                                    TensorShape({num_slots}),
    139                                                    &partition_ids_t));
    140   auto partition_ids = partition_ids_t->vec<int32>();
    141 
    142   // Feature ids tensor has ids of feature columns and their dimensions.
    143   Tensor* feature_ids_t = nullptr;
    144   OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids",
    145                                                    TensorShape({num_slots, 2}),
    146                                                    &feature_ids_t));
    147   auto feature_ids = feature_ids_t->matrix<int64>();
    148 
    149   Tensor* gradients_t = nullptr;
    150   OP_REQUIRES_OK(
    151       context, context->allocate_output(
    152                    "output_gradients", TensorShape({num_slots}), &gradients_t));
    153   auto gradients = gradients_t->vec<float>();
    154 
    155   Tensor* hessians_t = nullptr;
    156   OP_REQUIRES_OK(
    157       context, context->allocate_output("output_hessians",
    158                                         TensorShape({num_slots}), &hessians_t));
    159   auto hessians = hessians_t->vec<float>();
    160 
    161   int i = 0;
    162   for (const auto& iter : accumulator_resource.values()) {
    163     partition_ids(i) = iter.first.partition_id;
    164     feature_ids(i, 0) = iter.first.feature_id;
    165     feature_ids(i, 1) = iter.first.dimension;
    166 
    167     gradients(i) = iter.second.first;
    168     hessians(i) = iter.second.second;
    169     ++i;
    170   }
    171 }
    172 
    173 void SerializeTensorAccumulatorToOutput(
    174     const StatsAccumulatorTensorResource& accumulator_resource,
    175     OpKernelContext* context) {
    176   int64 num_slots = accumulator_resource.values().size();
    177   Tensor* partition_ids_t = nullptr;
    178   OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
    179                                                    TensorShape({num_slots}),
    180                                                    &partition_ids_t));
    181   auto partition_ids = partition_ids_t->vec<int32>();
    182 
    183   Tensor* feature_ids_t = nullptr;
    184   OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids",
    185                                                    TensorShape({num_slots, 2}),
    186                                                    &feature_ids_t));
    187   auto feature_ids = feature_ids_t->matrix<int64>();
    188 
    189   TensorShape gradient_shape = accumulator_resource.gradient_shape();
    190   int64 num_gradient_elements = gradient_shape.num_elements();
    191   gradient_shape.InsertDim(0, num_slots);
    192   Tensor* gradients_t = nullptr;
    193   OP_REQUIRES_OK(context,
    194                  context->allocate_output("output_gradients", gradient_shape,
    195                                           &gradients_t));
    196   auto gradients = gradients_t->flat_outer_dims<float>();
    197 
    198   TensorShape hessian_shape = accumulator_resource.hessian_shape();
    199   int64 num_hessian_elements = hessian_shape.num_elements();
    200   hessian_shape.InsertDim(0, num_slots);
    201   Tensor* hessians_t = nullptr;
    202   OP_REQUIRES_OK(context, context->allocate_output("output_hessians",
    203                                                    hessian_shape, &hessians_t));
    204   auto hessians = hessians_t->flat_outer_dims<float>();
    205 
    206   int i = 0;
    207   for (const auto& iter : accumulator_resource.values()) {
    208     partition_ids(i) = iter.first.partition_id;
    209     feature_ids(i, 0) = iter.first.feature_id;
    210     feature_ids(i, 1) = iter.first.dimension;
    211 
    212     for (int j = 0; j < num_gradient_elements; ++j) {
    213       gradients(i, j) = iter.second.first[j];
    214     }
    215     for (int j = 0; j < num_hessian_elements; ++j) {
    216       hessians(i, j) = iter.second.second[j];
    217     }
    218     ++i;
    219   }
    220 }
    221 
    222 void AddToScalarAccumulator(
    223     StatsAccumulatorScalarResource* accumulator_resource,
    224     const Tensor& partition_ids_t, const Tensor& feature_ids_t,
    225     const Tensor& gradients_t, const Tensor& hessians_t) {
    226   accumulator_resource->set_num_updates(accumulator_resource->num_updates() +
    227                                         1);
    228   const TensorShape& partition_ids_shape = partition_ids_t.shape();
    229   const auto& partition_ids = partition_ids_t.vec<int32>();
    230   const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>();
    231   const auto& gradients = gradients_t.vec<float>();
    232   const auto& hessians = hessians_t.vec<float>();
    233 
    234   int64 num_updates = partition_ids_shape.dim_size(0);
    235   auto stats_map = accumulator_resource->mutable_values();
    236   for (int64 i = 0; i < num_updates; ++i) {
    237     const auto key =
    238         PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
    239                      feature_ids_and_dimensions(i, 1));
    240     auto itr = stats_map->find(key);
    241     if (itr != stats_map->end()) {
    242       itr->second.first += gradients(i);
    243       itr->second.second += hessians(i);
    244     } else {
    245       (*stats_map)[key] = {gradients(i), hessians(i)};
    246     }
    247   }
    248 }
    249 
    250 void AddToScalarAccumulator(
    251     StatsAccumulatorScalarResource* accumulator_resource,
    252     OpKernelContext* context) {
    253   const Tensor* partition_ids_t;
    254   OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
    255   const Tensor* feature_ids_t;
    256   OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t));
    257   const Tensor* gradients_t;
    258   OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
    259   const Tensor* hessians_t;
    260   OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
    261   AddToScalarAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t,
    262                          *gradients_t, *hessians_t);
    263 }
    264 
    265 void AddToTensorAccumulator(
    266     StatsAccumulatorTensorResource* accumulator_resource,
    267     const Tensor& partition_ids_t, const Tensor& feature_ids_t,
    268     const Tensor& gradients_t, const Tensor& hessians_t,
    269     OpKernelContext* context) {
    270   accumulator_resource->set_num_updates(accumulator_resource->num_updates() +
    271                                         1);
    272 
    273   const TensorShape& partition_ids_shape = partition_ids_t.shape();
    274   const auto& partition_ids = partition_ids_t.vec<int32>();
    275   const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>();
    276   TensorShape gradients_shape = gradients_t.shape();
    277   const auto& gradients = gradients_t.flat_outer_dims<float>();
    278   TensorShape hessians_shape = hessians_t.shape();
    279   const auto& hessians = hessians_t.flat_outer_dims<float>();
    280 
    281   gradients_shape.RemoveDim(0);
    282   hessians_shape.RemoveDim(0);
    283 
    284   // TODO(soroush): Move gradient and hessian shape check to ShapeFn.
    285   OP_REQUIRES(
    286       context, gradients_shape == accumulator_resource->gradient_shape(),
    287       errors::InvalidArgument(strings::StrCat(
    288           "Gradients dimensions must match: ", gradients_shape.DebugString(),
    289           ", ", accumulator_resource->gradient_shape().DebugString())));
    290 
    291   OP_REQUIRES(
    292       context, hessians_shape == accumulator_resource->hessian_shape(),
    293       errors::InvalidArgument(strings::StrCat(
    294           "Hessian dimensions must match: ", hessians_shape.DebugString(), ", ",
    295           accumulator_resource->hessian_shape().DebugString())));
    296 
    297   int64 num_updates = partition_ids_shape.dim_size(0);
    298   auto stats_map = accumulator_resource->mutable_values();
    299   for (int64 i = 0; i < num_updates; ++i) {
    300     const auto key =
    301         PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
    302                      feature_ids_and_dimensions(i, 1));
    303     auto itr = stats_map->find(key);
    304     if (itr == stats_map->end()) {
    305       std::vector<float> new_gradients(gradients_shape.num_elements());
    306       for (int j = 0; j < gradients_shape.num_elements(); ++j) {
    307         new_gradients[j] = gradients(i, j);
    308       }
    309       std::vector<float> new_hessians(hessians_shape.num_elements());
    310       for (int j = 0; j < hessians_shape.num_elements(); ++j) {
    311         new_hessians[j] = hessians(i, j);
    312       }
    313       (*stats_map)[key] = {new_gradients, new_hessians};
    314     } else {
    315       auto& stored_gradients = itr->second.first;
    316       for (int j = 0; j < gradients_shape.num_elements(); ++j) {
    317         stored_gradients[j] += gradients(i, j);
    318       }
    319       auto& stored_hessians = itr->second.second;
    320       for (int j = 0; j < hessians_shape.num_elements(); ++j) {
    321         stored_hessians[j] += hessians(i, j);
    322       }
    323     }
    324   }
    325 }
    326 
    327 void AddToTensorAccumulator(
    328     StatsAccumulatorTensorResource* accumulator_resource,
    329     OpKernelContext* context) {
    330   const Tensor* partition_ids_t;
    331   OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
    332   const Tensor* feature_ids_t;
    333   OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t));
    334   const Tensor* gradients_t;
    335   OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
    336   const Tensor* hessians_t;
    337   OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
    338   AddToTensorAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t,
    339                          *gradients_t, *hessians_t, context);
    340 }
    341 
    342 }  // namespace
    343 
    344 REGISTER_RESOURCE_HANDLE_KERNEL(StatsAccumulatorScalarResource);
    345 REGISTER_RESOURCE_HANDLE_KERNEL(StatsAccumulatorTensorResource);
    346 
    347 REGISTER_KERNEL_BUILDER(
    348     Name("StatsAccumulatorScalarIsInitialized").Device(DEVICE_CPU),
    349     IsResourceInitialized<StatsAccumulatorScalarResource>);
    350 
    351 REGISTER_KERNEL_BUILDER(
    352     Name("StatsAccumulatorTensorIsInitialized").Device(DEVICE_CPU),
    353     IsResourceInitialized<StatsAccumulatorTensorResource>);
    354 
    355 class CreateStatsAccumulatorScalarOp : public OpKernel {
    356  public:
    357   explicit CreateStatsAccumulatorScalarOp(OpKernelConstruction* context)
    358       : OpKernel(context) {}
    359 
    360   void Compute(OpKernelContext* context) override {
    361     const Tensor* stamp_token_t;
    362     OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
    363 
    364     TensorShape gradient_shape = TensorShape({});
    365     TensorShape hessian_shape = TensorShape({});
    366 
    367     auto* result =
    368         new StatsAccumulatorScalarResource(gradient_shape, hessian_shape);
    369     result->set_stamp(stamp_token_t->scalar<int64>()());
    370     // Only create one, if one does not exist already. Report status for all
    371     // other exceptions. If one already exists, it unrefs the new one.
    372     auto status = CreateResource(context, HandleFromInput(context, 0), result);
    373     if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
    374       OP_REQUIRES(context, false, status);
    375     }
    376   }
    377 };
    378 
    379 REGISTER_KERNEL_BUILDER(Name("CreateStatsAccumulatorScalar").Device(DEVICE_CPU),
    380                         CreateStatsAccumulatorScalarOp);
    381 
    382 class CreateStatsAccumulatorTensorOp : public OpKernel {
    383  public:
    384   explicit CreateStatsAccumulatorTensorOp(OpKernelConstruction* context)
    385       : OpKernel(context) {}
    386 
    387   void Compute(OpKernelContext* context) override {
    388     const Tensor* stamp_token_t;
    389     OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
    390 
    391     const Tensor* gradient_shape_t;
    392     OP_REQUIRES_OK(
    393         context, context->input("per_slot_gradient_shape", &gradient_shape_t));
    394 
    395     const Tensor* hessian_shape_t;
    396     OP_REQUIRES_OK(context,
    397                    context->input("per_slot_hessian_shape", &hessian_shape_t));
    398     TensorShape gradient_shape = TensorShape(gradient_shape_t->vec<int64>());
    399     TensorShape hessian_shape = TensorShape(hessian_shape_t->vec<int64>());
    400     auto* result =
    401         new StatsAccumulatorTensorResource(gradient_shape, hessian_shape);
    402     result->set_stamp(stamp_token_t->scalar<int64>()());
    403 
    404     // Only create one, if one does not exist already. Report status for all
    405     // other exceptions. If one already exists, it unrefs the new one.
    406     auto status = CreateResource(context, HandleFromInput(context, 0), result);
    407     if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
    408       OP_REQUIRES(context, false, status);
    409     }
    410   }
    411 };
    412 
    413 REGISTER_KERNEL_BUILDER(Name("CreateStatsAccumulatorTensor").Device(DEVICE_CPU),
    414                         CreateStatsAccumulatorTensorOp);
    415 
    416 class StatsAccumulatorScalarAddOp : public OpKernel {
    417  public:
    418   explicit StatsAccumulatorScalarAddOp(OpKernelConstruction* context)
    419       : OpKernel(context) {}
    420 
    421   void Compute(OpKernelContext* context) override {
    422     OpInputList resource_handle_list;
    423     OP_REQUIRES_OK(context, context->input_list("stats_accumulator_handles",
    424                                                 &resource_handle_list));
    425     OpInputList partition_ids_list;
    426     OP_REQUIRES_OK(context,
    427                    context->input_list("partition_ids", &partition_ids_list));
    428 
    429     OpInputList feature_ids_list;
    430     OP_REQUIRES_OK(context,
    431                    context->input_list("feature_ids", &feature_ids_list));
    432     OpInputList gradients_list;
    433     OP_REQUIRES_OK(context, context->input_list("gradients", &gradients_list));
    434     OpInputList hessians_list;
    435     OP_REQUIRES_OK(context, context->input_list("hessians", &hessians_list));
    436 
    437     const Tensor* stamp_token_t;
    438     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
    439     int64 stamp_token = stamp_token_t->scalar<int64>()();
    440 
    441     thread::ThreadPool* const worker_threads =
    442         context->device()->tensorflow_cpu_worker_threads()->workers;
    443     boosted_trees::utils::ParallelFor(
    444         resource_handle_list.size(), worker_threads->NumThreads(),
    445         worker_threads,
    446         [&context, &resource_handle_list, &partition_ids_list,
    447          &feature_ids_list, &gradients_list, &hessians_list,
    448          stamp_token](int64 start, int64 end) {
    449           for (int resource_handle_idx = start; resource_handle_idx < end;
    450                ++resource_handle_idx) {
    451             ResourceHandle handle = resource_handle_list[resource_handle_idx]
    452                                         .flat<ResourceHandle>()(0);
    453 
    454             StatsAccumulatorScalarResource* accumulator_resource;
    455             OP_REQUIRES_OK(context, LookupResource(context, handle,
    456                                                    &accumulator_resource));
    457             mutex_lock l(*accumulator_resource->mutex());
    458             core::ScopedUnref unref_me(accumulator_resource);
    459 
    460             // If the stamp is invalid we drop the update.
    461             if (!accumulator_resource->is_stamp_valid(stamp_token)) {
    462               VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
    463                       << "Passed stamp token: " << stamp_token << " "
    464                       << "Current token: " << accumulator_resource->stamp();
    465               return;
    466             }
    467             AddToScalarAccumulator(accumulator_resource,
    468                                    partition_ids_list[resource_handle_idx],
    469                                    feature_ids_list[resource_handle_idx],
    470                                    gradients_list[resource_handle_idx],
    471                                    hessians_list[resource_handle_idx]);
    472           }
    473         });
    474   }
    475 };
    476 
    477 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorScalarAdd").Device(DEVICE_CPU),
    478                         StatsAccumulatorScalarAddOp);
    479 
    480 class StatsAccumulatorTensorAddOp : public OpKernel {
    481  public:
    482   explicit StatsAccumulatorTensorAddOp(OpKernelConstruction* context)
    483       : OpKernel(context) {}
    484 
    485   void Compute(OpKernelContext* context) override {
    486     OpInputList resource_handle_list;
    487     OP_REQUIRES_OK(context, context->input_list("stats_accumulator_handles",
    488                                                 &resource_handle_list));
    489     OpInputList partition_ids_list;
    490     OP_REQUIRES_OK(context,
    491                    context->input_list("partition_ids", &partition_ids_list));
    492 
    493     OpInputList feature_ids_list;
    494     OP_REQUIRES_OK(context,
    495                    context->input_list("feature_ids", &feature_ids_list));
    496     OpInputList gradients_list;
    497     OP_REQUIRES_OK(context, context->input_list("gradients", &gradients_list));
    498     OpInputList hessians_list;
    499     OP_REQUIRES_OK(context, context->input_list("hessians", &hessians_list));
    500 
    501     const Tensor* stamp_token_t;
    502     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
    503     int64 stamp_token = stamp_token_t->scalar<int64>()();
    504 
    505     thread::ThreadPool* const worker_threads =
    506         context->device()->tensorflow_cpu_worker_threads()->workers;
    507     boosted_trees::utils::ParallelFor(
    508         resource_handle_list.size(), worker_threads->NumThreads(),
    509         worker_threads,
    510         [&context, &resource_handle_list, &partition_ids_list,
    511          &feature_ids_list, &gradients_list, &hessians_list,
    512          stamp_token](int64 start, int64 end) {
    513           for (int resource_handle_idx = start; resource_handle_idx < end;
    514                ++resource_handle_idx) {
    515             ResourceHandle handle = resource_handle_list[resource_handle_idx]
    516                                         .flat<ResourceHandle>()(0);
    517 
    518             StatsAccumulatorTensorResource* accumulator_resource;
    519             OP_REQUIRES_OK(context, LookupResource(context, handle,
    520                                                    &accumulator_resource));
    521             mutex_lock l(*accumulator_resource->mutex());
    522             core::ScopedUnref unref_me(accumulator_resource);
    523 
    524             // If the stamp is invalid we drop the update.
    525             if (!accumulator_resource->is_stamp_valid(stamp_token)) {
    526               VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
    527                       << "Passed stamp token: " << stamp_token << " "
    528                       << "Current token: " << accumulator_resource->stamp();
    529               return;
    530             }
    531             AddToTensorAccumulator(accumulator_resource,
    532                                    partition_ids_list[resource_handle_idx],
    533                                    feature_ids_list[resource_handle_idx],
    534                                    gradients_list[resource_handle_idx],
    535                                    hessians_list[resource_handle_idx], context);
    536           }
    537         });
    538   }
    539 };
    540 
    541 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorTensorAdd").Device(DEVICE_CPU),
    542                         StatsAccumulatorTensorAddOp);
    543 
    544 class StatsAccumulatorScalarFlushOp : public OpKernel {
    545  public:
    546   explicit StatsAccumulatorScalarFlushOp(OpKernelConstruction* context)
    547       : OpKernel(context) {}
    548 
    549   void Compute(OpKernelContext* context) override {
    550     StatsAccumulatorScalarResource* accumulator_resource;
    551     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
    552                                            &accumulator_resource));
    553     mutex_lock l(*accumulator_resource->mutex());
    554     core::ScopedUnref unref_me(accumulator_resource);
    555 
    556     const Tensor* stamp_token_t;
    557     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
    558     int64 stamp_token = stamp_token_t->scalar<int64>()();
    559 
    560     // If the stamp is invalid we restart the PS. It shouldn't happen since
    561     // only Chief should call this function and chief is guaranteed to be in
    562     // a consistent state.
    563     CHECK(accumulator_resource->is_stamp_valid(stamp_token));
    564 
    565     const Tensor* next_stamp_token_t;
    566     OP_REQUIRES_OK(context,
    567                    context->input(kNextStampTokenName, &next_stamp_token_t));
    568     int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
    569     CHECK(stamp_token != next_stamp_token);
    570 
    571     SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
    572     Tensor* num_updates_t = nullptr;
    573     OP_REQUIRES_OK(context,
    574                    context->allocate_output("num_updates", TensorShape({}),
    575                                             &num_updates_t));
    576     num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
    577 
    578     accumulator_resource->Clear();
    579     accumulator_resource->set_stamp(next_stamp_token);
    580   }
    581 };
    582 
    583 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorScalarFlush").Device(DEVICE_CPU),
    584                         StatsAccumulatorScalarFlushOp);
    585 
    586 class StatsAccumulatorTensorFlushOp : public OpKernel {
    587  public:
    588   explicit StatsAccumulatorTensorFlushOp(OpKernelConstruction* context)
    589       : OpKernel(context) {}
    590 
    591   void Compute(OpKernelContext* context) override {
    592     StatsAccumulatorTensorResource* accumulator_resource;
    593     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
    594                                            &accumulator_resource));
    595     mutex_lock l(*accumulator_resource->mutex());
    596     core::ScopedUnref unref_me(accumulator_resource);
    597 
    598     const Tensor* stamp_token_t;
    599     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
    600     int64 stamp_token = stamp_token_t->scalar<int64>()();
    601 
    602     const Tensor* next_stamp_token_t;
    603     OP_REQUIRES_OK(context,
    604                    context->input(kNextStampTokenName, &next_stamp_token_t));
    605     int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
    606 
    607     // If the stamp is invalid we restart the PS. It shouldn't happen since
    608     // only Chief should call this function and chief is guaranteed to be in
    609     // a consistent state.
    610     CHECK(accumulator_resource->is_stamp_valid(stamp_token));
    611     CHECK(stamp_token != next_stamp_token);
    612     SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
    613     Tensor* num_updates_t = nullptr;
    614     OP_REQUIRES_OK(context,
    615                    context->allocate_output("num_updates", TensorShape({}),
    616                                             &num_updates_t));
    617     num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
    618     accumulator_resource->Clear();
    619     accumulator_resource->set_stamp(next_stamp_token);
    620   }
    621 };
    622 
    623 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorTensorFlush").Device(DEVICE_CPU),
    624                         StatsAccumulatorTensorFlushOp);
    625 
    626 class StatsAccumulatorScalarDeserializeOp : public OpKernel {
    627  public:
    628   explicit StatsAccumulatorScalarDeserializeOp(OpKernelConstruction* context)
    629       : OpKernel(context) {}
    630 
    631   void Compute(OpKernelContext* context) override {
    632     StatsAccumulatorScalarResource* accumulator_resource;
    633     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
    634                                            &accumulator_resource));
    635     mutex_lock l(*accumulator_resource->mutex());
    636     core::ScopedUnref unref_me(accumulator_resource);
    637 
    638     // Check the stamp token.
    639     const Tensor* stamp_token_t;
    640     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
    641     int64 stamp_token = stamp_token_t->scalar<int64>()();
    642     accumulator_resource->Clear();
    643     accumulator_resource->set_stamp(stamp_token);
    644     AddToScalarAccumulator(accumulator_resource, context);
    645     const Tensor* num_updates_t;
    646     OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
    647     accumulator_resource->set_num_updates(num_updates_t->scalar<int64>()());
    648   }
    649 };
    650 
    651 REGISTER_KERNEL_BUILDER(
    652     Name("StatsAccumulatorScalarDeserialize").Device(DEVICE_CPU),
    653     StatsAccumulatorScalarDeserializeOp);
    654 
    655 class StatsAccumulatorTensorDeserializeOp : public OpKernel {
    656  public:
    657   explicit StatsAccumulatorTensorDeserializeOp(OpKernelConstruction* context)
    658       : OpKernel(context) {}
    659 
    660   void Compute(OpKernelContext* context) override {
    661     StatsAccumulatorTensorResource* accumulator_resource;
    662     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
    663                                            &accumulator_resource));
    664     mutex_lock l(*accumulator_resource->mutex());
    665     core::ScopedUnref unref_me(accumulator_resource);
    666 
    667     // Check the stamp token.
    668     const Tensor* stamp_token_t;
    669     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
    670     int64 stamp_token = stamp_token_t->scalar<int64>()();
    671     accumulator_resource->Clear();
    672     accumulator_resource->set_stamp(stamp_token);
    673     AddToTensorAccumulator(accumulator_resource, context);
    674     const Tensor* num_updates_t;
    675     OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
    676     accumulator_resource->set_num_updates(num_updates_t->scalar<int64>()());
    677   }
    678 };
    679 
    680 REGISTER_KERNEL_BUILDER(
    681     Name("StatsAccumulatorTensorDeserialize").Device(DEVICE_CPU),
    682     StatsAccumulatorTensorDeserializeOp);
    683 
    684 class StatsAccumulatorScalarSerializeOp : public OpKernel {
    685  public:
    686   explicit StatsAccumulatorScalarSerializeOp(OpKernelConstruction* context)
    687       : OpKernel(context) {}
    688 
    689   void Compute(OpKernelContext* context) override {
    690     StatsAccumulatorScalarResource* accumulator_resource;
    691     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
    692                                            &accumulator_resource));
    693     mutex_lock l(*accumulator_resource->mutex());
    694     core::ScopedUnref unref_me(accumulator_resource);
    695     SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
    696     Tensor* stamp_token_t = nullptr;
    697     OP_REQUIRES_OK(context,
    698                    context->allocate_output("stamp_token", TensorShape({}),
    699                                             &stamp_token_t));
    700     stamp_token_t->scalar<int64>()() = accumulator_resource->stamp();
    701 
    702     Tensor* num_updates_t = nullptr;
    703     OP_REQUIRES_OK(context,
    704                    context->allocate_output("num_updates", TensorShape({}),
    705                                             &num_updates_t));
    706     num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
    707   }
    708 };
    709 
    710 REGISTER_KERNEL_BUILDER(
    711     Name("StatsAccumulatorScalarSerialize").Device(DEVICE_CPU),
    712     StatsAccumulatorScalarSerializeOp);
    713 
    714 class StatsAccumulatorTensorSerializeOp : public OpKernel {
    715  public:
    716   explicit StatsAccumulatorTensorSerializeOp(OpKernelConstruction* context)
    717       : OpKernel(context) {}
    718 
    719   void Compute(OpKernelContext* context) override {
    720     StatsAccumulatorTensorResource* accumulator_resource;
    721     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
    722                                            &accumulator_resource));
    723     mutex_lock l(*accumulator_resource->mutex());
    724     core::ScopedUnref unref_me(accumulator_resource);
    725     SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
    726     Tensor* stamp_token_t = nullptr;
    727     OP_REQUIRES_OK(context,
    728                    context->allocate_output("stamp_token", TensorShape({}),
    729                                             &stamp_token_t));
    730     stamp_token_t->scalar<int64>()() = accumulator_resource->stamp();
    731 
    732     Tensor* num_updates_t = nullptr;
    733     OP_REQUIRES_OK(context,
    734                    context->allocate_output("num_updates", TensorShape({}),
    735                                             &num_updates_t));
    736     num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
    737   }
    738 };
    739 
    740 REGISTER_KERNEL_BUILDER(
    741     Name("StatsAccumulatorTensorSerialize").Device(DEVICE_CPU),
    742     StatsAccumulatorTensorSerializeOp);
    743 
    744 class StatsAccumulatorScalarMakeSummaryOp : public OpKernel {
    745  public:
    746   explicit StatsAccumulatorScalarMakeSummaryOp(OpKernelConstruction* context)
    747       : OpKernel(context) {}
    748 
    749   void Compute(OpKernelContext* context) override {
    750     TensorShape gradient_shape = TensorShape({});
    751     TensorShape hessian_shape = TensorShape({});
    752     StatsAccumulatorScalarResource* accumulator_resource =
    753         new StatsAccumulatorScalarResource(gradient_shape, hessian_shape);
    754     core::ScopedUnref unref_me(accumulator_resource);
    755     // Check the stamp token.
    756     AddToScalarAccumulator(accumulator_resource, context);
    757     SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
    758   }
    759 };
    760 
    761 REGISTER_KERNEL_BUILDER(
    762     Name("StatsAccumulatorScalarMakeSummary").Device(DEVICE_CPU),
    763     StatsAccumulatorScalarMakeSummaryOp);
    764 
    765 class StatsAccumulatorTensorMakeSummaryOp : public OpKernel {
    766  public:
    767   explicit StatsAccumulatorTensorMakeSummaryOp(OpKernelConstruction* context)
    768       : OpKernel(context) {}
    769 
    770   void Compute(OpKernelContext* context) override {
    771     const Tensor* gradients_t;
    772     OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
    773     TensorShape gradients_shape = gradients_t->shape();
    774     gradients_shape.RemoveDim(0);
    775 
    776     const Tensor* hessians_t;
    777     OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
    778     TensorShape hessians_shape = hessians_t->shape();
    779     hessians_shape.RemoveDim(0);
    780 
    781     StatsAccumulatorTensorResource* accumulator_resource =
    782         new StatsAccumulatorTensorResource(gradients_shape, hessians_shape);
    783     core::ScopedUnref unref_me(accumulator_resource);
    784     // Check the stamp token.
    785     AddToTensorAccumulator(accumulator_resource, context);
    786     SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
    787   }
    788 };
    789 
    790 REGISTER_KERNEL_BUILDER(
    791     Name("StatsAccumulatorTensorMakeSummary").Device(DEVICE_CPU),
    792     StatsAccumulatorTensorMakeSummaryOp);
    793 
    794 }  // namespace boosted_trees
    795 }  // namespace tensorflow
    796