1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // See docs in ../ops/data_flow_ops.cc. 16 17 #include <limits.h> 18 #include <unordered_map> 19 #include <vector> 20 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/framework/register_types.h" 23 #include "tensorflow/core/framework/resource_mgr.h" 24 #include "tensorflow/core/framework/resource_op_kernel.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/framework/tensor_shape.h" 27 #include "tensorflow/core/framework/types.h" 28 #include "tensorflow/core/kernels/priority_queue.h" 29 #include "tensorflow/core/kernels/queue_base.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 #include "tensorflow/core/lib/core/notification.h" 32 #include "tensorflow/core/lib/gtl/map_util.h" 33 #include "tensorflow/core/platform/logging.h" 34 #include "tensorflow/core/platform/macros.h" 35 #include "tensorflow/core/platform/mutex.h" 36 #include "tensorflow/core/platform/thread_annotations.h" 37 #include "tensorflow/core/platform/types.h" 38 39 namespace tensorflow { 40 41 namespace barrier { 42 43 class Barrier : public ResourceBase { 44 public: 45 typedef std::vector<Tensor> Tuple; 46 typedef std::function<void()> DoneCallback; 47 typedef std::function<void(const Tensor&, const Tensor&, const Tuple&)> 48 IndicesKeysValuesCallback; 49 50 Barrier(const DataTypeVector& value_component_types, 51 const std::vector<TensorShape>& value_component_shapes, 52 const string& name) 53 : closed_(false), 54 queue_closed_(false), 55 queue_cancelled_(false), 56 cancel_pending_enqueues_(false), 57 value_component_types_(value_component_types), 58 value_component_shapes_(value_component_shapes), 59 name_(name), 60 input_index_(std::numeric_limits<int64>::min()) { 61 DataTypeVector queue_component_types; 62 std::vector<TensorShape> queue_component_shapes; 63 64 // First queue component is for the input index; 65 // Second queue component is for the key; 66 // remaining queue components are for the value. 67 queue_component_types.push_back(DT_INT64); 68 queue_component_types.push_back(DT_STRING); 69 for (DataType dt : value_component_types) { 70 queue_component_types.push_back(dt); 71 } 72 73 // NOTE(mrry): PriorityQueue expects all shapes specified because 74 // we'll be issuing TakeMany. 75 queue_component_shapes.push_back(TensorShape({})); 76 queue_component_shapes.push_back(TensorShape({})); 77 queue_component_shapes.insert(queue_component_shapes.end(), 78 value_component_shapes.begin(), 79 value_component_shapes.end()); 80 81 ready_queue_ = new PriorityQueue( 82 QueueBase::kUnbounded /* capacity */, queue_component_types, 83 queue_component_shapes, strings::StrCat(name_, "_queue")); 84 } 85 86 Status Initialize() { return ready_queue_->Initialize(); } 87 88 template <typename T> 89 void TryInsertMany(const Tensor& keys, int component_index, 90 const Tensor& values, OpKernelContext* ctx, 91 const DoneCallback& callback) { 92 TensorShape element_shape = values.shape(); 93 OP_REQUIRES_ASYNC( 94 ctx, keys.NumElements() == 0 || element_shape.num_elements() > 0, 95 errors::InvalidArgument("Tensors with no elements are not supported ", 96 name_, ": received shape ", 97 element_shape.DebugString()), 98 callback); 99 if (element_shape.dims() > 0) element_shape.RemoveDim(0); 100 const std::size_t num_inserted = keys.NumElements(); 101 102 // For each key, update the corresponding incomplete tuple with the 103 // the corresponding given value at component_index. 104 // This will be passed to the final callback at the very end. 105 bool new_elements = false; 106 107 // Will be used for the final insert into the queue. 108 Tuple insert_tuple; 109 110 { 111 mutex_lock lock(mu_); 112 if (closed_) { 113 OP_REQUIRES_ASYNC( 114 ctx, 115 !cancel_pending_enqueues_ && 116 (num_inserted == 0 || !incomplete_.empty()), 117 errors::Cancelled( 118 "Barrier ", name_, " is closed. Pending enqueues cancelled: ", 119 cancel_pending_enqueues_, 120 ". Number of new insertions: ", num_inserted, 121 ". Number of incomplete keys: ", incomplete_.size(), "."), 122 callback); 123 } 124 125 // Step 1: insert into the incomplete map and identify which 126 // entries are, in fact, complete and ready for enqueueing. Store 127 // them in a vector 128 std::vector<Tuple> ready_tuples; 129 130 for (int i = 0; i < num_inserted; ++i) { 131 OP_REQUIRES_OK_ASYNC( 132 ctx, 133 InsertOneLocked<T>(ctx, keys, values, element_shape, 134 component_index, i, &ready_tuples, 135 &new_elements), 136 callback); 137 } 138 139 if (new_elements) ++input_index_; 140 141 // This probably won't happen before the heat death of the 142 // universe, but who knows? Moore's law FTW. 143 OP_REQUIRES_ASYNC( 144 ctx, input_index_ != std::numeric_limits<int64>::max(), 145 errors::Internal( 146 "Barrier has had ", input_index_, 147 " insertions and can no longer keep track of new ones."), 148 callback); 149 150 if (ready_tuples.empty()) { 151 // Nothing to insert into the queue - so return early. 152 callback(); 153 return; 154 } 155 156 // We have something to Enqueue. Convert the Tuples into a single 157 // tuple by slicing entries into new Tensors. This part is slow 158 // but seems the cleanest solution for now. 159 insert_tuple.reserve(2 + num_components()); // indices, keys, rest 160 int insertion_size = ready_tuples.size(); 161 for (int i = 0; i < 2 + num_components(); ++i) { 162 TensorShape component_shape(ready_tuples[0][i].shape()); 163 component_shape.InsertDim(0, insertion_size); 164 Tensor component(ready_tuples[0][i].dtype(), component_shape); 165 for (int b = 0; b < insertion_size; ++b) { 166 OP_REQUIRES_OK_ASYNC( 167 ctx, 168 batch_util::CopyElementToSlice(std::move(ready_tuples[b][i]), 169 &component, b), 170 callback); 171 } 172 insert_tuple.push_back(component); 173 } 174 } 175 176 // Update the input index for the next batch. 177 ready_queue_->TryEnqueueMany( 178 insert_tuple, ctx, 179 // To avoid early closing of the queue, only close it if the 180 // SQSS is closed, nothing is left in the incomplete set, 181 // the queue is not already marked as closed, and (most 182 // importantly), the queue has entries in it. 183 [this, ctx, callback, component_index]() { 184 if (!ctx->status().ok()) { 185 callback(); 186 return; 187 } 188 { 189 mutex_lock lock(mu_); 190 int32 ready = ready_size(); 191 if (closed_ && incomplete_.empty() && queue_closed_ && ready > 0) { 192 CloseQueueLocked(ctx, false, callback); 193 } else { 194 callback(); 195 } 196 return; 197 } 198 }); 199 } 200 201 void TryTakeMany(int num_elements, bool allow_small_batch, int64 timeout, 202 OpKernelContext* ctx, 203 const IndicesKeysValuesCallback& callback) { 204 int num_elements_to_deliver = num_elements; 205 { 206 mutex_lock lock(mu_); 207 if (closed_) { 208 int available_elements = ready_size(); 209 if (allow_small_batch) { 210 // We want to deliver a maximum of num_elements, if there are less 211 // elements available, we deliver at most the available_elements. If 212 // there are no 213 // elements available, a call to TryTakeMany should fail with 214 // OutOfRange. We trigger this error by setting the request here to 1. 215 num_elements_to_deliver = std::min(num_elements, available_elements); 216 } else { 217 // We're happy to wait for additional elements to be completed. 218 available_elements += incomplete_.size(); 219 } 220 // If there are 0 available elements or less elements than the 221 // number we can deliver, then we are done. 222 if (available_elements < std::max(num_elements_to_deliver, 1)) { 223 ctx->SetStatus(errors::OutOfRange( 224 "Barrier '", name_, "' is closed and has ", 225 "insufficient elements (requested ", num_elements_to_deliver, 226 ", total size ", available_elements, ")")); 227 callback(Tensor(DT_INT64), Tensor(DT_STRING), Tuple()); 228 return; 229 } 230 } 231 } 232 233 ready_queue_->TryDequeueMany( 234 num_elements_to_deliver, ctx, allow_small_batch, 235 [this, ctx, callback](const Tuple& t) { 236 Tensor indices(DT_INT64); 237 Tensor keys(DT_STRING); 238 Tuple values; 239 240 if (!ctx->status().ok()) { 241 callback(indices, keys, values); 242 return; 243 } 244 245 CHECK_EQ(t.size(), 2 + num_components()); 246 indices = t[0]; 247 keys = t[1]; 248 values.insert(values.begin(), t.begin() + 2, t.end()); 249 callback(indices, keys, values); 250 return; 251 }); 252 } 253 254 void Close(OpKernelContext* ctx, bool cancel_pending_enqueues, 255 const DoneCallback& callback) { 256 mutex_lock lock(mu_); 257 // We're allowed to close twice if the first close wasn't a 258 // cancel but the second one is. 259 if (closed_ && (cancel_pending_enqueues_ || !cancel_pending_enqueues)) { 260 ctx->SetStatus( 261 errors::Cancelled("Barrier '", name_, "' is already closed.")); 262 callback(); 263 return; 264 } 265 cancel_pending_enqueues_ = cancel_pending_enqueues; 266 closed_ = true; 267 if (cancel_pending_enqueues_ || incomplete_.empty()) { 268 incomplete_.clear(); 269 // CloseQueueLocked runs the callback 270 CloseQueueLocked(ctx, cancel_pending_enqueues_, callback); 271 return; 272 } 273 callback(); 274 } 275 276 int32 ready_size() { return ready_queue_->size(); } 277 278 int32 incomplete_size() { 279 mutex_lock lock(mu_); 280 return incomplete_.size(); 281 } 282 283 const string& name() const { return name_; } 284 int num_components() const { return value_component_types_.size(); } 285 DataType component_type(int i) const { 286 CHECK_GE(i, 0); 287 CHECK_LT(static_cast<size_t>(i), value_component_types_.size()); 288 return value_component_types_[i]; 289 } 290 const DataTypeVector component_types() const { 291 return value_component_types_; 292 } 293 const gtl::ArraySlice<TensorShape> component_shapes() const { 294 return value_component_shapes_; 295 } 296 297 ~Barrier() override EXCLUSIVE_LOCKS_REQUIRED(mu_) { 298 mutex_lock lock(mu_); 299 incomplete_.clear(); 300 ready_queue_->Unref(); 301 } 302 303 string DebugString() override { return "A barrier"; } 304 305 protected: 306 template <typename T> 307 Status InsertOneLocked(OpKernelContext* ctx, const Tensor& keys, 308 const Tensor& values, const TensorShape& element_shape, 309 int component_index, int i, 310 std::vector<Tuple>* ready_tuples, bool* new_elements) 311 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 312 auto keys_vec = keys.flat<string>(); 313 auto values_matrix = values.flat_outer_dims<T>(); 314 315 PersistentTuple* element_ptr; 316 if (closed_) { 317 element_ptr = gtl::FindOrNull(incomplete_, keys_vec(i)); 318 if (element_ptr == nullptr) { 319 return errors::Cancelled( 320 "Barrier ", name_, 321 " is closed, but attempted to insert a brand new key: ", 322 keys_vec(i), 323 ". Pending enqueues cancelled: ", cancel_pending_enqueues_, 324 ". Insertion index: ", i, 325 ". Number of incomplete keys: ", incomplete_.size(), "."); 326 } 327 } else { 328 element_ptr = 329 >l::LookupOrInsert(&incomplete_, keys_vec(i), PersistentTuple()); 330 } 331 PersistentTuple& element = *element_ptr; 332 333 if (element.empty()) { // Never seen before key 334 // Added a new element, for keeping track of the insertion index 335 *new_elements = true; 336 337 // Initialize the incomplete tuple for a new key. 338 element.reserve(1 + num_components()); 339 340 // The first entry in element is the priority: the 341 // input_index_, so that tensors that entered the Barrier 342 // earlier have higher priority in the queue. 343 PersistentTensor index_persistent_tensor; 344 Tensor* allocate_index_tensor; 345 TF_RETURN_IF_ERROR(ctx->allocate_persistent(DT_INT64, TensorShape({}), 346 &index_persistent_tensor, 347 &allocate_index_tensor)); 348 349 Tensor index_tensor(DT_INT64, TensorShape({})); 350 allocate_index_tensor->scalar<int64>()() = input_index_; 351 element.push_back(index_persistent_tensor); 352 353 // The rest of the element stores uninitialized Tensors with 354 // the appropriate dtype. 355 for (int j = 0; j < num_components(); ++j) { 356 Tensor uninitialized(component_type(j)); 357 element.push_back(PersistentTensor(uninitialized)); 358 } 359 } 360 const PersistentTensor& component = element[1 + component_index]; 361 if (component.IsInitialized() && component.NumElements() > 0) { 362 return errors::InvalidArgument("Key ", keys_vec(i), 363 " already has a value for component ", 364 component_index, " in barrier ", name()); 365 } 366 367 // Extract the slice corresponding to the value from the value Tensor, 368 // and store it in the incomplete tuple at component_index. 369 PersistentTensor next_element; 370 Tensor* allocated_element; 371 TF_RETURN_IF_ERROR(ctx->allocate_persistent( 372 values.dtype(), element_shape, &next_element, &allocated_element)); 373 element[1 + component_index] = next_element; 374 allocated_element->flat<T>() = values_matrix.template chip<0>(i); 375 376 // Check the components of the tuple to see if it has become complete 377 // (i.e. all of its components are initialized). If so, add it to the 378 // ready queue. 379 bool is_complete = true; 380 for (int j = 0; is_complete && j < element.size(); ++j) { 381 is_complete = element[j].IsInitialized() && element[j].NumElements() > 0; 382 } 383 if (is_complete) { 384 // Add tuple to the ready queue. A queue tuple has the index 385 // as the first element and the key as the second element, 386 // followed by the value components. 387 Tuple ready_tuple; 388 ready_tuple.reserve(2 + num_components()); // index, key, rest 389 // Build a tensor for the key. TODO(mrry): Something more efficient. 390 PersistentTensor key; 391 Tensor* allocated_key; 392 TF_RETURN_IF_ERROR(ctx->allocate_persistent(DT_STRING, TensorShape({}), 393 &key, &allocated_key)); 394 ready_tuple.push_back(*element[0].AccessTensor(ctx)); // index 395 ready_tuple.push_back(*allocated_key); // key 396 ready_tuple[1].scalar<string>()() = keys_vec(i); // set the key 397 for (int j = 1; j < num_components() + 1; ++j) { 398 ready_tuple.push_back(*element[j].AccessTensor(ctx)); 399 } 400 incomplete_.erase(incomplete_.find(keys_vec(i))); 401 TF_RETURN_IF_ERROR(ready_queue_->ValidateTuple(ready_tuple)); 402 ready_tuples->push_back(ready_tuple); 403 } 404 return Status::OK(); 405 } 406 407 void CloseQueueLocked(OpKernelContext* ctx, bool cancel_pending_enqueues, 408 const DoneCallback& callback) 409 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 410 // CloseQueueLocked may only be called with mu_ held. 411 if (!cancel_pending_enqueues && queue_closed_) { 412 callback(); 413 return; 414 } 415 if (cancel_pending_enqueues && queue_cancelled_) { 416 callback(); 417 return; 418 } 419 queue_closed_ = true; 420 if (cancel_pending_enqueues) queue_cancelled_ = true; 421 if (!ready_queue_->is_closed()) { 422 ready_queue_->Close(ctx, cancel_pending_enqueues, callback); 423 } 424 } 425 426 private: 427 typedef std::vector<PersistentTensor> PersistentTuple; 428 mutex mu_; 429 bool closed_ GUARDED_BY(mu_); 430 bool queue_closed_ GUARDED_BY(mu_); 431 bool queue_cancelled_ GUARDED_BY(mu_); 432 bool cancel_pending_enqueues_ GUARDED_BY(mu_); 433 const DataTypeVector value_component_types_; 434 const std::vector<TensorShape>& value_component_shapes_; 435 const string name_; 436 int64 input_index_ GUARDED_BY(mu_); 437 std::unordered_map<string, PersistentTuple> incomplete_ GUARDED_BY(mu_); 438 PriorityQueue* ready_queue_; 439 440 TF_DISALLOW_COPY_AND_ASSIGN(Barrier); 441 }; 442 443 class BarrierOp : public ResourceOpKernel<Barrier> { 444 public: 445 explicit BarrierOp(OpKernelConstruction* context) 446 : ResourceOpKernel(context) { 447 OP_REQUIRES_OK( 448 context, context->GetAttr("component_types", &value_component_types_)); 449 OP_REQUIRES_OK(context, 450 context->GetAttr("shapes", &value_component_shapes_)); 451 OP_REQUIRES(context, 452 value_component_shapes_.size() == value_component_types_.size(), 453 errors::InvalidArgument( 454 "All of the component shapes must be specified")); 455 456 int32 value_capacity; 457 OP_REQUIRES_OK(context, context->GetAttr("capacity", &value_capacity)); 458 OP_REQUIRES(context, value_capacity == -1, 459 errors::InvalidArgument( 460 "Barrier only accepts capacity=-1. Feed the " 461 "inputs to your Barrier through a queue to enforce a " 462 "limited capacity.")); 463 } 464 465 private: 466 Status CreateResource(Barrier** barrier) override 467 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 468 *barrier = new Barrier(value_component_types_, value_component_shapes_, 469 cinfo_.name()); 470 if (*barrier == nullptr) { 471 return errors::ResourceExhausted("Failed to allocate barrier"); 472 } 473 return (*barrier)->Initialize(); 474 } 475 476 Status VerifyResource(Barrier* barrier) override 477 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 478 if (barrier->component_types() != value_component_types_) { 479 return errors::InvalidArgument( 480 "Shared barrier '", cinfo_.name(), "' has component types ", 481 DataTypeSliceString(barrier->component_types()), 482 " but requested component types were ", 483 DataTypeSliceString(value_component_types_)); 484 } 485 if (barrier->component_shapes() != value_component_shapes_) { 486 return errors::InvalidArgument( 487 "Shared barrier '", cinfo_.name(), "' has component shapes ", 488 TensorShapeUtils::ShapeListString(barrier->component_shapes()), 489 " but requested component shapes were ", 490 TensorShapeUtils::ShapeListString(value_component_shapes_)); 491 } 492 return Status::OK(); 493 } 494 495 DataTypeVector value_component_types_; 496 std::vector<TensorShape> value_component_shapes_; 497 498 TF_DISALLOW_COPY_AND_ASSIGN(BarrierOp); 499 }; 500 501 REGISTER_KERNEL_BUILDER(Name("Barrier").Device(DEVICE_CPU), BarrierOp); 502 503 class BarrierOpKernel : public AsyncOpKernel { 504 public: 505 explicit BarrierOpKernel(OpKernelConstruction* context) 506 : AsyncOpKernel(context) {} 507 508 void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final { 509 Barrier* barrier = nullptr; 510 OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &barrier), 511 callback); 512 ComputeAsync(ctx, barrier, [this, callback, barrier]() { 513 barrier->Unref(); 514 callback(); 515 }); 516 } 517 518 protected: 519 virtual void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, 520 DoneCallback callback) = 0; 521 }; 522 523 template <typename T> 524 class InsertManyOp : public BarrierOpKernel { 525 public: 526 explicit InsertManyOp(OpKernelConstruction* context) 527 : BarrierOpKernel(context) { 528 OP_REQUIRES_OK(context, 529 context->GetAttr("component_index", &component_index_)); 530 } 531 532 protected: 533 void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, 534 DoneCallback callback) override { 535 OP_REQUIRES_ASYNC( 536 ctx, component_index_ < barrier->num_components(), 537 errors::InvalidArgument("The component ID is out of range ", 538 component_index_, " > num_components", 539 " (= ", barrier->num_components(), ")"), 540 callback); 541 OP_REQUIRES_OK_ASYNC( 542 ctx, 543 ctx->MatchSignature({DT_STRING_REF, DT_STRING, 544 barrier->component_type(component_index_)}, 545 {}), 546 callback); 547 548 const Tensor* keys; 549 const Tensor* values; 550 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("keys", &keys), callback); 551 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("values", &values), callback); 552 barrier->TryInsertMany<T>(*keys, component_index_, *values, ctx, callback); 553 } 554 555 private: 556 int component_index_; 557 TF_DISALLOW_COPY_AND_ASSIGN(InsertManyOp); 558 }; 559 560 #define REGISTER_INSERTMANY(T) \ 561 REGISTER_KERNEL_BUILDER( \ 562 Name("BarrierInsertMany").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 563 InsertManyOp<T>); 564 565 TF_CALL_ALL_TYPES(REGISTER_INSERTMANY); 566 #undef REGISTER_INSERTMANY 567 568 class TakeManyOp : public BarrierOpKernel { 569 public: 570 explicit TakeManyOp(OpKernelConstruction* context) 571 : BarrierOpKernel(context) { 572 OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_)); 573 // TODO(keveman): Enable timeout. 574 OP_REQUIRES(context, timeout_ == -1, 575 errors::InvalidArgument("Timeout not supported yet.")); 576 577 OP_REQUIRES_OK(context, 578 context->GetAttr("allow_small_batch", &allow_small_batch_)); 579 } 580 581 protected: 582 void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, 583 DoneCallback callback) override { 584 const Tensor* Tnum_elements; 585 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("num_elements", &Tnum_elements), 586 callback); 587 OP_REQUIRES_ASYNC(ctx, TensorShapeUtils::IsScalar(Tnum_elements->shape()), 588 errors::InvalidArgument("num_elements must be a scalar."), 589 callback); 590 const int32 num_elements = Tnum_elements->scalar<int32>()(); 591 592 DataTypeVector expected_inputs = {DT_STRING_REF, DT_INT32}; 593 // The first output is the insertion index, the second output is the key. 594 DataTypeVector expected_outputs = {DT_INT64, DT_STRING}; 595 for (DataType dt : barrier->component_types()) { 596 expected_outputs.push_back(dt); 597 } 598 OP_REQUIRES_OK_ASYNC( 599 ctx, ctx->MatchSignature(expected_inputs, expected_outputs), callback); 600 601 barrier->TryTakeMany( 602 num_elements, allow_small_batch_, timeout_, ctx, 603 [ctx, callback](const Tensor& indices, const Tensor& keys, 604 const Barrier::Tuple& values) { 605 if (!ctx->status().ok()) { 606 callback(); 607 return; 608 } 609 // At this point, indices, keys, and values 610 // have all been written to successfully. 611 OP_REQUIRES_OK_ASYNC(ctx, ctx->set_output("indices", indices), 612 callback); 613 OP_REQUIRES_OK_ASYNC(ctx, ctx->set_output("keys", keys), callback); 614 OpOutputList values_output; 615 OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("values", &values_output), 616 callback); 617 for (size_t i = 0; i < values.size(); ++i) { 618 values_output.set(i, values[i]); 619 } 620 callback(); 621 return; 622 }); 623 } 624 625 private: 626 int64 timeout_; 627 bool allow_small_batch_; 628 TF_DISALLOW_COPY_AND_ASSIGN(TakeManyOp); 629 }; 630 631 REGISTER_KERNEL_BUILDER(Name("BarrierTakeMany").Device(DEVICE_CPU), TakeManyOp); 632 633 class BarrierCloseOp : public BarrierOpKernel { 634 public: 635 explicit BarrierCloseOp(OpKernelConstruction* context) 636 : BarrierOpKernel(context) { 637 OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues", 638 &cancel_pending_enqueues_)); 639 } 640 641 protected: 642 void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, 643 DoneCallback callback) override { 644 barrier->Close(ctx, cancel_pending_enqueues_, callback); 645 } 646 647 private: 648 bool cancel_pending_enqueues_; 649 TF_DISALLOW_COPY_AND_ASSIGN(BarrierCloseOp); 650 }; 651 652 REGISTER_KERNEL_BUILDER(Name("BarrierClose").Device(DEVICE_CPU), 653 BarrierCloseOp); 654 655 class BarrierIncompleteSizeOp : public BarrierOpKernel { 656 public: 657 explicit BarrierIncompleteSizeOp(OpKernelConstruction* context) 658 : BarrierOpKernel(context) {} 659 660 protected: 661 void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, 662 DoneCallback callback) override { 663 Tensor* Tsize = nullptr; 664 OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &Tsize), 665 callback); 666 Tsize->scalar<int32>().setConstant(barrier->incomplete_size()); 667 callback(); 668 } 669 }; 670 671 REGISTER_KERNEL_BUILDER(Name("BarrierIncompleteSize").Device(DEVICE_CPU), 672 BarrierIncompleteSizeOp); 673 674 class BarrierReadySizeOp : public BarrierOpKernel { 675 public: 676 explicit BarrierReadySizeOp(OpKernelConstruction* context) 677 : BarrierOpKernel(context) {} 678 679 protected: 680 void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, 681 DoneCallback callback) override { 682 Tensor* Tsize = nullptr; 683 OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &Tsize), 684 callback); 685 Tsize->scalar<int32>().setConstant(barrier->ready_size()); 686 callback(); 687 } 688 }; 689 690 REGISTER_KERNEL_BUILDER(Name("BarrierReadySize").Device(DEVICE_CPU), 691 BarrierReadySizeOp); 692 693 } // namespace barrier 694 695 } // namespace tensorflow 696