Home | History | Annotate | Download | only in batching_util
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Abstractions for processing small tasks in a batched fashion, to reduce
     17 // processing times and costs that can be amortized across multiple tasks.
     18 //
     19 // The core class is BatchScheduler, which groups tasks into batches.
     20 //
     21 // BatchScheduler encapsulates logic for aggregating multiple tasks into a
     22 // batch, and kicking off processing of a batch on a thread pool it manages.
     23 //
     24 // This file defines an abstract BatchScheduler class.
     25 
     26 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_
     27 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_
     28 
     29 #include <stddef.h>
     30 #include <algorithm>
     31 #include <functional>
     32 #include <memory>
     33 #include <utility>
     34 #include <vector>
     35 
     36 #include "tensorflow/core/lib/core/notification.h"
     37 #include "tensorflow/core/lib/core/status.h"
     38 #include "tensorflow/core/platform/logging.h"
     39 #include "tensorflow/core/platform/macros.h"
     40 #include "tensorflow/core/platform/mutex.h"
     41 #include "tensorflow/core/platform/thread_annotations.h"
     42 #include "tensorflow/core/platform/types.h"
     43 
     44 namespace tensorflow {
     45 namespace serving {
     46 
     47 // The abstract superclass for a unit of work to be done as part of a batch.
     48 //
     49 // An implementing subclass typically contains (or points to):
     50 //  (a) input data;
     51 //  (b) a thread-safe completion signal (e.g. a Notification);
     52 //  (c) a place to store the outcome (success, or some error), upon completion;
     53 //  (d) a place to store the output data, upon success.
     54 //
     55 // Items (b), (c) and (d) are typically non-owned pointers to data homed
     56 // elsewhere, because a task's ownership gets transferred to a BatchScheduler
     57 // (see below) and it may be deleted as soon as it is done executing.
     58 class BatchTask {
     59  public:
     60   virtual ~BatchTask() = default;
     61 
     62   // Returns the size of the task, in terms of how much it contributes to the
     63   // size of a batch. (A batch's size is the sum of its task sizes.)
     64   virtual size_t size() const = 0;
     65 };
     66 
     67 // A thread-safe collection of BatchTasks, to be executed together in some
     68 // fashion.
     69 //
     70 // At a given time, a batch is either "open" or "closed": an open batch can
     71 // accept new tasks; a closed one cannot. A batch is monotonic: initially it is
     72 // open and tasks can be added to it; then it is closed and its set of tasks
     73 // remains fixed for the remainder of its life. A closed batch cannot be re-
     74 // opened. Tasks can never be removed from a batch.
     75 //
     76 // Type parameter TaskType must be a subclass of BatchTask.
     77 template <typename TaskType>
     78 class Batch {
     79  public:
     80   Batch() = default;
     81   virtual ~Batch();  // Blocks until the batch is closed.
     82 
     83   // Appends 'task' to the batch. After calling AddTask(), the newly-added task
     84   // can be accessed via task(num_tasks()-1) or mutable_task(num_tasks()-1).
     85   // Dies if the batch is closed.
     86   void AddTask(std::unique_ptr<TaskType> task);
     87 
     88   // Removes the most recently added task. Returns nullptr if the batch is
     89   // empty.
     90   std::unique_ptr<TaskType> RemoveTask();
     91 
     92   // Returns the number of tasks in the batch.
     93   int num_tasks() const;
     94 
     95   // Returns true iff the batch contains 0 tasks.
     96   bool empty() const;
     97 
     98   // Returns a reference to the ith task (in terms of insertion order).
     99   const TaskType& task(int i) const;
    100 
    101   // Returns a pointer to the ith task (in terms of insertion order).
    102   TaskType* mutable_task(int i);
    103 
    104   // Returns the sum of the task sizes.
    105   size_t size() const;
    106 
    107   // Returns true iff the batch is currently closed.
    108   bool IsClosed() const;
    109 
    110   // Blocks until the batch is closed.
    111   void WaitUntilClosed() const;
    112 
    113   // Marks the batch as closed. Dies if called more than once.
    114   void Close();
    115 
    116  private:
    117   mutable mutex mu_;
    118 
    119   // The tasks in the batch.
    120   std::vector<std::unique_ptr<TaskType>> tasks_ GUARDED_BY(mu_);
    121 
    122   // The sum of the sizes of the tasks in 'tasks_'.
    123   size_t size_ GUARDED_BY(mu_) = 0;
    124 
    125   // Whether the batch has been closed.
    126   Notification closed_;
    127 
    128   TF_DISALLOW_COPY_AND_ASSIGN(Batch);
    129 };
    130 
    131 // An abstract batch scheduler class. Collects individual tasks into batches,
    132 // and processes each batch on a pool of "batch threads" that it manages. The
    133 // actual logic for processing a batch is accomplished via a callback.
    134 //
    135 // Type parameter TaskType must be a subclass of BatchTask.
    136 template <typename TaskType>
    137 class BatchScheduler {
    138  public:
    139   virtual ~BatchScheduler() = default;
    140 
    141   // Submits a task to be processed as part of a batch.
    142   //
    143   // Ownership of '*task' is transferred to the callee iff the method returns
    144   // Status::OK. In that case, '*task' is left as nullptr. Otherwise, '*task' is
    145   // left as-is.
    146   //
    147   // If no batch processing capacity is available to process this task at the
    148   // present time, and any task queue maintained by the implementing subclass is
    149   // full, this method returns an UNAVAILABLE error code. The client may retry
    150   // later.
    151   //
    152   // Other problems, such as the task size being larger than the maximum batch
    153   // size, yield other, permanent error types.
    154   //
    155   // In all cases, this method returns "quickly" without blocking for any
    156   // substantial amount of time. If the method returns Status::OK, the task is
    157   // processed asynchronously, and any errors that occur during the processing
    158   // of the batch that includes the task can be reported to 'task'.
    159   virtual Status Schedule(std::unique_ptr<TaskType>* task) = 0;
    160 
    161   // Returns the number of tasks that have been scheduled (i.e. accepted by
    162   // Schedule()), but have yet to be handed to a thread for execution as part of
    163   // a batch. Note that this returns the number of tasks, not the aggregate task
    164   // size (so if there is one task of size 3 and one task of size 5, this method
    165   // returns 2 rather than 8).
    166   virtual size_t NumEnqueuedTasks() const = 0;
    167 
    168   // Returns a guaranteed number of size 1 tasks that can be Schedule()d without
    169   // getting an UNAVAILABLE error. In a typical implementation, returns the
    170   // available space on a queue.
    171   //
    172   // There are two important caveats:
    173   //  1. The guarantee does not extend to varying-size tasks due to possible
    174   //     internal fragmentation of batches.
    175   //  2. The guarantee only holds in a single-thread environment or critical
    176   //     section, i.e. if an intervening thread cannot call Schedule().
    177   //
    178   // This method is useful for monitoring, or for guaranteeing a future slot in
    179   // the schedule (but being mindful about the caveats listed above).
    180   virtual size_t SchedulingCapacity() const = 0;
    181 
    182   // Returns the maximum allowed size of tasks submitted to the scheduler. (This
    183   // is typically equal to a configured maximum batch size.)
    184   virtual size_t max_task_size() const = 0;
    185 };
    186 
    187 //////////
    188 // Implementation details follow. API users need not read.
    189 
    190 template <typename TaskType>
    191 Batch<TaskType>::~Batch() {
    192   WaitUntilClosed();
    193 }
    194 
    195 template <typename TaskType>
    196 void Batch<TaskType>::AddTask(std::unique_ptr<TaskType> task) {
    197   DCHECK(!IsClosed());
    198   {
    199     mutex_lock l(mu_);
    200     size_ += task->size();
    201     tasks_.push_back(std::move(task));
    202   }
    203 }
    204 
    205 template <typename TaskType>
    206 std::unique_ptr<TaskType> Batch<TaskType>::RemoveTask() {
    207   {
    208     mutex_lock l(mu_);
    209     if (tasks_.empty()) {
    210       return nullptr;
    211     }
    212     std::unique_ptr<TaskType> task = std::move(tasks_.back());
    213     size_ -= task->size();
    214     tasks_.pop_back();
    215     return task;
    216   }
    217 }
    218 
    219 template <typename TaskType>
    220 int Batch<TaskType>::num_tasks() const {
    221   {
    222     mutex_lock l(mu_);
    223     return tasks_.size();
    224   }
    225 }
    226 
    227 template <typename TaskType>
    228 bool Batch<TaskType>::empty() const {
    229   {
    230     mutex_lock l(mu_);
    231     return tasks_.empty();
    232   }
    233 }
    234 
    235 template <typename TaskType>
    236 const TaskType& Batch<TaskType>::task(int i) const {
    237   DCHECK_GE(i, 0);
    238   {
    239     mutex_lock l(mu_);
    240     DCHECK_LT(i, tasks_.size());
    241     return *tasks_[i].get();
    242   }
    243 }
    244 
    245 template <typename TaskType>
    246 TaskType* Batch<TaskType>::mutable_task(int i) {
    247   DCHECK_GE(i, 0);
    248   {
    249     mutex_lock l(mu_);
    250     DCHECK_LT(i, tasks_.size());
    251     return tasks_[i].get();
    252   }
    253 }
    254 
    255 template <typename TaskType>
    256 size_t Batch<TaskType>::size() const {
    257   {
    258     mutex_lock l(mu_);
    259     return size_;
    260   }
    261 }
    262 
    263 template <typename TaskType>
    264 bool Batch<TaskType>::IsClosed() const {
    265   return const_cast<Notification*>(&closed_)->HasBeenNotified();
    266 }
    267 
    268 template <typename TaskType>
    269 void Batch<TaskType>::WaitUntilClosed() const {
    270   const_cast<Notification*>(&closed_)->WaitForNotification();
    271 }
    272 
    273 template <typename TaskType>
    274 void Batch<TaskType>::Close() {
    275   closed_.Notify();
    276 }
    277 
    278 }  // namespace serving
    279 }  // namespace tensorflow
    280 
    281 #endif  // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_
    282