Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/op_kernel.h"
     17 #include "tensorflow/core/framework/register_types.h"
     18 #include "tensorflow/core/framework/resource_mgr.h"
     19 #include "tensorflow/core/framework/tensor.h"
     20 #include "tensorflow/core/framework/tensor_util.h"
     21 #include "tensorflow/core/framework/types.h"
     22 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
     23 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
     24 #include "tensorflow/core/kernels/concat_lib.h"
     25 #include "tensorflow/core/kernels/ops_util.h"
     26 #include "tensorflow/core/kernels/split_lib.h"
     27 #include "tensorflow/core/lib/random/random.h"
     28 #include "tensorflow/core/platform/macros.h"
     29 
     30 namespace tensorflow {
     31 
     32 typedef Eigen::ThreadPoolDevice CPUDevice;
     33 typedef Eigen::GpuDevice GPUDevice;
     34 #ifdef TENSORFLOW_USE_SYCL
     35 typedef Eigen::SyclDevice SYCLDevice;
     36 #endif  // TENSORFLOW_USE_SYCL
     37 
     38 // Concatenates 'inputs' into a single tensor along the zeroth dimension.
     39 // Requires that all elements of 'inputs' have element type T. Writes to the
     40 // op's output at position 'output_index', using 'context' for the allocation to
     41 // ensure proper device placement.
     42 template <typename T>
     43 Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor>& inputs,
     44               int output_index) {
     45   const int input_dims = inputs[0].dims();
     46   const TensorShape& input_shape = inputs[0].shape();
     47 
     48   // Note that we reduce the concat of k-dimensional tensors into a two
     49   // dimensional concat. Assuming the dimensions of any input tensor are
     50   // {y0, y1,...,ym-1}, we flatten it to {1, y}, where y = Prod_i(yi).
     51   std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> inputs_flat;
     52   inputs_flat.reserve(inputs.size());
     53   int64 output_dim0 = 0;
     54   for (size_t i = 0; i < inputs.size(); ++i) {
     55     const Tensor& input = inputs[i];
     56     if (input.dims() != input_dims) {
     57       return errors::InvalidArgument(
     58           "Ranks of all input tensors should match: shape[0] = ",
     59           input_shape.DebugString(), " vs. shape[", i,
     60           "] = ", input.shape().DebugString());
     61     }
     62     for (int j = 1; j < input_dims; ++j) {
     63       if (input.dim_size(j) != input_shape.dim_size(j)) {
     64         return errors::InvalidArgument(
     65             "Dimensions of inputs should match: shape[0] = ",
     66             input_shape.DebugString(), " vs. shape[", i,
     67             "] = ", input.shape().DebugString());
     68       }
     69     }
     70     if (input.NumElements() > 0) {
     71       inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
     72           input.shaped<T, 2>({1, input.NumElements()})));
     73     }
     74     output_dim0 += input.dim_size(0);
     75   }
     76 
     77   TensorShape output_shape(input_shape);
     78   output_shape.set_dim(0, output_dim0);
     79   Tensor* output = nullptr;
     80   TF_RETURN_IF_ERROR(
     81       context->allocate_output(output_index, output_shape, &output));
     82   if (output->NumElements() > 0) {
     83     auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
     84 #if GOOGLE_CUDA
     85     if (std::is_same<Device, GPUDevice>::value) {
     86       ConcatGPU<T>(context, inputs_flat, output, &output_flat);
     87       return Status::OK();
     88     }
     89 #endif  // GOOGLE_CUDA
     90     ConcatCPU<T>(context->device(), inputs_flat, &output_flat);
     91   }
     92 
     93   return Status::OK();
     94 }
     95 
     96 // The Split*() functions split 'input' with element type T into 'sizes.size()'
     97 // tensors along the zeroth dimension, with the ith split having zeroth-
     98 // dimension size 'sizes[i]'. They allocate the output tensors using 'context',
     99 // for proper device placement.
    100 
    101 // Handles special cases that are cheap. Sets 'done==true' iff it found an
    102 // applicable special case and wrote to the outputs. Otherwise acts as a no-op.
    103 template <typename T>
    104 Status SplitEasyCases(OpKernelContext* context, const Tensor& input,
    105                       const gtl::ArraySlice<int64>& sizes,
    106                       std::vector<Tensor>* outputs, bool* done) {
    107   *done = false;
    108 
    109   int64 total_size = 0;
    110   for (const int64 size : sizes) {
    111     total_size += size;
    112   }
    113   if (total_size > input.shape().dim_size(0)) {
    114     return errors::InvalidArgument(
    115         "Sum of split sizes must not exceed dim0-size of input tensor");
    116   }
    117 
    118   // Special case 0: trivial 1-way split.
    119   if (sizes.size() == 1 && sizes.at(0) == input.shape().dim_size(0)) {
    120     outputs->push_back(input);
    121     *done = true;
    122     return Status::OK();
    123   }
    124 
    125   // Special case 1: input is aligned.
    126   if (IsInnerDimsSizeAligned<T>(input.shape())) {
    127     int64 position = 0;
    128     for (const int64 size : sizes) {
    129       outputs->emplace_back(input.Slice(position, position + size));
    130       position += size;
    131     }
    132     *done = true;
    133     return Status::OK();
    134   }
    135 
    136   return Status::OK();
    137 }
    138 
    139 // Handles the general case, on CPU.
    140 template <typename T>
    141 Status SplitCPU(OpKernelContext* context, const Tensor& input,
    142                 const gtl::ArraySlice<int64>& sizes,
    143                 std::vector<Tensor>* outputs) {
    144   int64 suffix_dim_size = 1;
    145   for (int i = 1; i < input.shape().dims(); ++i) {
    146     suffix_dim_size *= input.shape().dim_size(i);
    147   }
    148   auto input_reshaped =
    149       input.shaped<T, 3>({1, input.shape().dim_size(0), suffix_dim_size});
    150 
    151   int64 position = 0;
    152   for (const int64 size : sizes) {
    153     TensorShape output_shape = input.shape();
    154     output_shape.set_dim(0, size);
    155     Tensor output;
    156     TF_RETURN_IF_ERROR(
    157         context->allocate_temp(input.dtype(), output_shape, &output));
    158     auto output_shaped = output.shaped<T, 3>({1, size, suffix_dim_size});
    159 
    160     Eigen::DSizes<Eigen::DenseIndex, 3> slice_indices{0, position, 0};
    161     Eigen::DSizes<Eigen::DenseIndex, 3> slice_sizes{1, size, suffix_dim_size};
    162     functor::Split<CPUDevice, T>()(context->eigen_device<CPUDevice>(),
    163                                    output_shaped, input_reshaped, slice_indices,
    164                                    slice_sizes);
    165 
    166     outputs->emplace_back(output);
    167 
    168     position += size;
    169   }
    170 
    171   return Status::OK();
    172 }
    173 
    174 #if GOOGLE_CUDA
    175 
    176 // Handles the general case, on GPU.
    177 template <typename T>
    178 Status SplitGPU(OpKernelContext* context, const Tensor& input,
    179                 const gtl::ArraySlice<int64>& sizes,
    180                 std::vector<Tensor>* outputs) {
    181   // TODO(olston, apassos): Implement this.
    182   LOG(FATAL) << "Not yet implemented";  // Crash ok
    183 }
    184 
    185 #endif  // GOOGLE_CUDA
    186 
    187 // The outer function that dispatches to the various Split*() functions above.
    188 template <typename T>
    189 Status Split(OpKernelContext* context, const Tensor& input,
    190              const gtl::ArraySlice<int64>& sizes,
    191              std::vector<Tensor>* outputs) {
    192   bool easy_cases_done;
    193   TF_RETURN_IF_ERROR(
    194       SplitEasyCases<T>(context, input, sizes, outputs, &easy_cases_done));
    195   if (easy_cases_done) {
    196     return Status::OK();
    197   }
    198 
    199 #if GOOGLE_CUDA
    200 // TODO(olston, apassos): Handle non-CPU cases.
    201 // return SplitGPU<T>(context, input, sizes, outputs);
    202 #endif  // GOOGLE_CUDA
    203   return SplitCPU<T>(context, input, sizes, outputs);
    204 }
    205 
    206 // A class encapsulating the state and logic for batching tensors.
    207 class BatchResource : public ResourceBase {
    208  public:
    209   static Status Create(int32 num_batch_threads, int32 max_batch_size,
    210                        int32 batch_timeout_micros, int32 max_enqueued_batches,
    211                        const std::vector<int32>& allowed_batch_sizes,
    212                        std::unique_ptr<BatchResource>* resource) {
    213     std::unique_ptr<BatchResource> new_resource(new BatchResource);
    214 
    215     Batcher::Options batcher_options;
    216     batcher_options.num_batch_threads = num_batch_threads;
    217     TF_RETURN_IF_ERROR(
    218         Batcher::Create(batcher_options, &new_resource->batcher_));
    219 
    220     new_resource->batcher_queue_options_.max_batch_size = max_batch_size;
    221     new_resource->batcher_queue_options_.max_enqueued_batches =
    222         max_enqueued_batches;
    223     new_resource->batcher_queue_options_.batch_timeout_micros =
    224         batch_timeout_micros;
    225 
    226     new_resource->allowed_batch_sizes_ = allowed_batch_sizes;
    227 
    228     *resource = std::move(new_resource);
    229     return Status::OK();
    230   }
    231 
    232   string DebugString() final { return "BatchResource"; }
    233 
    234   // Ingests data from one invocation of the batch op. The data is enqueued to
    235   // be combined with others into a batch, asynchronously.
    236   Status RegisterInput(int64 guid, OpKernelContext* context,
    237                        const string& batcher_queue_name,
    238                        AsyncOpKernel::DoneCallback done_callback) {
    239     std::unique_ptr<BatchTask> batch_components(new BatchTask);
    240     batch_components->guid = guid;
    241     OpInputList tensors;
    242     TF_RETURN_IF_ERROR(context->input_list("in_tensors", &tensors));
    243     for (int i = 0; i < tensors.size(); ++i) {
    244       const Tensor& tensor = tensors[i];
    245       if (tensor.shape().dims() == 0) {
    246         return errors::InvalidArgument(
    247             "Batching input tensors must have at least one dimension");
    248       }
    249       if (tensors.size() >= 2 &&
    250           tensor.shape().dim_size(0) != tensors[0].shape().dim_size(0)) {
    251         return errors::InvalidArgument(
    252             "Batching input tensors supplied in a given op invocation must "
    253             "have equal 0th-dimension size");
    254       }
    255       batch_components->inputs.push_back(tensor);
    256     }
    257     batch_components->context = context;
    258     batch_components->done_callback = std::move(done_callback);
    259 
    260     BatcherQueue* batcher_queue;
    261     TF_RETURN_IF_ERROR(
    262         LookupOrCreateBatcherQueue(batcher_queue_name, &batcher_queue));
    263     return batcher_queue->Schedule(&batch_components);
    264   }
    265 
    266  private:
    267   BatchResource() = default;
    268 
    269   // One input to be batched. Corresponds to one invocation of the batch op.
    270   struct BatchTask : public serving::BatchTask {
    271     // A unique ID to identify this invocation of Batch.
    272     int64 guid;
    273 
    274     std::vector<Tensor> inputs;
    275     OpKernelContext* context;
    276     AsyncOpKernel::DoneCallback done_callback;
    277 
    278     size_t size() const override { return inputs[0].shape().dim_size(0); }
    279   };
    280 
    281   using Batcher = serving::SharedBatchScheduler<BatchTask>;
    282   using BatcherQueue = serving::BatchScheduler<BatchTask>;
    283   using Batch = serving::Batch<BatchTask>;
    284 
    285   // Validates that it's legal to combine the tasks in 'batch' into a batch.
    286   // Assumes the batch is non-empty.
    287   static Status ValidateBatch(const Batch& batch) {
    288     for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
    289       const BatchTask& task = batch.task(task_idx);
    290 
    291       if (task.inputs.size() != batch.task(0).inputs.size()) {
    292         return errors::InvalidArgument(
    293             "Batching inputs must have equal number of edges");
    294       }
    295     }
    296 
    297     return Status::OK();
    298   }
    299 
    300   // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than
    301   // or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply
    302   // returns 'batch_size'.
    303   int RoundToLowestAllowedBatchSize(int batch_size) const {
    304     if (allowed_batch_sizes_.empty()) {
    305       return batch_size;
    306     }
    307     for (int allowed_size : allowed_batch_sizes_) {
    308       if (allowed_size >= batch_size) {
    309         return allowed_size;
    310       }
    311     }
    312     LOG(ERROR) << "Maximum batch size greater than largest allowed size; "
    313                   "ignoring allowed sizes constraint";
    314     return batch_size;
    315   }
    316 
    317   // Processes a batch of one or more BatchTask entries.
    318   void ProcessBatch(std::unique_ptr<Batch> batch) const {
    319     if (batch->empty()) {
    320       return;
    321     }
    322     const int padded_batch_size = RoundToLowestAllowedBatchSize(batch->size());
    323     const int padding_amount = padded_batch_size - batch->size();
    324 
    325     OpKernelContext* last_task_context =
    326         batch->task(batch->num_tasks() - 1).context;
    327     AsyncOpKernel::DoneCallback last_task_callback =
    328         batch->task(batch->num_tasks() - 1).done_callback;
    329 
    330     OP_REQUIRES_OK_ASYNC(last_task_context, ValidateBatch(*batch),
    331                          last_task_callback);
    332 
    333     // All tasks should have the same number of input edges.
    334     const int num_input_edges = batch->task(0).inputs.size();
    335 
    336     // Process each input edge one at a time (the typical case has just one).
    337     for (int i = 0; i < num_input_edges; ++i) {
    338       // Emit batch->num_tasks() - 1 empty output tensors.
    339       for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
    340         const BatchTask& task = batch->task(task_idx);
    341         TensorShape output_shape(task.inputs.at(i).shape());
    342         output_shape.set_dim(0, 0);
    343         Tensor* output = nullptr;
    344         OP_REQUIRES_OK_ASYNC(
    345             task.context,
    346             task.context->allocate_output(i, output_shape, &output),
    347             task.done_callback);
    348       }
    349 
    350       // Concatenate the tasks ith input tensors into a big output tensor.
    351       std::vector<Tensor> to_concatenate;
    352       to_concatenate.reserve(batch->num_tasks());
    353       for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
    354         to_concatenate.push_back(batch->task(task_idx).inputs.at(i));
    355       }
    356 
    357       // Add padding as needed. Use the first row of the first task's tensor as
    358       // the data for padding.
    359       if (padding_amount > 0) {
    360         const Tensor& padding_source = batch->task(0).inputs.at(i);
    361         Tensor padding;
    362         if (padding_source.shape().dim_size(0) == 1) {
    363           padding = padding_source;
    364         } else {
    365           const std::vector<int64> slice_sizes = {1};
    366           const DataType type = padding_source.dtype();
    367           Status slice_status;
    368           std::vector<Tensor> slices;
    369           switch (type) {
    370 #define CASE(type)                                                   \
    371   case DataTypeToEnum<type>::value:                                  \
    372     slice_status = SplitCPU<type>(last_task_context, padding_source, \
    373                                   slice_sizes, &slices);             \
    374     break;
    375             TF_CALL_ALL_TYPES(CASE);
    376 #undef CASE
    377             default:
    378               slice_status =
    379                   errors::InvalidArgument("Unsupported data type: ", type);
    380               break;
    381           }
    382           OP_REQUIRES_OK_ASYNC(last_task_context, slice_status,
    383                                last_task_callback);
    384           padding = slices.at(0);
    385         }
    386         for (int i = 0; i < padding_amount; ++i) {
    387           to_concatenate.push_back(padding);
    388         }
    389       }
    390 
    391       const DataType type = to_concatenate[0].dtype();
    392       Status concat_status;
    393       switch (type) {
    394 #define CASE(type)                                                      \
    395   case DataTypeToEnum<type>::value:                                     \
    396     concat_status = Concat<type>(last_task_context, to_concatenate, i); \
    397     break;
    398         TF_CALL_ALL_TYPES(CASE);
    399 #undef CASE
    400         default:
    401           concat_status =
    402               errors::InvalidArgument("Unsupported data type: ", type);
    403           break;
    404       }
    405       OP_REQUIRES_OK_ASYNC(last_task_context, concat_status,
    406                            last_task_callback);
    407     }
    408 
    409     // Emit batch->num_tasks() - 1 empty index tensors.
    410     for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
    411       const BatchTask& task = batch->task(task_idx);
    412       TensorShape index_shape({0, 3});
    413       Tensor* output = nullptr;
    414       OP_REQUIRES_OK_ASYNC(
    415           task.context,
    416           task.context->allocate_output(num_input_edges, index_shape, &output),
    417           task.done_callback);
    418     }
    419     // Emit all ID tensors.
    420     for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
    421       const BatchTask& task = batch->task(task_idx);
    422       Tensor* id;
    423       OP_REQUIRES_OK_ASYNC(task.context,
    424                            task.context->allocate_output(num_input_edges + 1,
    425                                                          TensorShape({}), &id),
    426                            task.done_callback);
    427       id->scalar<int64>()() = task.guid;
    428     }
    429     OP_REQUIRES_OK_ASYNC(
    430         last_task_context,
    431         EmitIndexTensor(last_task_context, *batch, num_input_edges),
    432         last_task_callback);
    433 
    434     // Signal done for each element of the batch. (At this point, the contexts
    435     // are no longer guaranteed to remain live.)
    436     for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
    437       batch->mutable_task(task_idx)->done_callback();
    438     }
    439   }
    440 
    441   // Emits an index tensor, which the Unbatch op will use to un-concatenate
    442   // the tensor and attribute the pieces to the right batch keys. The index
    443   // tensor contains, for each input: [batch_key, start_offset, end_offset]
    444   // where start_offset and end_offset represent the range of entries in the
    445   // concatenated tensors that belong to that input.
    446   //
    447   // Emits the result to the output at 'output_index' using 'context'.
    448   static Status EmitIndexTensor(OpKernelContext* context, const Batch& batch,
    449                                 int output_index) {
    450     const TensorShape index_shape({batch.num_tasks(), 3});
    451     Tensor* index = nullptr;
    452     TF_RETURN_IF_ERROR(
    453         context->allocate_output(output_index, index_shape, &index));
    454     auto index_flat = index->shaped<int64, 2>({batch.num_tasks(), 3});
    455     size_t offset = 0;
    456     for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
    457       const BatchTask& task = batch.task(task_idx);
    458       index_flat(task_idx, 0) = task.guid;
    459       index_flat(task_idx, 1) = offset;
    460       index_flat(task_idx, 2) = offset + task.size();
    461       offset += task.size();
    462     }
    463     return Status::OK();
    464   }
    465 
    466   // Looks up the batcher queue for 'queue_name'. If it didn't previously exist,
    467   // creates it.
    468   Status LookupOrCreateBatcherQueue(const string& queue_name,
    469                                     BatcherQueue** queue) {
    470     mutex_lock l(batcher_queues_mu_);
    471 
    472     auto it = batcher_queues_.find(queue_name);
    473     if (it != batcher_queues_.end()) {
    474       *queue = it->second.get();
    475       return Status::OK();
    476     }
    477 
    478     std::unique_ptr<BatcherQueue> new_queue;
    479     auto process_batch_callback = [this](std::unique_ptr<Batch> batch) {
    480       ProcessBatch(std::move(batch));
    481     };
    482     TF_RETURN_IF_ERROR(batcher_->AddQueue(batcher_queue_options_,
    483                                           process_batch_callback, &new_queue));
    484     *queue = new_queue.get();
    485     batcher_queues_[queue_name] = std::move(new_queue);
    486     return Status::OK();
    487   }
    488 
    489   // A batch scheduler, and options for creating queues.
    490   std::shared_ptr<Batcher> batcher_;
    491   Batcher::QueueOptions batcher_queue_options_;
    492 
    493   // A collection of batcher queues, keyed on queue name.
    494   // TODO(olston): Garbage-collect unused queues (perhaps simply remove empty
    495   // ones (with a time delay?); it's okay if they get recreated later).
    496   mutable mutex batcher_queues_mu_;
    497   std::map<string, std::unique_ptr<BatcherQueue>> batcher_queues_
    498       GUARDED_BY(batcher_queues_mu_);
    499 
    500   std::vector<int32> allowed_batch_sizes_;
    501 };
    502 
    503 class BatchKernel : public AsyncOpKernel {
    504  public:
    505   explicit BatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
    506     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
    507     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
    508     // If shared_name is not supplied, use name instead (prevent collisions by
    509     // default).
    510     if (shared_name_.empty()) {
    511       shared_name_ = name();
    512     }
    513     OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
    514     OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
    515     OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
    516     OP_REQUIRES_OK(c,
    517                    c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
    518     OP_REQUIRES_OK(c,
    519                    c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
    520     OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
    521     OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
    522   }
    523 
    524   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
    525     BatchResource* br;
    526     std::function<Status(BatchResource * *r)> creator =
    527         [this](BatchResource** r) {
    528           std::unique_ptr<BatchResource> new_resource;
    529           TF_RETURN_IF_ERROR(BatchResource::Create(
    530               num_batch_threads_, max_batch_size_, batch_timeout_micros_,
    531               max_enqueued_batches_, allowed_batch_sizes_, &new_resource));
    532           *r = new_resource.release();
    533           return Status::OK();
    534         };
    535     OP_REQUIRES_OK_ASYNC(c,
    536                          c->resource_manager()->LookupOrCreate(
    537                              container_, shared_name_, &br, creator),
    538                          done);
    539     const Status status =
    540         br->RegisterInput(random::New64(), c, batcher_queue_, done);
    541     br->Unref();
    542     if (!status.ok()) {
    543       OP_REQUIRES_OK_ASYNC(c, status, done);
    544     }
    545     // Assume br calls done, so nothing to do here.
    546   }
    547 
    548   // Validates 'allowed_batch_sizes_'. The entries must increase monotonically,
    549   // and the last one must equal 'max_batch_size_'.
    550   Status ValidateAllowedBatchSizes() const {
    551     if (allowed_batch_sizes_.empty()) {
    552       return Status::OK();
    553     }
    554     int32 last_size = 0;
    555     for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
    556       const int32 size = allowed_batch_sizes_.at(i);
    557       if (i > 0 && size <= last_size) {
    558         return errors::InvalidArgument(
    559             "allowed_batch_sizes entries must be monotonically increasing");
    560       }
    561       if (i == allowed_batch_sizes_.size() - 1 && size != max_batch_size_) {
    562         return errors::InvalidArgument(
    563             "final entry in allowed_batch_sizes must equal max_batch_size");
    564       }
    565       last_size = size;
    566     }
    567     return Status::OK();
    568   }
    569 
    570  private:
    571   string container_;
    572   string shared_name_;
    573   string batcher_queue_;
    574   int32 num_batch_threads_;
    575   int32 max_batch_size_;
    576   int32 batch_timeout_micros_;
    577   int32 max_enqueued_batches_;
    578   std::vector<int32> allowed_batch_sizes_;
    579 };
    580 
    581 REGISTER_KERNEL_BUILDER(Name("Batch").Device(DEVICE_CPU), BatchKernel);
    582 
    583 // A class encapsulating the state and logic for unbatching tensors.
    584 //
    585 // UnbatchResource keeps two data structures indexed by batch-key: one which has
    586 // the continuations for all concurrent kernels which are waiting for tensors
    587 // and another which has tensors which are waiting for their corresponding
    588 // kernels to run. Whenever a kernel runs, we either grab its tensor if it's
    589 // waiting already, or we insert it in the queue and then look at its tensor to
    590 // see if it can be used to dispatch any stored continuations.
    591 class UnbatchResource : public ResourceBase {
    592  public:
    593   explicit UnbatchResource(int32 timeout_micros)
    594       : timeout_micros_(timeout_micros),
    595         timeout_enforcer_(new serving::PeriodicFunction(
    596             [this] { EnforceTimeout(); }, 1000 /* 1 ms */)) {}
    597 
    598   ~UnbatchResource() override {
    599     // Tear down 'timeout_enforcer_' first, since it accesses other state in
    600     // this class.
    601     timeout_enforcer_ = nullptr;
    602   }
    603 
    604   string DebugString() final { return "UnbatchResource"; }
    605 
    606   Status Compute(OpKernelContext* context, AsyncOpKernel::DoneCallback done) {
    607     const Tensor& data_t = context->input(0);
    608     const Tensor& batch_index_t = context->input(1);
    609 
    610     if (batch_index_t.shape().dim_size(0) > data_t.shape().dim_size(0)) {
    611       return errors::InvalidArgument(
    612           "Wrong shape for index tensor. Expected 0th dimension size to be no "
    613           "greater than ",
    614           data_t.shape().dim_size(0),
    615           "; Got: ", batch_index_t.shape().dim_size(0), ".");
    616     }
    617     if (batch_index_t.shape().dim_size(1) != 3) {
    618       return errors::InvalidArgument(
    619           "Wrong shape for index tensor. Expected 1st dimension size to be 3 ; "
    620           "Got: ",
    621           batch_index_t.shape().dim_size(1), ".");
    622     }
    623 
    624     const int64 batch_key = context->input(2).scalar<int64>()();
    625     const bool nonempty_input = batch_index_t.dim_size(0) > 0;
    626 
    627     // If we have a non-empty tensor, slice it up.
    628     // (It is important to do this outside of the critical section below.)
    629     // The following variables are populated iff 'nonempty_input==true'.
    630     std::vector<int64> sizes;
    631     std::vector<int64> batch_keys;
    632     std::vector<Tensor> split_inputs;
    633     if (nonempty_input) {
    634       auto batch_indices =
    635           batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
    636       for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
    637         sizes.push_back(batch_indices(i, 2) - batch_indices(i, 1));
    638         batch_keys.push_back(batch_indices(i, 0));
    639       }
    640 
    641       const DataType type = data_t.dtype();
    642       switch (type) {
    643 #define CASE(type)                                                          \
    644   case DataTypeToEnum<type>::value:                                         \
    645     TF_RETURN_IF_ERROR(Split<type>(context, data_t, sizes, &split_inputs)); \
    646     break;
    647         TF_CALL_ALL_TYPES(CASE);
    648 #undef CASE
    649         default:
    650           return errors::InvalidArgument("Unsupported data type: ", type);
    651       }
    652     }
    653 
    654     // Critical section.
    655     std::vector<AsyncOpKernel::DoneCallback> done_callbacks_to_call;
    656     Status status = [&]() -> Status {
    657       mutex_lock ml(mu_);
    658 
    659       // Check to see whether the tensor we want is already ready.
    660       auto tensor_it = waiting_tensors_.find(batch_key);
    661       if (tensor_it != waiting_tensors_.end()) {
    662         context->set_output(0, tensor_it->second.tensor);
    663         waiting_tensors_.erase(tensor_it);
    664         done_callbacks_to_call.push_back(done);
    665         return Status::OK();
    666       }
    667 
    668       const uint64 deadline_micros =
    669           Env::Default()->NowMicros() + timeout_micros_;
    670 
    671       // Add ourselves to the waitlist for tensors.
    672       if (!waiting_callbacks_
    673                .emplace(batch_key,
    674                         WaitingCallback{deadline_micros, context, done})
    675                .second) {
    676         return errors::AlreadyExists(
    677             "Multiple session runs with the same batch key.");
    678       }
    679 
    680       // If we have a non-empty tensor, finish the waitlisted runs,
    681       // and store any remaining pieces.
    682       if (nonempty_input) {
    683         for (size_t i = 0; i < batch_keys.size(); ++i) {
    684           auto runs_it = waiting_callbacks_.find(batch_keys[i]);
    685           if (runs_it != waiting_callbacks_.end()) {
    686             runs_it->second.context->set_output(0, split_inputs[i]);
    687             done_callbacks_to_call.push_back(runs_it->second.done);
    688             waiting_callbacks_.erase(runs_it);
    689           } else {
    690             // Note: the deadline here is in case we are arriving late and the
    691             // kernel that should rendezvous with this tensor has already waited
    692             // and timed out.
    693             if (!waiting_tensors_
    694                      .emplace(batch_keys[i],
    695                               WaitingTensor{deadline_micros, split_inputs[i]})
    696                      .second) {
    697               return errors::AlreadyExists(
    698                   "Multiple tensors returned for same batch key.");
    699             }
    700           }
    701         }
    702       }
    703 
    704       return Status::OK();
    705     }();
    706 
    707     for (const AsyncOpKernel::DoneCallback& done_callback :
    708          done_callbacks_to_call) {
    709       done_callback();
    710     }
    711 
    712     return status;
    713   }
    714 
    715  private:
    716   // Evicts waiting tensors and callbacks that have exceeded their deadline.
    717   void EnforceTimeout() {
    718     const uint64 now = Env::Default()->NowMicros();
    719     std::vector<WaitingCallback> evicted_callbacks;
    720 
    721     {
    722       mutex_lock ml(mu_);
    723 
    724       for (auto it = waiting_tensors_.begin(); it != waiting_tensors_.end();) {
    725         const WaitingTensor& waiting_tensor = it->second;
    726         if (waiting_tensor.deadline_micros < now) {
    727           it = waiting_tensors_.erase(it);
    728         } else {
    729           ++it;
    730         }
    731       }
    732 
    733       for (auto it = waiting_callbacks_.begin();
    734            it != waiting_callbacks_.end();) {
    735         const WaitingCallback& waiting_callback = it->second;
    736         if (waiting_callback.deadline_micros < now) {
    737           evicted_callbacks.push_back(waiting_callback);
    738           it = waiting_callbacks_.erase(it);
    739         } else {
    740           ++it;
    741         }
    742       }
    743     }
    744 
    745     for (const WaitingCallback& evicted_callback : evicted_callbacks) {
    746       evicted_callback.context->CtxFailureWithWarning(errors::DeadlineExceeded(
    747           "Batched data did not arrive within timeout window."));
    748       evicted_callback.done();
    749     }
    750   }
    751 
    752   struct WaitingTensor {
    753     uint64 deadline_micros;
    754     Tensor tensor;
    755   };
    756 
    757   struct WaitingCallback {
    758     uint64 deadline_micros;
    759     OpKernelContext* context;
    760     AsyncOpKernel::DoneCallback done;
    761   };
    762 
    763   const int32 timeout_micros_;
    764 
    765   mutex mu_;
    766 
    767   // Maps keyed by BatchKey of tensors waiting for callbacks and callbacks
    768   // waiting for tensors.
    769   std::unordered_map<int64, WaitingTensor> waiting_tensors_ GUARDED_BY(mu_);
    770   std::unordered_map<int64, WaitingCallback> waiting_callbacks_ GUARDED_BY(mu_);
    771 
    772   // A thread that evicts waiting tensors and callbacks that have exceeded their
    773   // deadline.
    774   std::unique_ptr<serving::PeriodicFunction> timeout_enforcer_;
    775 };
    776 
    777 class UnbatchKernel : public AsyncOpKernel {
    778  public:
    779   explicit UnbatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
    780     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
    781     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
    782     // If shared_name is not supplied, use name instead (prevent collisions by
    783     // default).
    784     if (shared_name_.empty()) {
    785       shared_name_ = name();
    786     }
    787     OP_REQUIRES_OK(c, c->GetAttr("timeout_micros", &timeout_micros_));
    788   }
    789 
    790   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
    791     UnbatchResource* ubr;
    792     std::function<Status(UnbatchResource * *r)> creator =
    793         [this](UnbatchResource** r) {
    794           *r = new UnbatchResource(timeout_micros_);
    795           return Status::OK();
    796         };
    797     OP_REQUIRES_OK_ASYNC(c,
    798                          c->resource_manager()->LookupOrCreate(
    799                              container_, shared_name_, &ubr, creator),
    800                          done);
    801     auto status = ubr->Compute(c, done);
    802     ubr->Unref();
    803     if (!status.ok()) {
    804       OP_REQUIRES_OK_ASYNC(c, status, done);
    805     }
    806     // Assume ubr calls done, so nothing to do here.
    807   }
    808 
    809  private:
    810   string container_;
    811   string shared_name_;
    812   int32 timeout_micros_;
    813 };
    814 REGISTER_KERNEL_BUILDER(Name("Unbatch").Device(DEVICE_CPU), UnbatchKernel);
    815 
    816 // A class encapsulating the state and logic for batching tensors
    817 // deterministically for the gradient of unbatch.
    818 class UnbatchGradResource : public ResourceBase {
    819  public:
    820   UnbatchGradResource() {}
    821 
    822   string DebugString() final { return "UnbatchGradResource"; }
    823 
    824   // Flushes the information for one batch, given its context and done
    825   // callback. Clears all information about it from the available_tensors_.
    826   Status OutputBatch(OpKernelContext* context,
    827                      const AsyncOpKernel::DoneCallback& done)
    828       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    829     const Tensor& batch_index_t = context->input(1);
    830     auto batch_index =
    831         batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
    832     std::vector<Tensor> tensors;
    833     for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
    834       auto available_it = available_tensors_.find(batch_index(i, 0));
    835       if (available_it == available_tensors_.end()) {
    836         return errors::Internal("bad bookkeeping of available tensors.");
    837       }
    838       tensors.push_back(available_it->second);
    839       available_tensors_.erase(available_it);
    840     }
    841 
    842     const DataType type = tensors[0].dtype();
    843     switch (type) {
    844 #define CASE(type)                                         \
    845   case DataTypeToEnum<type>::value:                        \
    846     TF_RETURN_IF_ERROR(Concat<type>(context, tensors, 0)); \
    847     break;
    848       TF_CALL_ALL_TYPES(CASE);
    849 #undef CASE
    850       default:
    851         return errors::InvalidArgument("Unsupported data type: ", type);
    852     }
    853     done();
    854     return Status::OK();
    855   }
    856 
    857   // Ingests data from one invocation of the op.
    858   Status Compute(OpKernelContext* context,
    859                  const AsyncOpKernel::DoneCallback& done) {
    860     const Tensor& data_t = context->input(0);
    861     const Tensor& batch_index_t = context->input(1);
    862     const Tensor& grad_t = context->input(2);
    863 
    864     mutex_lock ml(mu_);
    865 
    866     const int64 batch_key = context->input(3).scalar<int64>()();
    867     // Mark our tensor as available.
    868     if (!available_tensors_.emplace(batch_key, grad_t).second) {
    869       return errors::InvalidArgument("Two runs with the same batch key.");
    870     }
    871 
    872     // Check whether we have a valid input tensor and, if so, create its
    873     // dispatch logic.
    874     if (data_t.NumElements() > 0) {
    875       if (batch_index_t.NumElements() == 0) {
    876         return errors::InvalidArgument(
    877             "batch_index is empty while the tensor isn't.");
    878       }
    879       std::unordered_set<int64> missing_tensors;
    880       const auto batch_index =
    881           batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
    882       for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
    883         const int64 batch_key = batch_index(i, 0);
    884         if (available_tensors_.find(batch_key) == available_tensors_.end()) {
    885           missing_tensors.emplace(batch_key);
    886         }
    887       }
    888       if (missing_tensors.empty()) {
    889         return OutputBatch(context, done);
    890       }
    891       if (!available_batches_
    892                .emplace(batch_key, Batch{missing_tensors, context, done})
    893                .second) {
    894         return errors::InvalidArgument(
    895             "Batch key with valid batch used twice.");
    896       }
    897       for (const int64 i : missing_tensors) {
    898         if (!desired_tensor_to_batch_map_.emplace(i, batch_key).second) {
    899           return errors::InvalidArgument(
    900               "Missing tensor wanted by more than one batch.");
    901         }
    902       }
    903     } else {
    904       // If we don't have a valid input tensor we can output an empty tensor and
    905       // call our done closure.
    906       TensorShape output_shape(grad_t.shape());
    907       output_shape.set_dim(0, 0);
    908       Tensor* output = nullptr;
    909       TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, &output));
    910       done();
    911     }
    912 
    913     // Search to see whether our tensor is desired by any existing batch.
    914     auto desire_it = desired_tensor_to_batch_map_.find(batch_key);
    915     if (desire_it != desired_tensor_to_batch_map_.end()) {
    916       // Mark our tensor as no longer missing.
    917       auto batch_it = available_batches_.find(desire_it->second);
    918       desired_tensor_to_batch_map_.erase(desire_it);
    919       if (batch_it == available_batches_.end()) {
    920         return errors::InvalidArgument("Batch no longer exists.");
    921       }
    922       batch_it->second.missing_tensors.erase(batch_key);
    923       // If all tensors are available we should concatenate them and dispatch
    924       // the batch.
    925       if (batch_it->second.missing_tensors.empty()) {
    926         TF_RETURN_IF_ERROR(
    927             OutputBatch(batch_it->second.context, batch_it->second.done));
    928         available_batches_.erase(batch_it);
    929       }
    930     }
    931     return Status::OK();
    932   }
    933 
    934  private:
    935   mutex mu_;
    936 
    937   // Represents a still-incomplete batch of tensors. When all tensors become
    938   // available they will be concatenated in the right order and sent through the
    939   // context.
    940   struct Batch {
    941     // Batch keys for tensors which are still missing from this batch. When this
    942     // is empty the Tensors can be concatenated and forwarded.
    943     std::unordered_set<int64> missing_tensors;
    944 
    945     // Context and callback for the session responsible for finishing this
    946     // batch.
    947     OpKernelContext* context;
    948     AsyncOpKernel::DoneCallback done;
    949   };
    950 
    951   // Map from batch key of the session which will output the batched gradients
    952   // to still-incomplete batches.
    953   std::unordered_map<int64, Batch> available_batches_;
    954 
    955   // Map from batch key to tensors which are waiting for their batches to be
    956   // available.
    957   std::unordered_map<int64, Tensor> available_tensors_;
    958 
    959   // Map from batch key of a tensor which is not yet available to the batch key
    960   // of the batch to which it belongs.
    961   std::unordered_map<int64, int64> desired_tensor_to_batch_map_;
    962 };
    963 
    964 class UnbatchGradKernel : public AsyncOpKernel {
    965  public:
    966   explicit UnbatchGradKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
    967     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
    968     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
    969     // If shared_name is not supplied, use name instead (prevent collisions by
    970     // default).
    971     if (shared_name_.empty()) {
    972       shared_name_ = name();
    973     }
    974   }
    975 
    976   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
    977     UnbatchGradResource* ubr;
    978     std::function<Status(UnbatchGradResource * *r)> creator =
    979         [this](UnbatchGradResource** r) {
    980           *r = new UnbatchGradResource();
    981           return Status::OK();
    982         };
    983     OP_REQUIRES_OK_ASYNC(c,
    984                          c->resource_manager()->LookupOrCreate(
    985                              container_, shared_name_, &ubr, creator),
    986                          done);
    987     Status status = ubr->Compute(c, done);
    988     ubr->Unref();
    989     if (!status.ok()) {
    990       OP_REQUIRES_OK_ASYNC(c, status, done);
    991     }
    992     // Assume ubr calls done, so nothing to do here.
    993   }
    994 
    995  private:
    996   string container_;
    997   string shared_name_;
    998 };
    999 REGISTER_KERNEL_BUILDER(Name("UnbatchGrad").Device(DEVICE_CPU),
   1000                         UnbatchGradKernel);
   1001 
   1002 }  // namespace tensorflow
   1003