Home | History | Annotate | Download | only in Support
      1 //===- llvm/Support/Parallel.h - Parallel algorithms ----------------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 
     10 #ifndef LLVM_SUPPORT_PARALLEL_H
     11 #define LLVM_SUPPORT_PARALLEL_H
     12 
     13 #include "llvm/ADT/STLExtras.h"
     14 #include "llvm/Config/llvm-config.h"
     15 #include "llvm/Support/MathExtras.h"
     16 
     17 #include <algorithm>
     18 #include <condition_variable>
     19 #include <functional>
     20 #include <mutex>
     21 
     22 #if defined(_MSC_VER) && LLVM_ENABLE_THREADS
     23 #pragma warning(push)
     24 #pragma warning(disable : 4530)
     25 #include <concrt.h>
     26 #include <ppl.h>
     27 #pragma warning(pop)
     28 #endif
     29 
     30 namespace llvm {
     31 
     32 namespace parallel {
     33 struct sequential_execution_policy {};
     34 struct parallel_execution_policy {};
     35 
     36 template <typename T>
     37 struct is_execution_policy
     38     : public std::integral_constant<
     39           bool, llvm::is_one_of<T, sequential_execution_policy,
     40                                 parallel_execution_policy>::value> {};
     41 
     42 constexpr sequential_execution_policy seq{};
     43 constexpr parallel_execution_policy par{};
     44 
     45 namespace detail {
     46 
     47 #if LLVM_ENABLE_THREADS
     48 
     49 class Latch {
     50   uint32_t Count;
     51   mutable std::mutex Mutex;
     52   mutable std::condition_variable Cond;
     53 
     54 public:
     55   explicit Latch(uint32_t Count = 0) : Count(Count) {}
     56   ~Latch() { sync(); }
     57 
     58   void inc() {
     59     std::lock_guard<std::mutex> lock(Mutex);
     60     ++Count;
     61   }
     62 
     63   void dec() {
     64     std::lock_guard<std::mutex> lock(Mutex);
     65     if (--Count == 0)
     66       Cond.notify_all();
     67   }
     68 
     69   void sync() const {
     70     std::unique_lock<std::mutex> lock(Mutex);
     71     Cond.wait(lock, [&] { return Count == 0; });
     72   }
     73 };
     74 
     75 class TaskGroup {
     76   Latch L;
     77 
     78 public:
     79   void spawn(std::function<void()> f);
     80 
     81   void sync() const { L.sync(); }
     82 };
     83 
     84 #if defined(_MSC_VER)
     85 template <class RandomAccessIterator, class Comparator>
     86 void parallel_sort(RandomAccessIterator Start, RandomAccessIterator End,
     87                    const Comparator &Comp) {
     88   concurrency::parallel_sort(Start, End, Comp);
     89 }
     90 template <class IterTy, class FuncTy>
     91 void parallel_for_each(IterTy Begin, IterTy End, FuncTy Fn) {
     92   concurrency::parallel_for_each(Begin, End, Fn);
     93 }
     94 
     95 template <class IndexTy, class FuncTy>
     96 void parallel_for_each_n(IndexTy Begin, IndexTy End, FuncTy Fn) {
     97   concurrency::parallel_for(Begin, End, Fn);
     98 }
     99 
    100 #else
    101 const ptrdiff_t MinParallelSize = 1024;
    102 
    103 /// Inclusive median.
    104 template <class RandomAccessIterator, class Comparator>
    105 RandomAccessIterator medianOf3(RandomAccessIterator Start,
    106                                RandomAccessIterator End,
    107                                const Comparator &Comp) {
    108   RandomAccessIterator Mid = Start + (std::distance(Start, End) / 2);
    109   return Comp(*Start, *(End - 1))
    110              ? (Comp(*Mid, *(End - 1)) ? (Comp(*Start, *Mid) ? Mid : Start)
    111                                        : End - 1)
    112              : (Comp(*Mid, *Start) ? (Comp(*(End - 1), *Mid) ? Mid : End - 1)
    113                                    : Start);
    114 }
    115 
    116 template <class RandomAccessIterator, class Comparator>
    117 void parallel_quick_sort(RandomAccessIterator Start, RandomAccessIterator End,
    118                          const Comparator &Comp, TaskGroup &TG, size_t Depth) {
    119   // Do a sequential sort for small inputs.
    120   if (std::distance(Start, End) < detail::MinParallelSize || Depth == 0) {
    121     llvm::sort(Start, End, Comp);
    122     return;
    123   }
    124 
    125   // Partition.
    126   auto Pivot = medianOf3(Start, End, Comp);
    127   // Move Pivot to End.
    128   std::swap(*(End - 1), *Pivot);
    129   Pivot = std::partition(Start, End - 1, [&Comp, End](decltype(*Start) V) {
    130     return Comp(V, *(End - 1));
    131   });
    132   // Move Pivot to middle of partition.
    133   std::swap(*Pivot, *(End - 1));
    134 
    135   // Recurse.
    136   TG.spawn([=, &Comp, &TG] {
    137     parallel_quick_sort(Start, Pivot, Comp, TG, Depth - 1);
    138   });
    139   parallel_quick_sort(Pivot + 1, End, Comp, TG, Depth - 1);
    140 }
    141 
    142 template <class RandomAccessIterator, class Comparator>
    143 void parallel_sort(RandomAccessIterator Start, RandomAccessIterator End,
    144                    const Comparator &Comp) {
    145   TaskGroup TG;
    146   parallel_quick_sort(Start, End, Comp, TG,
    147                       llvm::Log2_64(std::distance(Start, End)) + 1);
    148 }
    149 
    150 template <class IterTy, class FuncTy>
    151 void parallel_for_each(IterTy Begin, IterTy End, FuncTy Fn) {
    152   // TaskGroup has a relatively high overhead, so we want to reduce
    153   // the number of spawn() calls. We'll create up to 1024 tasks here.
    154   // (Note that 1024 is an arbitrary number. This code probably needs
    155   // improving to take the number of available cores into account.)
    156   ptrdiff_t TaskSize = std::distance(Begin, End) / 1024;
    157   if (TaskSize == 0)
    158     TaskSize = 1;
    159 
    160   TaskGroup TG;
    161   while (TaskSize < std::distance(Begin, End)) {
    162     TG.spawn([=, &Fn] { std::for_each(Begin, Begin + TaskSize, Fn); });
    163     Begin += TaskSize;
    164   }
    165   std::for_each(Begin, End, Fn);
    166 }
    167 
    168 template <class IndexTy, class FuncTy>
    169 void parallel_for_each_n(IndexTy Begin, IndexTy End, FuncTy Fn) {
    170   ptrdiff_t TaskSize = (End - Begin) / 1024;
    171   if (TaskSize == 0)
    172     TaskSize = 1;
    173 
    174   TaskGroup TG;
    175   IndexTy I = Begin;
    176   for (; I + TaskSize < End; I += TaskSize) {
    177     TG.spawn([=, &Fn] {
    178       for (IndexTy J = I, E = I + TaskSize; J != E; ++J)
    179         Fn(J);
    180     });
    181   }
    182   for (IndexTy J = I; J < End; ++J)
    183     Fn(J);
    184 }
    185 
    186 #endif
    187 
    188 #endif
    189 
    190 template <typename Iter>
    191 using DefComparator =
    192     std::less<typename std::iterator_traits<Iter>::value_type>;
    193 
    194 } // namespace detail
    195 
    196 // sequential algorithm implementations.
    197 template <class Policy, class RandomAccessIterator,
    198           class Comparator = detail::DefComparator<RandomAccessIterator>>
    199 void sort(Policy policy, RandomAccessIterator Start, RandomAccessIterator End,
    200           const Comparator &Comp = Comparator()) {
    201   static_assert(is_execution_policy<Policy>::value,
    202                 "Invalid execution policy!");
    203   llvm::sort(Start, End, Comp);
    204 }
    205 
    206 template <class Policy, class IterTy, class FuncTy>
    207 void for_each(Policy policy, IterTy Begin, IterTy End, FuncTy Fn) {
    208   static_assert(is_execution_policy<Policy>::value,
    209                 "Invalid execution policy!");
    210   std::for_each(Begin, End, Fn);
    211 }
    212 
    213 template <class Policy, class IndexTy, class FuncTy>
    214 void for_each_n(Policy policy, IndexTy Begin, IndexTy End, FuncTy Fn) {
    215   static_assert(is_execution_policy<Policy>::value,
    216                 "Invalid execution policy!");
    217   for (IndexTy I = Begin; I != End; ++I)
    218     Fn(I);
    219 }
    220 
    221 // Parallel algorithm implementations, only available when LLVM_ENABLE_THREADS
    222 // is true.
    223 #if LLVM_ENABLE_THREADS
    224 template <class RandomAccessIterator,
    225           class Comparator = detail::DefComparator<RandomAccessIterator>>
    226 void sort(parallel_execution_policy policy, RandomAccessIterator Start,
    227           RandomAccessIterator End, const Comparator &Comp = Comparator()) {
    228   detail::parallel_sort(Start, End, Comp);
    229 }
    230 
    231 template <class IterTy, class FuncTy>
    232 void for_each(parallel_execution_policy policy, IterTy Begin, IterTy End,
    233               FuncTy Fn) {
    234   detail::parallel_for_each(Begin, End, Fn);
    235 }
    236 
    237 template <class IndexTy, class FuncTy>
    238 void for_each_n(parallel_execution_policy policy, IndexTy Begin, IndexTy End,
    239                 FuncTy Fn) {
    240   detail::parallel_for_each_n(Begin, End, Fn);
    241 }
    242 #endif
    243 
    244 } // namespace parallel
    245 } // namespace llvm
    246 
    247 #endif // LLVM_SUPPORT_PARALLEL_H
    248