Home | History | Annotate | Download | only in core
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_LIB_CORE_THREADPOOL_H_
     17 #define TENSORFLOW_LIB_CORE_THREADPOOL_H_
     18 
     19 #include <functional>
     20 #include <memory>
     21 #include "tensorflow/core/platform/env.h"
     22 #include "tensorflow/core/platform/macros.h"
     23 #include "tensorflow/core/platform/types.h"
     24 
     25 namespace tensorflow {
     26 namespace thread {
     27 
     28 class ThreadPool {
     29  public:
     30   // Constructs a pool that contains "num_threads" threads with specified
     31   // "name". env->StartThread() is used to create individual threads with the
     32   // given ThreadOptions. If "low_latency_hint" is true the thread pool
     33   // implementation may use it as a hint that lower latency is preferred at the
     34   // cost of higher CPU usage, e.g. by letting one or more idle threads spin
     35   // wait. Conversely, if the threadpool is used to schedule high-latency
     36   // operations like I/O the hint should be set to false.
     37   //
     38   // REQUIRES: num_threads > 0
     39   ThreadPool(Env* env, const ThreadOptions& thread_options, const string& name,
     40              int num_threads, bool low_latency_hint);
     41 
     42   // Constructs a pool for low-latency ops that contains "num_threads" threads
     43   // with specified "name". env->StartThread() is used to create individual
     44   // threads.
     45   // REQUIRES: num_threads > 0
     46   ThreadPool(Env* env, const string& name, int num_threads);
     47 
     48   // Constructs a pool for low-latency ops that contains "num_threads" threads
     49   // with specified "name". env->StartThread() is used to create individual
     50   // threads with the given ThreadOptions.
     51   // REQUIRES: num_threads > 0
     52   ThreadPool(Env* env, const ThreadOptions& thread_options, const string& name,
     53              int num_threads);
     54 
     55   // Waits until all scheduled work has finished and then destroy the
     56   // set of threads.
     57   ~ThreadPool();
     58 
     59   // Schedules fn() for execution in the pool of threads.
     60   void Schedule(std::function<void()> fn);
     61 
     62   // ParallelFor shards the "total" units of work assuming each unit of work
     63   // having roughly "cost_per_unit" cost, in cycles. Each unit of work is
     64   // indexed 0, 1, ..., total - 1. Each shard contains 1 or more units of work
     65   // and the total cost of each shard is roughly the same.
     66   //
     67   // "cost_per_unit" is an estimate of the number of CPU cycles (or nanoseconds
     68   // if not CPU-bound) to complete a unit of work. Overestimating creates too
     69   // many shards and CPU time will be dominated by per-shard overhead, such as
     70   // Context creation. Underestimating may not fully make use of the specified
     71   // parallelism.
     72   void ParallelFor(int64 total, int64 cost_per_unit,
     73                    std::function<void(int64, int64)> fn);
     74 
     75   // Shards the "total" units of work. For more details, see "ParallelFor".
     76   //
     77   // The function is passed a thread_id between 0 and NumThreads() *inclusive*.
     78   // This is because some work can happen on the caller thread while the threads
     79   // in the pool are also being used.
     80   //
     81   // The caller can allocate NumThreads() + 1 separate buffers for each thread.
     82   // Each thread can safely write to the buffer given by its id without
     83   // synchronization. However, the worker fn may be called multiple times
     84   // sequentially with the same id.
     85   //
     86   // At most NumThreads() unique ids will actually be used, and only a few may
     87   // be used for small workloads. If each buffer is expensive, the buffers
     88   // should be stored in an array initially filled with null, and a buffer
     89   // should be allocated by fn the first time that the id is used.
     90   void ParallelForWithWorkerId(
     91       int64 total, int64 cost_per_unit,
     92       const std::function<void(int64, int64, int)>& fn);
     93 
     94   // Returns the number of threads in the pool.
     95   int NumThreads() const;
     96 
     97   // Returns current thread id between 0 and NumThreads() - 1, if called from a
     98   // thread in the pool. Returns -1 otherwise.
     99   int CurrentThreadId() const;
    100 
    101   struct Impl;
    102 
    103  private:
    104   std::unique_ptr<Impl> impl_;
    105   TF_DISALLOW_COPY_AND_ASSIGN(ThreadPool);
    106 };
    107 
    108 }  // namespace thread
    109 }  // namespace tensorflow
    110 
    111 #endif  // TENSORFLOW_LIB_CORE_THREADPOOL_H_
    112