1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Abstractions for processing small tasks in a batched fashion, to reduce 17 // processing times and costs that can be amortized across multiple tasks. 18 // 19 // The core class is BatchScheduler, which groups tasks into batches. 20 // 21 // BatchScheduler encapsulates logic for aggregating multiple tasks into a 22 // batch, and kicking off processing of a batch on a thread pool it manages. 23 // 24 // This file defines an abstract BatchScheduler class. 25 26 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_ 27 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_ 28 29 #include <stddef.h> 30 #include <algorithm> 31 #include <functional> 32 #include <memory> 33 #include <utility> 34 #include <vector> 35 36 #include "tensorflow/core/lib/core/notification.h" 37 #include "tensorflow/core/lib/core/status.h" 38 #include "tensorflow/core/platform/logging.h" 39 #include "tensorflow/core/platform/macros.h" 40 #include "tensorflow/core/platform/mutex.h" 41 #include "tensorflow/core/platform/thread_annotations.h" 42 #include "tensorflow/core/platform/types.h" 43 44 namespace tensorflow { 45 namespace serving { 46 47 // The abstract superclass for a unit of work to be done as part of a batch. 48 // 49 // An implementing subclass typically contains (or points to): 50 // (a) input data; 51 // (b) a thread-safe completion signal (e.g. a Notification); 52 // (c) a place to store the outcome (success, or some error), upon completion; 53 // (d) a place to store the output data, upon success. 54 // 55 // Items (b), (c) and (d) are typically non-owned pointers to data homed 56 // elsewhere, because a task's ownership gets transferred to a BatchScheduler 57 // (see below) and it may be deleted as soon as it is done executing. 58 class BatchTask { 59 public: 60 virtual ~BatchTask() = default; 61 62 // Returns the size of the task, in terms of how much it contributes to the 63 // size of a batch. (A batch's size is the sum of its task sizes.) 64 virtual size_t size() const = 0; 65 }; 66 67 // A thread-safe collection of BatchTasks, to be executed together in some 68 // fashion. 69 // 70 // At a given time, a batch is either "open" or "closed": an open batch can 71 // accept new tasks; a closed one cannot. A batch is monotonic: initially it is 72 // open and tasks can be added to it; then it is closed and its set of tasks 73 // remains fixed for the remainder of its life. A closed batch cannot be re- 74 // opened. Tasks can never be removed from a batch. 75 // 76 // Type parameter TaskType must be a subclass of BatchTask. 77 template <typename TaskType> 78 class Batch { 79 public: 80 Batch() = default; 81 virtual ~Batch(); // Blocks until the batch is closed. 82 83 // Appends 'task' to the batch. After calling AddTask(), the newly-added task 84 // can be accessed via task(num_tasks()-1) or mutable_task(num_tasks()-1). 85 // Dies if the batch is closed. 86 void AddTask(std::unique_ptr<TaskType> task); 87 88 // Removes the most recently added task. Returns nullptr if the batch is 89 // empty. 90 std::unique_ptr<TaskType> RemoveTask(); 91 92 // Returns the number of tasks in the batch. 93 int num_tasks() const; 94 95 // Returns true iff the batch contains 0 tasks. 96 bool empty() const; 97 98 // Returns a reference to the ith task (in terms of insertion order). 99 const TaskType& task(int i) const; 100 101 // Returns a pointer to the ith task (in terms of insertion order). 102 TaskType* mutable_task(int i); 103 104 // Returns the sum of the task sizes. 105 size_t size() const; 106 107 // Returns true iff the batch is currently closed. 108 bool IsClosed() const; 109 110 // Blocks until the batch is closed. 111 void WaitUntilClosed() const; 112 113 // Marks the batch as closed. Dies if called more than once. 114 void Close(); 115 116 private: 117 mutable mutex mu_; 118 119 // The tasks in the batch. 120 std::vector<std::unique_ptr<TaskType>> tasks_ GUARDED_BY(mu_); 121 122 // The sum of the sizes of the tasks in 'tasks_'. 123 size_t size_ GUARDED_BY(mu_) = 0; 124 125 // Whether the batch has been closed. 126 Notification closed_; 127 128 TF_DISALLOW_COPY_AND_ASSIGN(Batch); 129 }; 130 131 // An abstract batch scheduler class. Collects individual tasks into batches, 132 // and processes each batch on a pool of "batch threads" that it manages. The 133 // actual logic for processing a batch is accomplished via a callback. 134 // 135 // Type parameter TaskType must be a subclass of BatchTask. 136 template <typename TaskType> 137 class BatchScheduler { 138 public: 139 virtual ~BatchScheduler() = default; 140 141 // Submits a task to be processed as part of a batch. 142 // 143 // Ownership of '*task' is transferred to the callee iff the method returns 144 // Status::OK. In that case, '*task' is left as nullptr. Otherwise, '*task' is 145 // left as-is. 146 // 147 // If no batch processing capacity is available to process this task at the 148 // present time, and any task queue maintained by the implementing subclass is 149 // full, this method returns an UNAVAILABLE error code. The client may retry 150 // later. 151 // 152 // Other problems, such as the task size being larger than the maximum batch 153 // size, yield other, permanent error types. 154 // 155 // In all cases, this method returns "quickly" without blocking for any 156 // substantial amount of time. If the method returns Status::OK, the task is 157 // processed asynchronously, and any errors that occur during the processing 158 // of the batch that includes the task can be reported to 'task'. 159 virtual Status Schedule(std::unique_ptr<TaskType>* task) = 0; 160 161 // Returns the number of tasks that have been scheduled (i.e. accepted by 162 // Schedule()), but have yet to be handed to a thread for execution as part of 163 // a batch. Note that this returns the number of tasks, not the aggregate task 164 // size (so if there is one task of size 3 and one task of size 5, this method 165 // returns 2 rather than 8). 166 virtual size_t NumEnqueuedTasks() const = 0; 167 168 // Returns a guaranteed number of size 1 tasks that can be Schedule()d without 169 // getting an UNAVAILABLE error. In a typical implementation, returns the 170 // available space on a queue. 171 // 172 // There are two important caveats: 173 // 1. The guarantee does not extend to varying-size tasks due to possible 174 // internal fragmentation of batches. 175 // 2. The guarantee only holds in a single-thread environment or critical 176 // section, i.e. if an intervening thread cannot call Schedule(). 177 // 178 // This method is useful for monitoring, or for guaranteeing a future slot in 179 // the schedule (but being mindful about the caveats listed above). 180 virtual size_t SchedulingCapacity() const = 0; 181 182 // Returns the maximum allowed size of tasks submitted to the scheduler. (This 183 // is typically equal to a configured maximum batch size.) 184 virtual size_t max_task_size() const = 0; 185 }; 186 187 ////////// 188 // Implementation details follow. API users need not read. 189 190 template <typename TaskType> 191 Batch<TaskType>::~Batch() { 192 WaitUntilClosed(); 193 } 194 195 template <typename TaskType> 196 void Batch<TaskType>::AddTask(std::unique_ptr<TaskType> task) { 197 DCHECK(!IsClosed()); 198 { 199 mutex_lock l(mu_); 200 size_ += task->size(); 201 tasks_.push_back(std::move(task)); 202 } 203 } 204 205 template <typename TaskType> 206 std::unique_ptr<TaskType> Batch<TaskType>::RemoveTask() { 207 { 208 mutex_lock l(mu_); 209 if (tasks_.empty()) { 210 return nullptr; 211 } 212 std::unique_ptr<TaskType> task = std::move(tasks_.back()); 213 size_ -= task->size(); 214 tasks_.pop_back(); 215 return task; 216 } 217 } 218 219 template <typename TaskType> 220 int Batch<TaskType>::num_tasks() const { 221 { 222 mutex_lock l(mu_); 223 return tasks_.size(); 224 } 225 } 226 227 template <typename TaskType> 228 bool Batch<TaskType>::empty() const { 229 { 230 mutex_lock l(mu_); 231 return tasks_.empty(); 232 } 233 } 234 235 template <typename TaskType> 236 const TaskType& Batch<TaskType>::task(int i) const { 237 DCHECK_GE(i, 0); 238 { 239 mutex_lock l(mu_); 240 DCHECK_LT(i, tasks_.size()); 241 return *tasks_[i].get(); 242 } 243 } 244 245 template <typename TaskType> 246 TaskType* Batch<TaskType>::mutable_task(int i) { 247 DCHECK_GE(i, 0); 248 { 249 mutex_lock l(mu_); 250 DCHECK_LT(i, tasks_.size()); 251 return tasks_[i].get(); 252 } 253 } 254 255 template <typename TaskType> 256 size_t Batch<TaskType>::size() const { 257 { 258 mutex_lock l(mu_); 259 return size_; 260 } 261 } 262 263 template <typename TaskType> 264 bool Batch<TaskType>::IsClosed() const { 265 return const_cast<Notification*>(&closed_)->HasBeenNotified(); 266 } 267 268 template <typename TaskType> 269 void Batch<TaskType>::WaitUntilClosed() const { 270 const_cast<Notification*>(&closed_)->WaitForNotification(); 271 } 272 273 template <typename TaskType> 274 void Batch<TaskType>::Close() { 275 closed_.Notify(); 276 } 277 278 } // namespace serving 279 } // namespace tensorflow 280 281 #endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_ 282