Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 // See docs in ../ops/data_flow_ops.cc.
     16 
     17 #include <limits.h>
     18 #include <unordered_map>
     19 #include <vector>
     20 
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/framework/register_types.h"
     23 #include "tensorflow/core/framework/resource_mgr.h"
     24 #include "tensorflow/core/framework/resource_op_kernel.h"
     25 #include "tensorflow/core/framework/tensor.h"
     26 #include "tensorflow/core/framework/tensor_shape.h"
     27 #include "tensorflow/core/framework/types.h"
     28 #include "tensorflow/core/kernels/priority_queue.h"
     29 #include "tensorflow/core/kernels/queue_base.h"
     30 #include "tensorflow/core/lib/core/errors.h"
     31 #include "tensorflow/core/lib/core/notification.h"
     32 #include "tensorflow/core/lib/gtl/map_util.h"
     33 #include "tensorflow/core/platform/logging.h"
     34 #include "tensorflow/core/platform/macros.h"
     35 #include "tensorflow/core/platform/mutex.h"
     36 #include "tensorflow/core/platform/thread_annotations.h"
     37 #include "tensorflow/core/platform/types.h"
     38 
     39 namespace tensorflow {
     40 
     41 namespace barrier {
     42 
     43 class Barrier : public ResourceBase {
     44  public:
     45   typedef std::vector<Tensor> Tuple;
     46   typedef std::function<void()> DoneCallback;
     47   typedef std::function<void(const Tensor&, const Tensor&, const Tuple&)>
     48       IndicesKeysValuesCallback;
     49 
     50   Barrier(const DataTypeVector& value_component_types,
     51           const std::vector<TensorShape>& value_component_shapes,
     52           const string& name)
     53       : closed_(false),
     54         queue_closed_(false),
     55         queue_cancelled_(false),
     56         cancel_pending_enqueues_(false),
     57         value_component_types_(value_component_types),
     58         value_component_shapes_(value_component_shapes),
     59         name_(name),
     60         input_index_(std::numeric_limits<int64>::min()) {
     61     DataTypeVector queue_component_types;
     62     std::vector<TensorShape> queue_component_shapes;
     63 
     64     // First queue component is for the input index;
     65     // Second queue component is for the key;
     66     // remaining queue components are for the value.
     67     queue_component_types.push_back(DT_INT64);
     68     queue_component_types.push_back(DT_STRING);
     69     for (DataType dt : value_component_types) {
     70       queue_component_types.push_back(dt);
     71     }
     72 
     73     // NOTE(mrry): PriorityQueue expects all shapes specified because
     74     // we'll be issuing TakeMany.
     75     queue_component_shapes.push_back(TensorShape({}));
     76     queue_component_shapes.push_back(TensorShape({}));
     77     queue_component_shapes.insert(queue_component_shapes.end(),
     78                                   value_component_shapes.begin(),
     79                                   value_component_shapes.end());
     80 
     81     ready_queue_ = new PriorityQueue(
     82         QueueBase::kUnbounded /* capacity */, queue_component_types,
     83         queue_component_shapes, strings::StrCat(name_, "_queue"));
     84   }
     85 
     86   Status Initialize() { return ready_queue_->Initialize(); }
     87 
     88   template <typename T>
     89   void TryInsertMany(const Tensor& keys, int component_index,
     90                      const Tensor& values, OpKernelContext* ctx,
     91                      const DoneCallback& callback) {
     92     TensorShape element_shape = values.shape();
     93     OP_REQUIRES_ASYNC(
     94         ctx, keys.NumElements() == 0 || element_shape.num_elements() > 0,
     95         errors::InvalidArgument("Tensors with no elements are not supported ",
     96                                 name_, ": received shape ",
     97                                 element_shape.DebugString()),
     98         callback);
     99     if (element_shape.dims() > 0) element_shape.RemoveDim(0);
    100     const std::size_t num_inserted = keys.NumElements();
    101 
    102     // For each key, update the corresponding incomplete tuple with the
    103     // the corresponding given value at component_index.
    104     // This will be passed to the final callback at the very end.
    105     bool new_elements = false;
    106 
    107     // Will be used for the final insert into the queue.
    108     Tuple insert_tuple;
    109 
    110     {
    111       mutex_lock lock(mu_);
    112       if (closed_) {
    113         OP_REQUIRES_ASYNC(
    114             ctx,
    115             !cancel_pending_enqueues_ &&
    116                 (num_inserted == 0 || !incomplete_.empty()),
    117             errors::Cancelled(
    118                 "Barrier ", name_, " is closed.  Pending enqueues cancelled: ",
    119                 cancel_pending_enqueues_,
    120                 ".  Number of new insertions: ", num_inserted,
    121                 ".  Number of incomplete keys: ", incomplete_.size(), "."),
    122             callback);
    123       }
    124 
    125       // Step 1: insert into the incomplete map and identify which
    126       // entries are, in fact, complete and ready for enqueueing.  Store
    127       // them in a vector
    128       std::vector<Tuple> ready_tuples;
    129 
    130       for (int i = 0; i < num_inserted; ++i) {
    131         OP_REQUIRES_OK_ASYNC(
    132             ctx,
    133             InsertOneLocked<T>(ctx, keys, values, element_shape,
    134                                component_index, i, &ready_tuples,
    135                                &new_elements),
    136             callback);
    137       }
    138 
    139       if (new_elements) ++input_index_;
    140 
    141       // This probably won't happen before the heat death of the
    142       // universe, but who knows?  Moore's law FTW.
    143       OP_REQUIRES_ASYNC(
    144           ctx, input_index_ != std::numeric_limits<int64>::max(),
    145           errors::Internal(
    146               "Barrier has had ", input_index_,
    147               " insertions and can no longer keep track of new ones."),
    148           callback);
    149 
    150       if (ready_tuples.empty()) {
    151         // Nothing to insert into the queue - so return early.
    152         callback();
    153         return;
    154       }
    155 
    156       // We have something to Enqueue.  Convert the Tuples into a single
    157       // tuple by slicing entries into new Tensors.  This part is slow
    158       // but seems the cleanest solution for now.
    159       insert_tuple.reserve(2 + num_components());  // indices, keys, rest
    160       int insertion_size = ready_tuples.size();
    161       for (int i = 0; i < 2 + num_components(); ++i) {
    162         TensorShape component_shape(ready_tuples[0][i].shape());
    163         component_shape.InsertDim(0, insertion_size);
    164         Tensor component(ready_tuples[0][i].dtype(), component_shape);
    165         for (int b = 0; b < insertion_size; ++b) {
    166           OP_REQUIRES_OK_ASYNC(
    167               ctx,
    168               batch_util::CopyElementToSlice(std::move(ready_tuples[b][i]),
    169                                              &component, b),
    170               callback);
    171         }
    172         insert_tuple.push_back(component);
    173       }
    174     }
    175 
    176     // Update the input index for the next batch.
    177     ready_queue_->TryEnqueueMany(
    178         insert_tuple, ctx,
    179         // To avoid early closing of the queue, only close it if the
    180         // SQSS is closed, nothing is left in the incomplete set,
    181         // the queue is not already marked as closed, and (most
    182         // importantly), the queue has entries in it.
    183         [this, ctx, callback, component_index]() {
    184           if (!ctx->status().ok()) {
    185             callback();
    186             return;
    187           }
    188           {
    189             mutex_lock lock(mu_);
    190             int32 ready = ready_size();
    191             if (closed_ && incomplete_.empty() && queue_closed_ && ready > 0) {
    192               CloseQueueLocked(ctx, false, callback);
    193             } else {
    194               callback();
    195             }
    196             return;
    197           }
    198         });
    199   }
    200 
    201   void TryTakeMany(int num_elements, bool allow_small_batch, int64 timeout,
    202                    OpKernelContext* ctx,
    203                    const IndicesKeysValuesCallback& callback) {
    204     int num_elements_to_deliver = num_elements;
    205     {
    206       mutex_lock lock(mu_);
    207       if (closed_) {
    208         int available_elements = ready_size();
    209         if (allow_small_batch) {
    210           // We want to deliver a maximum of num_elements, if there are less
    211           // elements available, we deliver at most the available_elements. If
    212           // there are no
    213           // elements available, a call to TryTakeMany should fail with
    214           // OutOfRange. We trigger this error by setting the request here to 1.
    215           num_elements_to_deliver = std::min(num_elements, available_elements);
    216         } else {
    217           // We're happy to wait for additional elements to be completed.
    218           available_elements += incomplete_.size();
    219         }
    220         // If there are 0 available elements or less elements than the
    221         // number we can deliver, then we are done.
    222         if (available_elements < std::max(num_elements_to_deliver, 1)) {
    223           ctx->SetStatus(errors::OutOfRange(
    224               "Barrier '", name_, "' is closed and has ",
    225               "insufficient elements (requested ", num_elements_to_deliver,
    226               ", total size ", available_elements, ")"));
    227           callback(Tensor(DT_INT64), Tensor(DT_STRING), Tuple());
    228           return;
    229         }
    230       }
    231     }
    232 
    233     ready_queue_->TryDequeueMany(
    234         num_elements_to_deliver, ctx, allow_small_batch,
    235         [this, ctx, callback](const Tuple& t) {
    236           Tensor indices(DT_INT64);
    237           Tensor keys(DT_STRING);
    238           Tuple values;
    239 
    240           if (!ctx->status().ok()) {
    241             callback(indices, keys, values);
    242             return;
    243           }
    244 
    245           CHECK_EQ(t.size(), 2 + num_components());
    246           indices = t[0];
    247           keys = t[1];
    248           values.insert(values.begin(), t.begin() + 2, t.end());
    249           callback(indices, keys, values);
    250           return;
    251         });
    252   }
    253 
    254   void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
    255              const DoneCallback& callback) {
    256     mutex_lock lock(mu_);
    257     // We're allowed to close twice if the first close wasn't a
    258     // cancel but the second one is.
    259     if (closed_ && (cancel_pending_enqueues_ || !cancel_pending_enqueues)) {
    260       ctx->SetStatus(
    261           errors::Cancelled("Barrier '", name_, "' is already closed."));
    262       callback();
    263       return;
    264     }
    265     cancel_pending_enqueues_ = cancel_pending_enqueues;
    266     closed_ = true;
    267     if (cancel_pending_enqueues_ || incomplete_.empty()) {
    268       incomplete_.clear();
    269       // CloseQueueLocked runs the callback
    270       CloseQueueLocked(ctx, cancel_pending_enqueues_, callback);
    271       return;
    272     }
    273     callback();
    274   }
    275 
    276   int32 ready_size() { return ready_queue_->size(); }
    277 
    278   int32 incomplete_size() {
    279     mutex_lock lock(mu_);
    280     return incomplete_.size();
    281   }
    282 
    283   const string& name() const { return name_; }
    284   int num_components() const { return value_component_types_.size(); }
    285   DataType component_type(int i) const {
    286     CHECK_GE(i, 0);
    287     CHECK_LT(static_cast<size_t>(i), value_component_types_.size());
    288     return value_component_types_[i];
    289   }
    290   const DataTypeVector component_types() const {
    291     return value_component_types_;
    292   }
    293   const gtl::ArraySlice<TensorShape> component_shapes() const {
    294     return value_component_shapes_;
    295   }
    296 
    297   ~Barrier() override EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    298     mutex_lock lock(mu_);
    299     incomplete_.clear();
    300     ready_queue_->Unref();
    301   }
    302 
    303   string DebugString() override { return "A barrier"; }
    304 
    305  protected:
    306   template <typename T>
    307   Status InsertOneLocked(OpKernelContext* ctx, const Tensor& keys,
    308                          const Tensor& values, const TensorShape& element_shape,
    309                          int component_index, int i,
    310                          std::vector<Tuple>* ready_tuples, bool* new_elements)
    311       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    312     auto keys_vec = keys.flat<string>();
    313     auto values_matrix = values.flat_outer_dims<T>();
    314 
    315     PersistentTuple* element_ptr;
    316     if (closed_) {
    317       element_ptr = gtl::FindOrNull(incomplete_, keys_vec(i));
    318       if (element_ptr == nullptr) {
    319         return errors::Cancelled(
    320             "Barrier ", name_,
    321             " is closed, but attempted to insert a brand new key: ",
    322             keys_vec(i),
    323             ".  Pending enqueues cancelled: ", cancel_pending_enqueues_,
    324             ".  Insertion index: ", i,
    325             ".  Number of incomplete keys: ", incomplete_.size(), ".");
    326       }
    327     } else {
    328       element_ptr =
    329           &gtl::LookupOrInsert(&incomplete_, keys_vec(i), PersistentTuple());
    330     }
    331     PersistentTuple& element = *element_ptr;
    332 
    333     if (element.empty()) {  // Never seen before key
    334       // Added a new element, for keeping track of the insertion index
    335       *new_elements = true;
    336 
    337       // Initialize the incomplete tuple for a new key.
    338       element.reserve(1 + num_components());
    339 
    340       // The first entry in element is the priority: the
    341       // input_index_, so that tensors that entered the Barrier
    342       // earlier have higher priority in the queue.
    343       PersistentTensor index_persistent_tensor;
    344       Tensor* allocate_index_tensor;
    345       TF_RETURN_IF_ERROR(ctx->allocate_persistent(DT_INT64, TensorShape({}),
    346                                                   &index_persistent_tensor,
    347                                                   &allocate_index_tensor));
    348 
    349       Tensor index_tensor(DT_INT64, TensorShape({}));
    350       allocate_index_tensor->scalar<int64>()() = input_index_;
    351       element.push_back(index_persistent_tensor);
    352 
    353       // The rest of the element stores uninitialized Tensors with
    354       // the appropriate dtype.
    355       for (int j = 0; j < num_components(); ++j) {
    356         Tensor uninitialized(component_type(j));
    357         element.push_back(PersistentTensor(uninitialized));
    358       }
    359     }
    360     const PersistentTensor& component = element[1 + component_index];
    361     if (component.IsInitialized() && component.NumElements() > 0) {
    362       return errors::InvalidArgument("Key ", keys_vec(i),
    363                                      " already has a value for component ",
    364                                      component_index, " in barrier ", name());
    365     }
    366 
    367     // Extract the slice corresponding to the value from the value Tensor,
    368     // and store it in the incomplete tuple at component_index.
    369     PersistentTensor next_element;
    370     Tensor* allocated_element;
    371     TF_RETURN_IF_ERROR(ctx->allocate_persistent(
    372         values.dtype(), element_shape, &next_element, &allocated_element));
    373     element[1 + component_index] = next_element;
    374     allocated_element->flat<T>() = values_matrix.template chip<0>(i);
    375 
    376     // Check the components of the tuple to see if it has become complete
    377     // (i.e. all of its components are initialized). If so, add it to the
    378     // ready queue.
    379     bool is_complete = true;
    380     for (int j = 0; is_complete && j < element.size(); ++j) {
    381       is_complete = element[j].IsInitialized() && element[j].NumElements() > 0;
    382     }
    383     if (is_complete) {
    384       // Add tuple to the ready queue. A queue tuple has the index
    385       // as the first element and the key as the second element,
    386       // followed by the value components.
    387       Tuple ready_tuple;
    388       ready_tuple.reserve(2 + num_components());  // index, key, rest
    389       // Build a tensor for the key. TODO(mrry): Something more efficient.
    390       PersistentTensor key;
    391       Tensor* allocated_key;
    392       TF_RETURN_IF_ERROR(ctx->allocate_persistent(DT_STRING, TensorShape({}),
    393                                                   &key, &allocated_key));
    394       ready_tuple.push_back(*element[0].AccessTensor(ctx));  // index
    395       ready_tuple.push_back(*allocated_key);                 // key
    396       ready_tuple[1].scalar<string>()() = keys_vec(i);       // set the key
    397       for (int j = 1; j < num_components() + 1; ++j) {
    398         ready_tuple.push_back(*element[j].AccessTensor(ctx));
    399       }
    400       incomplete_.erase(incomplete_.find(keys_vec(i)));
    401       TF_RETURN_IF_ERROR(ready_queue_->ValidateTuple(ready_tuple));
    402       ready_tuples->push_back(ready_tuple);
    403     }
    404     return Status::OK();
    405   }
    406 
    407   void CloseQueueLocked(OpKernelContext* ctx, bool cancel_pending_enqueues,
    408                         const DoneCallback& callback)
    409       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    410     // CloseQueueLocked may only be called with mu_ held.
    411     if (!cancel_pending_enqueues && queue_closed_) {
    412       callback();
    413       return;
    414     }
    415     if (cancel_pending_enqueues && queue_cancelled_) {
    416       callback();
    417       return;
    418     }
    419     queue_closed_ = true;
    420     if (cancel_pending_enqueues) queue_cancelled_ = true;
    421     if (!ready_queue_->is_closed()) {
    422       ready_queue_->Close(ctx, cancel_pending_enqueues, callback);
    423     }
    424   }
    425 
    426  private:
    427   typedef std::vector<PersistentTensor> PersistentTuple;
    428   mutex mu_;
    429   bool closed_ GUARDED_BY(mu_);
    430   bool queue_closed_ GUARDED_BY(mu_);
    431   bool queue_cancelled_ GUARDED_BY(mu_);
    432   bool cancel_pending_enqueues_ GUARDED_BY(mu_);
    433   const DataTypeVector value_component_types_;
    434   const std::vector<TensorShape>& value_component_shapes_;
    435   const string name_;
    436   int64 input_index_ GUARDED_BY(mu_);
    437   std::unordered_map<string, PersistentTuple> incomplete_ GUARDED_BY(mu_);
    438   PriorityQueue* ready_queue_;
    439 
    440   TF_DISALLOW_COPY_AND_ASSIGN(Barrier);
    441 };
    442 
    443 class BarrierOp : public ResourceOpKernel<Barrier> {
    444  public:
    445   explicit BarrierOp(OpKernelConstruction* context)
    446       : ResourceOpKernel(context) {
    447     OP_REQUIRES_OK(
    448         context, context->GetAttr("component_types", &value_component_types_));
    449     OP_REQUIRES_OK(context,
    450                    context->GetAttr("shapes", &value_component_shapes_));
    451     OP_REQUIRES(context,
    452                 value_component_shapes_.size() == value_component_types_.size(),
    453                 errors::InvalidArgument(
    454                     "All of the component shapes must be specified"));
    455 
    456     int32 value_capacity;
    457     OP_REQUIRES_OK(context, context->GetAttr("capacity", &value_capacity));
    458     OP_REQUIRES(context, value_capacity == -1,
    459                 errors::InvalidArgument(
    460                     "Barrier only accepts capacity=-1.  Feed the "
    461                     "inputs to your Barrier through a queue to enforce a "
    462                     "limited capacity."));
    463   }
    464 
    465  private:
    466   Status CreateResource(Barrier** barrier) override
    467       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    468     *barrier = new Barrier(value_component_types_, value_component_shapes_,
    469                            cinfo_.name());
    470     if (*barrier == nullptr) {
    471       return errors::ResourceExhausted("Failed to allocate barrier");
    472     }
    473     return (*barrier)->Initialize();
    474   }
    475 
    476   Status VerifyResource(Barrier* barrier) override
    477       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    478     if (barrier->component_types() != value_component_types_) {
    479       return errors::InvalidArgument(
    480           "Shared barrier '", cinfo_.name(), "' has component types ",
    481           DataTypeSliceString(barrier->component_types()),
    482           " but requested component types were ",
    483           DataTypeSliceString(value_component_types_));
    484     }
    485     if (barrier->component_shapes() != value_component_shapes_) {
    486       return errors::InvalidArgument(
    487           "Shared barrier '", cinfo_.name(), "' has component shapes ",
    488           TensorShapeUtils::ShapeListString(barrier->component_shapes()),
    489           " but requested component shapes were ",
    490           TensorShapeUtils::ShapeListString(value_component_shapes_));
    491     }
    492     return Status::OK();
    493   }
    494 
    495   DataTypeVector value_component_types_;
    496   std::vector<TensorShape> value_component_shapes_;
    497 
    498   TF_DISALLOW_COPY_AND_ASSIGN(BarrierOp);
    499 };
    500 
    501 REGISTER_KERNEL_BUILDER(Name("Barrier").Device(DEVICE_CPU), BarrierOp);
    502 
    503 class BarrierOpKernel : public AsyncOpKernel {
    504  public:
    505   explicit BarrierOpKernel(OpKernelConstruction* context)
    506       : AsyncOpKernel(context) {}
    507 
    508   void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final {
    509     Barrier* barrier = nullptr;
    510     OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &barrier),
    511                          callback);
    512     ComputeAsync(ctx, barrier, [this, callback, barrier]() {
    513       barrier->Unref();
    514       callback();
    515     });
    516   }
    517 
    518  protected:
    519   virtual void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
    520                             DoneCallback callback) = 0;
    521 };
    522 
    523 template <typename T>
    524 class InsertManyOp : public BarrierOpKernel {
    525  public:
    526   explicit InsertManyOp(OpKernelConstruction* context)
    527       : BarrierOpKernel(context) {
    528     OP_REQUIRES_OK(context,
    529                    context->GetAttr("component_index", &component_index_));
    530   }
    531 
    532  protected:
    533   void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
    534                     DoneCallback callback) override {
    535     OP_REQUIRES_ASYNC(
    536         ctx, component_index_ < barrier->num_components(),
    537         errors::InvalidArgument("The component ID is out of range ",
    538                                 component_index_, " > num_components",
    539                                 " (= ", barrier->num_components(), ")"),
    540         callback);
    541     OP_REQUIRES_OK_ASYNC(
    542         ctx,
    543         ctx->MatchSignature({DT_STRING_REF, DT_STRING,
    544                              barrier->component_type(component_index_)},
    545                             {}),
    546         callback);
    547 
    548     const Tensor* keys;
    549     const Tensor* values;
    550     OP_REQUIRES_OK_ASYNC(ctx, ctx->input("keys", &keys), callback);
    551     OP_REQUIRES_OK_ASYNC(ctx, ctx->input("values", &values), callback);
    552     barrier->TryInsertMany<T>(*keys, component_index_, *values, ctx, callback);
    553   }
    554 
    555  private:
    556   int component_index_;
    557   TF_DISALLOW_COPY_AND_ASSIGN(InsertManyOp);
    558 };
    559 
    560 #define REGISTER_INSERTMANY(T)                                             \
    561   REGISTER_KERNEL_BUILDER(                                                 \
    562       Name("BarrierInsertMany").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    563       InsertManyOp<T>);
    564 
    565 TF_CALL_ALL_TYPES(REGISTER_INSERTMANY);
    566 #undef REGISTER_INSERTMANY
    567 
    568 class TakeManyOp : public BarrierOpKernel {
    569  public:
    570   explicit TakeManyOp(OpKernelConstruction* context)
    571       : BarrierOpKernel(context) {
    572     OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_));
    573     // TODO(keveman): Enable timeout.
    574     OP_REQUIRES(context, timeout_ == -1,
    575                 errors::InvalidArgument("Timeout not supported yet."));
    576 
    577     OP_REQUIRES_OK(context,
    578                    context->GetAttr("allow_small_batch", &allow_small_batch_));
    579   }
    580 
    581  protected:
    582   void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
    583                     DoneCallback callback) override {
    584     const Tensor* Tnum_elements;
    585     OP_REQUIRES_OK_ASYNC(ctx, ctx->input("num_elements", &Tnum_elements),
    586                          callback);
    587     OP_REQUIRES_ASYNC(ctx, TensorShapeUtils::IsScalar(Tnum_elements->shape()),
    588                       errors::InvalidArgument("num_elements must be a scalar."),
    589                       callback);
    590     const int32 num_elements = Tnum_elements->scalar<int32>()();
    591 
    592     DataTypeVector expected_inputs = {DT_STRING_REF, DT_INT32};
    593     // The first output is the insertion index, the second output is the key.
    594     DataTypeVector expected_outputs = {DT_INT64, DT_STRING};
    595     for (DataType dt : barrier->component_types()) {
    596       expected_outputs.push_back(dt);
    597     }
    598     OP_REQUIRES_OK_ASYNC(
    599         ctx, ctx->MatchSignature(expected_inputs, expected_outputs), callback);
    600 
    601     barrier->TryTakeMany(
    602         num_elements, allow_small_batch_, timeout_, ctx,
    603         [ctx, callback](const Tensor& indices, const Tensor& keys,
    604                         const Barrier::Tuple& values) {
    605           if (!ctx->status().ok()) {
    606             callback();
    607             return;
    608           }
    609           // At this point, indices, keys, and values
    610           // have all been written to successfully.
    611           OP_REQUIRES_OK_ASYNC(ctx, ctx->set_output("indices", indices),
    612                                callback);
    613           OP_REQUIRES_OK_ASYNC(ctx, ctx->set_output("keys", keys), callback);
    614           OpOutputList values_output;
    615           OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("values", &values_output),
    616                                callback);
    617           for (size_t i = 0; i < values.size(); ++i) {
    618             values_output.set(i, values[i]);
    619           }
    620           callback();
    621           return;
    622         });
    623   }
    624 
    625  private:
    626   int64 timeout_;
    627   bool allow_small_batch_;
    628   TF_DISALLOW_COPY_AND_ASSIGN(TakeManyOp);
    629 };
    630 
    631 REGISTER_KERNEL_BUILDER(Name("BarrierTakeMany").Device(DEVICE_CPU), TakeManyOp);
    632 
    633 class BarrierCloseOp : public BarrierOpKernel {
    634  public:
    635   explicit BarrierCloseOp(OpKernelConstruction* context)
    636       : BarrierOpKernel(context) {
    637     OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues",
    638                                              &cancel_pending_enqueues_));
    639   }
    640 
    641  protected:
    642   void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
    643                     DoneCallback callback) override {
    644     barrier->Close(ctx, cancel_pending_enqueues_, callback);
    645   }
    646 
    647  private:
    648   bool cancel_pending_enqueues_;
    649   TF_DISALLOW_COPY_AND_ASSIGN(BarrierCloseOp);
    650 };
    651 
    652 REGISTER_KERNEL_BUILDER(Name("BarrierClose").Device(DEVICE_CPU),
    653                         BarrierCloseOp);
    654 
    655 class BarrierIncompleteSizeOp : public BarrierOpKernel {
    656  public:
    657   explicit BarrierIncompleteSizeOp(OpKernelConstruction* context)
    658       : BarrierOpKernel(context) {}
    659 
    660  protected:
    661   void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
    662                     DoneCallback callback) override {
    663     Tensor* Tsize = nullptr;
    664     OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &Tsize),
    665                          callback);
    666     Tsize->scalar<int32>().setConstant(barrier->incomplete_size());
    667     callback();
    668   }
    669 };
    670 
    671 REGISTER_KERNEL_BUILDER(Name("BarrierIncompleteSize").Device(DEVICE_CPU),
    672                         BarrierIncompleteSizeOp);
    673 
    674 class BarrierReadySizeOp : public BarrierOpKernel {
    675  public:
    676   explicit BarrierReadySizeOp(OpKernelConstruction* context)
    677       : BarrierOpKernel(context) {}
    678 
    679  protected:
    680   void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
    681                     DoneCallback callback) override {
    682     Tensor* Tsize = nullptr;
    683     OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &Tsize),
    684                          callback);
    685     Tsize->scalar<int32>().setConstant(barrier->ready_size());
    686     callback();
    687   }
    688 };
    689 
    690 REGISTER_KERNEL_BUILDER(Name("BarrierReadySize").Device(DEVICE_CPU),
    691                         BarrierReadySizeOp);
    692 
    693 }  // namespace barrier
    694 
    695 }  // namespace tensorflow
    696