1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/op_kernel.h" 17 #include "tensorflow/core/framework/register_types.h" 18 #include "tensorflow/core/framework/resource_mgr.h" 19 #include "tensorflow/core/framework/tensor.h" 20 #include "tensorflow/core/framework/tensor_util.h" 21 #include "tensorflow/core/framework/types.h" 22 #include "tensorflow/core/kernels/batching_util/periodic_function.h" 23 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h" 24 #include "tensorflow/core/kernels/concat_lib.h" 25 #include "tensorflow/core/kernels/ops_util.h" 26 #include "tensorflow/core/kernels/split_lib.h" 27 #include "tensorflow/core/lib/random/random.h" 28 #include "tensorflow/core/platform/macros.h" 29 30 namespace tensorflow { 31 32 typedef Eigen::ThreadPoolDevice CPUDevice; 33 typedef Eigen::GpuDevice GPUDevice; 34 #ifdef TENSORFLOW_USE_SYCL 35 typedef Eigen::SyclDevice SYCLDevice; 36 #endif // TENSORFLOW_USE_SYCL 37 38 // Concatenates 'inputs' into a single tensor along the zeroth dimension. 39 // Requires that all elements of 'inputs' have element type T. Writes to the 40 // op's output at position 'output_index', using 'context' for the allocation to 41 // ensure proper device placement. 42 template <typename T> 43 Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor>& inputs, 44 int output_index) { 45 const int input_dims = inputs[0].dims(); 46 const TensorShape& input_shape = inputs[0].shape(); 47 48 // Note that we reduce the concat of k-dimensional tensors into a two 49 // dimensional concat. Assuming the dimensions of any input tensor are 50 // {y0, y1,...,ym-1}, we flatten it to {1, y}, where y = Prod_i(yi). 51 std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> inputs_flat; 52 inputs_flat.reserve(inputs.size()); 53 int64 output_dim0 = 0; 54 for (size_t i = 0; i < inputs.size(); ++i) { 55 const Tensor& input = inputs[i]; 56 if (input.dims() != input_dims) { 57 return errors::InvalidArgument( 58 "Ranks of all input tensors should match: shape[0] = ", 59 input_shape.DebugString(), " vs. shape[", i, 60 "] = ", input.shape().DebugString()); 61 } 62 for (int j = 1; j < input_dims; ++j) { 63 if (input.dim_size(j) != input_shape.dim_size(j)) { 64 return errors::InvalidArgument( 65 "Dimensions of inputs should match: shape[0] = ", 66 input_shape.DebugString(), " vs. shape[", i, 67 "] = ", input.shape().DebugString()); 68 } 69 } 70 if (input.NumElements() > 0) { 71 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( 72 input.shaped<T, 2>({1, input.NumElements()}))); 73 } 74 output_dim0 += input.dim_size(0); 75 } 76 77 TensorShape output_shape(input_shape); 78 output_shape.set_dim(0, output_dim0); 79 Tensor* output = nullptr; 80 TF_RETURN_IF_ERROR( 81 context->allocate_output(output_index, output_shape, &output)); 82 if (output->NumElements() > 0) { 83 auto output_flat = output->shaped<T, 2>({1, output->NumElements()}); 84 #if GOOGLE_CUDA 85 if (std::is_same<Device, GPUDevice>::value) { 86 ConcatGPU<T>(context, inputs_flat, output, &output_flat); 87 return Status::OK(); 88 } 89 #endif // GOOGLE_CUDA 90 ConcatCPU<T>(context->device(), inputs_flat, &output_flat); 91 } 92 93 return Status::OK(); 94 } 95 96 // The Split*() functions split 'input' with element type T into 'sizes.size()' 97 // tensors along the zeroth dimension, with the ith split having zeroth- 98 // dimension size 'sizes[i]'. They allocate the output tensors using 'context', 99 // for proper device placement. 100 101 // Handles special cases that are cheap. Sets 'done==true' iff it found an 102 // applicable special case and wrote to the outputs. Otherwise acts as a no-op. 103 template <typename T> 104 Status SplitEasyCases(OpKernelContext* context, const Tensor& input, 105 const gtl::ArraySlice<int64>& sizes, 106 std::vector<Tensor>* outputs, bool* done) { 107 *done = false; 108 109 int64 total_size = 0; 110 for (const int64 size : sizes) { 111 total_size += size; 112 } 113 if (total_size > input.shape().dim_size(0)) { 114 return errors::InvalidArgument( 115 "Sum of split sizes must not exceed dim0-size of input tensor"); 116 } 117 118 // Special case 0: trivial 1-way split. 119 if (sizes.size() == 1 && sizes.at(0) == input.shape().dim_size(0)) { 120 outputs->push_back(input); 121 *done = true; 122 return Status::OK(); 123 } 124 125 // Special case 1: input is aligned. 126 if (IsInnerDimsSizeAligned<T>(input.shape())) { 127 int64 position = 0; 128 for (const int64 size : sizes) { 129 outputs->emplace_back(input.Slice(position, position + size)); 130 position += size; 131 } 132 *done = true; 133 return Status::OK(); 134 } 135 136 return Status::OK(); 137 } 138 139 // Handles the general case, on CPU. 140 template <typename T> 141 Status SplitCPU(OpKernelContext* context, const Tensor& input, 142 const gtl::ArraySlice<int64>& sizes, 143 std::vector<Tensor>* outputs) { 144 int64 suffix_dim_size = 1; 145 for (int i = 1; i < input.shape().dims(); ++i) { 146 suffix_dim_size *= input.shape().dim_size(i); 147 } 148 auto input_reshaped = 149 input.shaped<T, 3>({1, input.shape().dim_size(0), suffix_dim_size}); 150 151 int64 position = 0; 152 for (const int64 size : sizes) { 153 TensorShape output_shape = input.shape(); 154 output_shape.set_dim(0, size); 155 Tensor output; 156 TF_RETURN_IF_ERROR( 157 context->allocate_temp(input.dtype(), output_shape, &output)); 158 auto output_shaped = output.shaped<T, 3>({1, size, suffix_dim_size}); 159 160 Eigen::DSizes<Eigen::DenseIndex, 3> slice_indices{0, position, 0}; 161 Eigen::DSizes<Eigen::DenseIndex, 3> slice_sizes{1, size, suffix_dim_size}; 162 functor::Split<CPUDevice, T>()(context->eigen_device<CPUDevice>(), 163 output_shaped, input_reshaped, slice_indices, 164 slice_sizes); 165 166 outputs->emplace_back(output); 167 168 position += size; 169 } 170 171 return Status::OK(); 172 } 173 174 #if GOOGLE_CUDA 175 176 // Handles the general case, on GPU. 177 template <typename T> 178 Status SplitGPU(OpKernelContext* context, const Tensor& input, 179 const gtl::ArraySlice<int64>& sizes, 180 std::vector<Tensor>* outputs) { 181 // TODO(olston, apassos): Implement this. 182 LOG(FATAL) << "Not yet implemented"; // Crash ok 183 } 184 185 #endif // GOOGLE_CUDA 186 187 // The outer function that dispatches to the various Split*() functions above. 188 template <typename T> 189 Status Split(OpKernelContext* context, const Tensor& input, 190 const gtl::ArraySlice<int64>& sizes, 191 std::vector<Tensor>* outputs) { 192 bool easy_cases_done; 193 TF_RETURN_IF_ERROR( 194 SplitEasyCases<T>(context, input, sizes, outputs, &easy_cases_done)); 195 if (easy_cases_done) { 196 return Status::OK(); 197 } 198 199 #if GOOGLE_CUDA 200 // TODO(olston, apassos): Handle non-CPU cases. 201 // return SplitGPU<T>(context, input, sizes, outputs); 202 #endif // GOOGLE_CUDA 203 return SplitCPU<T>(context, input, sizes, outputs); 204 } 205 206 // A class encapsulating the state and logic for batching tensors. 207 class BatchResource : public ResourceBase { 208 public: 209 static Status Create(int32 num_batch_threads, int32 max_batch_size, 210 int32 batch_timeout_micros, int32 max_enqueued_batches, 211 const std::vector<int32>& allowed_batch_sizes, 212 std::unique_ptr<BatchResource>* resource) { 213 std::unique_ptr<BatchResource> new_resource(new BatchResource); 214 215 Batcher::Options batcher_options; 216 batcher_options.num_batch_threads = num_batch_threads; 217 TF_RETURN_IF_ERROR( 218 Batcher::Create(batcher_options, &new_resource->batcher_)); 219 220 new_resource->batcher_queue_options_.max_batch_size = max_batch_size; 221 new_resource->batcher_queue_options_.max_enqueued_batches = 222 max_enqueued_batches; 223 new_resource->batcher_queue_options_.batch_timeout_micros = 224 batch_timeout_micros; 225 226 new_resource->allowed_batch_sizes_ = allowed_batch_sizes; 227 228 *resource = std::move(new_resource); 229 return Status::OK(); 230 } 231 232 string DebugString() final { return "BatchResource"; } 233 234 // Ingests data from one invocation of the batch op. The data is enqueued to 235 // be combined with others into a batch, asynchronously. 236 Status RegisterInput(int64 guid, OpKernelContext* context, 237 const string& batcher_queue_name, 238 AsyncOpKernel::DoneCallback done_callback) { 239 std::unique_ptr<BatchTask> batch_components(new BatchTask); 240 batch_components->guid = guid; 241 OpInputList tensors; 242 TF_RETURN_IF_ERROR(context->input_list("in_tensors", &tensors)); 243 for (int i = 0; i < tensors.size(); ++i) { 244 const Tensor& tensor = tensors[i]; 245 if (tensor.shape().dims() == 0) { 246 return errors::InvalidArgument( 247 "Batching input tensors must have at least one dimension"); 248 } 249 if (tensors.size() >= 2 && 250 tensor.shape().dim_size(0) != tensors[0].shape().dim_size(0)) { 251 return errors::InvalidArgument( 252 "Batching input tensors supplied in a given op invocation must " 253 "have equal 0th-dimension size"); 254 } 255 batch_components->inputs.push_back(tensor); 256 } 257 batch_components->context = context; 258 batch_components->done_callback = std::move(done_callback); 259 260 BatcherQueue* batcher_queue; 261 TF_RETURN_IF_ERROR( 262 LookupOrCreateBatcherQueue(batcher_queue_name, &batcher_queue)); 263 return batcher_queue->Schedule(&batch_components); 264 } 265 266 private: 267 BatchResource() = default; 268 269 // One input to be batched. Corresponds to one invocation of the batch op. 270 struct BatchTask : public serving::BatchTask { 271 // A unique ID to identify this invocation of Batch. 272 int64 guid; 273 274 std::vector<Tensor> inputs; 275 OpKernelContext* context; 276 AsyncOpKernel::DoneCallback done_callback; 277 278 size_t size() const override { return inputs[0].shape().dim_size(0); } 279 }; 280 281 using Batcher = serving::SharedBatchScheduler<BatchTask>; 282 using BatcherQueue = serving::BatchScheduler<BatchTask>; 283 using Batch = serving::Batch<BatchTask>; 284 285 // Validates that it's legal to combine the tasks in 'batch' into a batch. 286 // Assumes the batch is non-empty. 287 static Status ValidateBatch(const Batch& batch) { 288 for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) { 289 const BatchTask& task = batch.task(task_idx); 290 291 if (task.inputs.size() != batch.task(0).inputs.size()) { 292 return errors::InvalidArgument( 293 "Batching inputs must have equal number of edges"); 294 } 295 } 296 297 return Status::OK(); 298 } 299 300 // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than 301 // or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply 302 // returns 'batch_size'. 303 int RoundToLowestAllowedBatchSize(int batch_size) const { 304 if (allowed_batch_sizes_.empty()) { 305 return batch_size; 306 } 307 for (int allowed_size : allowed_batch_sizes_) { 308 if (allowed_size >= batch_size) { 309 return allowed_size; 310 } 311 } 312 LOG(ERROR) << "Maximum batch size greater than largest allowed size; " 313 "ignoring allowed sizes constraint"; 314 return batch_size; 315 } 316 317 // Processes a batch of one or more BatchTask entries. 318 void ProcessBatch(std::unique_ptr<Batch> batch) const { 319 if (batch->empty()) { 320 return; 321 } 322 const int padded_batch_size = RoundToLowestAllowedBatchSize(batch->size()); 323 const int padding_amount = padded_batch_size - batch->size(); 324 325 OpKernelContext* last_task_context = 326 batch->task(batch->num_tasks() - 1).context; 327 AsyncOpKernel::DoneCallback last_task_callback = 328 batch->task(batch->num_tasks() - 1).done_callback; 329 330 OP_REQUIRES_OK_ASYNC(last_task_context, ValidateBatch(*batch), 331 last_task_callback); 332 333 // All tasks should have the same number of input edges. 334 const int num_input_edges = batch->task(0).inputs.size(); 335 336 // Process each input edge one at a time (the typical case has just one). 337 for (int i = 0; i < num_input_edges; ++i) { 338 // Emit batch->num_tasks() - 1 empty output tensors. 339 for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) { 340 const BatchTask& task = batch->task(task_idx); 341 TensorShape output_shape(task.inputs.at(i).shape()); 342 output_shape.set_dim(0, 0); 343 Tensor* output = nullptr; 344 OP_REQUIRES_OK_ASYNC( 345 task.context, 346 task.context->allocate_output(i, output_shape, &output), 347 task.done_callback); 348 } 349 350 // Concatenate the tasks ith input tensors into a big output tensor. 351 std::vector<Tensor> to_concatenate; 352 to_concatenate.reserve(batch->num_tasks()); 353 for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) { 354 to_concatenate.push_back(batch->task(task_idx).inputs.at(i)); 355 } 356 357 // Add padding as needed. Use the first row of the first task's tensor as 358 // the data for padding. 359 if (padding_amount > 0) { 360 const Tensor& padding_source = batch->task(0).inputs.at(i); 361 Tensor padding; 362 if (padding_source.shape().dim_size(0) == 1) { 363 padding = padding_source; 364 } else { 365 const std::vector<int64> slice_sizes = {1}; 366 const DataType type = padding_source.dtype(); 367 Status slice_status; 368 std::vector<Tensor> slices; 369 switch (type) { 370 #define CASE(type) \ 371 case DataTypeToEnum<type>::value: \ 372 slice_status = SplitCPU<type>(last_task_context, padding_source, \ 373 slice_sizes, &slices); \ 374 break; 375 TF_CALL_ALL_TYPES(CASE); 376 #undef CASE 377 default: 378 slice_status = 379 errors::InvalidArgument("Unsupported data type: ", type); 380 break; 381 } 382 OP_REQUIRES_OK_ASYNC(last_task_context, slice_status, 383 last_task_callback); 384 padding = slices.at(0); 385 } 386 for (int i = 0; i < padding_amount; ++i) { 387 to_concatenate.push_back(padding); 388 } 389 } 390 391 const DataType type = to_concatenate[0].dtype(); 392 Status concat_status; 393 switch (type) { 394 #define CASE(type) \ 395 case DataTypeToEnum<type>::value: \ 396 concat_status = Concat<type>(last_task_context, to_concatenate, i); \ 397 break; 398 TF_CALL_ALL_TYPES(CASE); 399 #undef CASE 400 default: 401 concat_status = 402 errors::InvalidArgument("Unsupported data type: ", type); 403 break; 404 } 405 OP_REQUIRES_OK_ASYNC(last_task_context, concat_status, 406 last_task_callback); 407 } 408 409 // Emit batch->num_tasks() - 1 empty index tensors. 410 for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) { 411 const BatchTask& task = batch->task(task_idx); 412 TensorShape index_shape({0, 3}); 413 Tensor* output = nullptr; 414 OP_REQUIRES_OK_ASYNC( 415 task.context, 416 task.context->allocate_output(num_input_edges, index_shape, &output), 417 task.done_callback); 418 } 419 // Emit all ID tensors. 420 for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) { 421 const BatchTask& task = batch->task(task_idx); 422 Tensor* id; 423 OP_REQUIRES_OK_ASYNC(task.context, 424 task.context->allocate_output(num_input_edges + 1, 425 TensorShape({}), &id), 426 task.done_callback); 427 id->scalar<int64>()() = task.guid; 428 } 429 OP_REQUIRES_OK_ASYNC( 430 last_task_context, 431 EmitIndexTensor(last_task_context, *batch, num_input_edges), 432 last_task_callback); 433 434 // Signal done for each element of the batch. (At this point, the contexts 435 // are no longer guaranteed to remain live.) 436 for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) { 437 batch->mutable_task(task_idx)->done_callback(); 438 } 439 } 440 441 // Emits an index tensor, which the Unbatch op will use to un-concatenate 442 // the tensor and attribute the pieces to the right batch keys. The index 443 // tensor contains, for each input: [batch_key, start_offset, end_offset] 444 // where start_offset and end_offset represent the range of entries in the 445 // concatenated tensors that belong to that input. 446 // 447 // Emits the result to the output at 'output_index' using 'context'. 448 static Status EmitIndexTensor(OpKernelContext* context, const Batch& batch, 449 int output_index) { 450 const TensorShape index_shape({batch.num_tasks(), 3}); 451 Tensor* index = nullptr; 452 TF_RETURN_IF_ERROR( 453 context->allocate_output(output_index, index_shape, &index)); 454 auto index_flat = index->shaped<int64, 2>({batch.num_tasks(), 3}); 455 size_t offset = 0; 456 for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) { 457 const BatchTask& task = batch.task(task_idx); 458 index_flat(task_idx, 0) = task.guid; 459 index_flat(task_idx, 1) = offset; 460 index_flat(task_idx, 2) = offset + task.size(); 461 offset += task.size(); 462 } 463 return Status::OK(); 464 } 465 466 // Looks up the batcher queue for 'queue_name'. If it didn't previously exist, 467 // creates it. 468 Status LookupOrCreateBatcherQueue(const string& queue_name, 469 BatcherQueue** queue) { 470 mutex_lock l(batcher_queues_mu_); 471 472 auto it = batcher_queues_.find(queue_name); 473 if (it != batcher_queues_.end()) { 474 *queue = it->second.get(); 475 return Status::OK(); 476 } 477 478 std::unique_ptr<BatcherQueue> new_queue; 479 auto process_batch_callback = [this](std::unique_ptr<Batch> batch) { 480 ProcessBatch(std::move(batch)); 481 }; 482 TF_RETURN_IF_ERROR(batcher_->AddQueue(batcher_queue_options_, 483 process_batch_callback, &new_queue)); 484 *queue = new_queue.get(); 485 batcher_queues_[queue_name] = std::move(new_queue); 486 return Status::OK(); 487 } 488 489 // A batch scheduler, and options for creating queues. 490 std::shared_ptr<Batcher> batcher_; 491 Batcher::QueueOptions batcher_queue_options_; 492 493 // A collection of batcher queues, keyed on queue name. 494 // TODO(olston): Garbage-collect unused queues (perhaps simply remove empty 495 // ones (with a time delay?); it's okay if they get recreated later). 496 mutable mutex batcher_queues_mu_; 497 std::map<string, std::unique_ptr<BatcherQueue>> batcher_queues_ 498 GUARDED_BY(batcher_queues_mu_); 499 500 std::vector<int32> allowed_batch_sizes_; 501 }; 502 503 class BatchKernel : public AsyncOpKernel { 504 public: 505 explicit BatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) { 506 OP_REQUIRES_OK(c, c->GetAttr("container", &container_)); 507 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_)); 508 // If shared_name is not supplied, use name instead (prevent collisions by 509 // default). 510 if (shared_name_.empty()) { 511 shared_name_ = name(); 512 } 513 OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_)); 514 OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_)); 515 OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_)); 516 OP_REQUIRES_OK(c, 517 c->GetAttr("batch_timeout_micros", &batch_timeout_micros_)); 518 OP_REQUIRES_OK(c, 519 c->GetAttr("max_enqueued_batches", &max_enqueued_batches_)); 520 OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_)); 521 OP_REQUIRES_OK(c, ValidateAllowedBatchSizes()); 522 } 523 524 void ComputeAsync(OpKernelContext* c, DoneCallback done) final { 525 BatchResource* br; 526 std::function<Status(BatchResource * *r)> creator = 527 [this](BatchResource** r) { 528 std::unique_ptr<BatchResource> new_resource; 529 TF_RETURN_IF_ERROR(BatchResource::Create( 530 num_batch_threads_, max_batch_size_, batch_timeout_micros_, 531 max_enqueued_batches_, allowed_batch_sizes_, &new_resource)); 532 *r = new_resource.release(); 533 return Status::OK(); 534 }; 535 OP_REQUIRES_OK_ASYNC(c, 536 c->resource_manager()->LookupOrCreate( 537 container_, shared_name_, &br, creator), 538 done); 539 const Status status = 540 br->RegisterInput(random::New64(), c, batcher_queue_, done); 541 br->Unref(); 542 if (!status.ok()) { 543 OP_REQUIRES_OK_ASYNC(c, status, done); 544 } 545 // Assume br calls done, so nothing to do here. 546 } 547 548 // Validates 'allowed_batch_sizes_'. The entries must increase monotonically, 549 // and the last one must equal 'max_batch_size_'. 550 Status ValidateAllowedBatchSizes() const { 551 if (allowed_batch_sizes_.empty()) { 552 return Status::OK(); 553 } 554 int32 last_size = 0; 555 for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) { 556 const int32 size = allowed_batch_sizes_.at(i); 557 if (i > 0 && size <= last_size) { 558 return errors::InvalidArgument( 559 "allowed_batch_sizes entries must be monotonically increasing"); 560 } 561 if (i == allowed_batch_sizes_.size() - 1 && size != max_batch_size_) { 562 return errors::InvalidArgument( 563 "final entry in allowed_batch_sizes must equal max_batch_size"); 564 } 565 last_size = size; 566 } 567 return Status::OK(); 568 } 569 570 private: 571 string container_; 572 string shared_name_; 573 string batcher_queue_; 574 int32 num_batch_threads_; 575 int32 max_batch_size_; 576 int32 batch_timeout_micros_; 577 int32 max_enqueued_batches_; 578 std::vector<int32> allowed_batch_sizes_; 579 }; 580 581 REGISTER_KERNEL_BUILDER(Name("Batch").Device(DEVICE_CPU), BatchKernel); 582 583 // A class encapsulating the state and logic for unbatching tensors. 584 // 585 // UnbatchResource keeps two data structures indexed by batch-key: one which has 586 // the continuations for all concurrent kernels which are waiting for tensors 587 // and another which has tensors which are waiting for their corresponding 588 // kernels to run. Whenever a kernel runs, we either grab its tensor if it's 589 // waiting already, or we insert it in the queue and then look at its tensor to 590 // see if it can be used to dispatch any stored continuations. 591 class UnbatchResource : public ResourceBase { 592 public: 593 explicit UnbatchResource(int32 timeout_micros) 594 : timeout_micros_(timeout_micros), 595 timeout_enforcer_(new serving::PeriodicFunction( 596 [this] { EnforceTimeout(); }, 1000 /* 1 ms */)) {} 597 598 ~UnbatchResource() override { 599 // Tear down 'timeout_enforcer_' first, since it accesses other state in 600 // this class. 601 timeout_enforcer_ = nullptr; 602 } 603 604 string DebugString() final { return "UnbatchResource"; } 605 606 Status Compute(OpKernelContext* context, AsyncOpKernel::DoneCallback done) { 607 const Tensor& data_t = context->input(0); 608 const Tensor& batch_index_t = context->input(1); 609 610 if (batch_index_t.shape().dim_size(0) > data_t.shape().dim_size(0)) { 611 return errors::InvalidArgument( 612 "Wrong shape for index tensor. Expected 0th dimension size to be no " 613 "greater than ", 614 data_t.shape().dim_size(0), 615 "; Got: ", batch_index_t.shape().dim_size(0), "."); 616 } 617 if (batch_index_t.shape().dim_size(1) != 3) { 618 return errors::InvalidArgument( 619 "Wrong shape for index tensor. Expected 1st dimension size to be 3 ; " 620 "Got: ", 621 batch_index_t.shape().dim_size(1), "."); 622 } 623 624 const int64 batch_key = context->input(2).scalar<int64>()(); 625 const bool nonempty_input = batch_index_t.dim_size(0) > 0; 626 627 // If we have a non-empty tensor, slice it up. 628 // (It is important to do this outside of the critical section below.) 629 // The following variables are populated iff 'nonempty_input==true'. 630 std::vector<int64> sizes; 631 std::vector<int64> batch_keys; 632 std::vector<Tensor> split_inputs; 633 if (nonempty_input) { 634 auto batch_indices = 635 batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3}); 636 for (int i = 0; i < batch_index_t.dim_size(0); ++i) { 637 sizes.push_back(batch_indices(i, 2) - batch_indices(i, 1)); 638 batch_keys.push_back(batch_indices(i, 0)); 639 } 640 641 const DataType type = data_t.dtype(); 642 switch (type) { 643 #define CASE(type) \ 644 case DataTypeToEnum<type>::value: \ 645 TF_RETURN_IF_ERROR(Split<type>(context, data_t, sizes, &split_inputs)); \ 646 break; 647 TF_CALL_ALL_TYPES(CASE); 648 #undef CASE 649 default: 650 return errors::InvalidArgument("Unsupported data type: ", type); 651 } 652 } 653 654 // Critical section. 655 std::vector<AsyncOpKernel::DoneCallback> done_callbacks_to_call; 656 Status status = [&]() -> Status { 657 mutex_lock ml(mu_); 658 659 // Check to see whether the tensor we want is already ready. 660 auto tensor_it = waiting_tensors_.find(batch_key); 661 if (tensor_it != waiting_tensors_.end()) { 662 context->set_output(0, tensor_it->second.tensor); 663 waiting_tensors_.erase(tensor_it); 664 done_callbacks_to_call.push_back(done); 665 return Status::OK(); 666 } 667 668 const uint64 deadline_micros = 669 Env::Default()->NowMicros() + timeout_micros_; 670 671 // Add ourselves to the waitlist for tensors. 672 if (!waiting_callbacks_ 673 .emplace(batch_key, 674 WaitingCallback{deadline_micros, context, done}) 675 .second) { 676 return errors::AlreadyExists( 677 "Multiple session runs with the same batch key."); 678 } 679 680 // If we have a non-empty tensor, finish the waitlisted runs, 681 // and store any remaining pieces. 682 if (nonempty_input) { 683 for (size_t i = 0; i < batch_keys.size(); ++i) { 684 auto runs_it = waiting_callbacks_.find(batch_keys[i]); 685 if (runs_it != waiting_callbacks_.end()) { 686 runs_it->second.context->set_output(0, split_inputs[i]); 687 done_callbacks_to_call.push_back(runs_it->second.done); 688 waiting_callbacks_.erase(runs_it); 689 } else { 690 // Note: the deadline here is in case we are arriving late and the 691 // kernel that should rendezvous with this tensor has already waited 692 // and timed out. 693 if (!waiting_tensors_ 694 .emplace(batch_keys[i], 695 WaitingTensor{deadline_micros, split_inputs[i]}) 696 .second) { 697 return errors::AlreadyExists( 698 "Multiple tensors returned for same batch key."); 699 } 700 } 701 } 702 } 703 704 return Status::OK(); 705 }(); 706 707 for (const AsyncOpKernel::DoneCallback& done_callback : 708 done_callbacks_to_call) { 709 done_callback(); 710 } 711 712 return status; 713 } 714 715 private: 716 // Evicts waiting tensors and callbacks that have exceeded their deadline. 717 void EnforceTimeout() { 718 const uint64 now = Env::Default()->NowMicros(); 719 std::vector<WaitingCallback> evicted_callbacks; 720 721 { 722 mutex_lock ml(mu_); 723 724 for (auto it = waiting_tensors_.begin(); it != waiting_tensors_.end();) { 725 const WaitingTensor& waiting_tensor = it->second; 726 if (waiting_tensor.deadline_micros < now) { 727 it = waiting_tensors_.erase(it); 728 } else { 729 ++it; 730 } 731 } 732 733 for (auto it = waiting_callbacks_.begin(); 734 it != waiting_callbacks_.end();) { 735 const WaitingCallback& waiting_callback = it->second; 736 if (waiting_callback.deadline_micros < now) { 737 evicted_callbacks.push_back(waiting_callback); 738 it = waiting_callbacks_.erase(it); 739 } else { 740 ++it; 741 } 742 } 743 } 744 745 for (const WaitingCallback& evicted_callback : evicted_callbacks) { 746 evicted_callback.context->CtxFailureWithWarning(errors::DeadlineExceeded( 747 "Batched data did not arrive within timeout window.")); 748 evicted_callback.done(); 749 } 750 } 751 752 struct WaitingTensor { 753 uint64 deadline_micros; 754 Tensor tensor; 755 }; 756 757 struct WaitingCallback { 758 uint64 deadline_micros; 759 OpKernelContext* context; 760 AsyncOpKernel::DoneCallback done; 761 }; 762 763 const int32 timeout_micros_; 764 765 mutex mu_; 766 767 // Maps keyed by BatchKey of tensors waiting for callbacks and callbacks 768 // waiting for tensors. 769 std::unordered_map<int64, WaitingTensor> waiting_tensors_ GUARDED_BY(mu_); 770 std::unordered_map<int64, WaitingCallback> waiting_callbacks_ GUARDED_BY(mu_); 771 772 // A thread that evicts waiting tensors and callbacks that have exceeded their 773 // deadline. 774 std::unique_ptr<serving::PeriodicFunction> timeout_enforcer_; 775 }; 776 777 class UnbatchKernel : public AsyncOpKernel { 778 public: 779 explicit UnbatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) { 780 OP_REQUIRES_OK(c, c->GetAttr("container", &container_)); 781 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_)); 782 // If shared_name is not supplied, use name instead (prevent collisions by 783 // default). 784 if (shared_name_.empty()) { 785 shared_name_ = name(); 786 } 787 OP_REQUIRES_OK(c, c->GetAttr("timeout_micros", &timeout_micros_)); 788 } 789 790 void ComputeAsync(OpKernelContext* c, DoneCallback done) final { 791 UnbatchResource* ubr; 792 std::function<Status(UnbatchResource * *r)> creator = 793 [this](UnbatchResource** r) { 794 *r = new UnbatchResource(timeout_micros_); 795 return Status::OK(); 796 }; 797 OP_REQUIRES_OK_ASYNC(c, 798 c->resource_manager()->LookupOrCreate( 799 container_, shared_name_, &ubr, creator), 800 done); 801 auto status = ubr->Compute(c, done); 802 ubr->Unref(); 803 if (!status.ok()) { 804 OP_REQUIRES_OK_ASYNC(c, status, done); 805 } 806 // Assume ubr calls done, so nothing to do here. 807 } 808 809 private: 810 string container_; 811 string shared_name_; 812 int32 timeout_micros_; 813 }; 814 REGISTER_KERNEL_BUILDER(Name("Unbatch").Device(DEVICE_CPU), UnbatchKernel); 815 816 // A class encapsulating the state and logic for batching tensors 817 // deterministically for the gradient of unbatch. 818 class UnbatchGradResource : public ResourceBase { 819 public: 820 UnbatchGradResource() {} 821 822 string DebugString() final { return "UnbatchGradResource"; } 823 824 // Flushes the information for one batch, given its context and done 825 // callback. Clears all information about it from the available_tensors_. 826 Status OutputBatch(OpKernelContext* context, 827 const AsyncOpKernel::DoneCallback& done) 828 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 829 const Tensor& batch_index_t = context->input(1); 830 auto batch_index = 831 batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3}); 832 std::vector<Tensor> tensors; 833 for (int i = 0; i < batch_index_t.dim_size(0); ++i) { 834 auto available_it = available_tensors_.find(batch_index(i, 0)); 835 if (available_it == available_tensors_.end()) { 836 return errors::Internal("bad bookkeeping of available tensors."); 837 } 838 tensors.push_back(available_it->second); 839 available_tensors_.erase(available_it); 840 } 841 842 const DataType type = tensors[0].dtype(); 843 switch (type) { 844 #define CASE(type) \ 845 case DataTypeToEnum<type>::value: \ 846 TF_RETURN_IF_ERROR(Concat<type>(context, tensors, 0)); \ 847 break; 848 TF_CALL_ALL_TYPES(CASE); 849 #undef CASE 850 default: 851 return errors::InvalidArgument("Unsupported data type: ", type); 852 } 853 done(); 854 return Status::OK(); 855 } 856 857 // Ingests data from one invocation of the op. 858 Status Compute(OpKernelContext* context, 859 const AsyncOpKernel::DoneCallback& done) { 860 const Tensor& data_t = context->input(0); 861 const Tensor& batch_index_t = context->input(1); 862 const Tensor& grad_t = context->input(2); 863 864 mutex_lock ml(mu_); 865 866 const int64 batch_key = context->input(3).scalar<int64>()(); 867 // Mark our tensor as available. 868 if (!available_tensors_.emplace(batch_key, grad_t).second) { 869 return errors::InvalidArgument("Two runs with the same batch key."); 870 } 871 872 // Check whether we have a valid input tensor and, if so, create its 873 // dispatch logic. 874 if (data_t.NumElements() > 0) { 875 if (batch_index_t.NumElements() == 0) { 876 return errors::InvalidArgument( 877 "batch_index is empty while the tensor isn't."); 878 } 879 std::unordered_set<int64> missing_tensors; 880 const auto batch_index = 881 batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3}); 882 for (int i = 0; i < batch_index_t.dim_size(0); ++i) { 883 const int64 batch_key = batch_index(i, 0); 884 if (available_tensors_.find(batch_key) == available_tensors_.end()) { 885 missing_tensors.emplace(batch_key); 886 } 887 } 888 if (missing_tensors.empty()) { 889 return OutputBatch(context, done); 890 } 891 if (!available_batches_ 892 .emplace(batch_key, Batch{missing_tensors, context, done}) 893 .second) { 894 return errors::InvalidArgument( 895 "Batch key with valid batch used twice."); 896 } 897 for (const int64 i : missing_tensors) { 898 if (!desired_tensor_to_batch_map_.emplace(i, batch_key).second) { 899 return errors::InvalidArgument( 900 "Missing tensor wanted by more than one batch."); 901 } 902 } 903 } else { 904 // If we don't have a valid input tensor we can output an empty tensor and 905 // call our done closure. 906 TensorShape output_shape(grad_t.shape()); 907 output_shape.set_dim(0, 0); 908 Tensor* output = nullptr; 909 TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, &output)); 910 done(); 911 } 912 913 // Search to see whether our tensor is desired by any existing batch. 914 auto desire_it = desired_tensor_to_batch_map_.find(batch_key); 915 if (desire_it != desired_tensor_to_batch_map_.end()) { 916 // Mark our tensor as no longer missing. 917 auto batch_it = available_batches_.find(desire_it->second); 918 desired_tensor_to_batch_map_.erase(desire_it); 919 if (batch_it == available_batches_.end()) { 920 return errors::InvalidArgument("Batch no longer exists."); 921 } 922 batch_it->second.missing_tensors.erase(batch_key); 923 // If all tensors are available we should concatenate them and dispatch 924 // the batch. 925 if (batch_it->second.missing_tensors.empty()) { 926 TF_RETURN_IF_ERROR( 927 OutputBatch(batch_it->second.context, batch_it->second.done)); 928 available_batches_.erase(batch_it); 929 } 930 } 931 return Status::OK(); 932 } 933 934 private: 935 mutex mu_; 936 937 // Represents a still-incomplete batch of tensors. When all tensors become 938 // available they will be concatenated in the right order and sent through the 939 // context. 940 struct Batch { 941 // Batch keys for tensors which are still missing from this batch. When this 942 // is empty the Tensors can be concatenated and forwarded. 943 std::unordered_set<int64> missing_tensors; 944 945 // Context and callback for the session responsible for finishing this 946 // batch. 947 OpKernelContext* context; 948 AsyncOpKernel::DoneCallback done; 949 }; 950 951 // Map from batch key of the session which will output the batched gradients 952 // to still-incomplete batches. 953 std::unordered_map<int64, Batch> available_batches_; 954 955 // Map from batch key to tensors which are waiting for their batches to be 956 // available. 957 std::unordered_map<int64, Tensor> available_tensors_; 958 959 // Map from batch key of a tensor which is not yet available to the batch key 960 // of the batch to which it belongs. 961 std::unordered_map<int64, int64> desired_tensor_to_batch_map_; 962 }; 963 964 class UnbatchGradKernel : public AsyncOpKernel { 965 public: 966 explicit UnbatchGradKernel(OpKernelConstruction* c) : AsyncOpKernel(c) { 967 OP_REQUIRES_OK(c, c->GetAttr("container", &container_)); 968 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_)); 969 // If shared_name is not supplied, use name instead (prevent collisions by 970 // default). 971 if (shared_name_.empty()) { 972 shared_name_ = name(); 973 } 974 } 975 976 void ComputeAsync(OpKernelContext* c, DoneCallback done) final { 977 UnbatchGradResource* ubr; 978 std::function<Status(UnbatchGradResource * *r)> creator = 979 [this](UnbatchGradResource** r) { 980 *r = new UnbatchGradResource(); 981 return Status::OK(); 982 }; 983 OP_REQUIRES_OK_ASYNC(c, 984 c->resource_manager()->LookupOrCreate( 985 container_, shared_name_, &ubr, creator), 986 done); 987 Status status = ubr->Compute(c, done); 988 ubr->Unref(); 989 if (!status.ok()) { 990 OP_REQUIRES_OK_ASYNC(c, status, done); 991 } 992 // Assume ubr calls done, so nothing to do here. 993 } 994 995 private: 996 string container_; 997 string shared_name_; 998 }; 999 REGISTER_KERNEL_BUILDER(Name("UnbatchGrad").Device(DEVICE_CPU), 1000 UnbatchGradKernel); 1001 1002 } // namespace tensorflow 1003