Home | History | Annotate | Download | only in parallel
      1 // -*- C++ -*-
      2 
      3 // Copyright (C) 2007, 2008, 2009 Free Software Foundation, Inc.
      4 //
      5 // This file is part of the GNU ISO C++ Library.  This library is free
      6 // software; you can redistribute it and/or modify it under the terms
      7 // of the GNU General Public License as published by the Free Software
      8 // Foundation; either version 3, or (at your option) any later
      9 // version.
     10 
     11 // This library is distributed in the hope that it will be useful, but
     12 // WITHOUT ANY WARRANTY; without even the implied warranty of
     13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
     14 // General Public License for more details.
     15 
     16 // Under Section 7 of GPL version 3, you are granted additional
     17 // permissions described in the GCC Runtime Library Exception, version
     18 // 3.1, as published by the Free Software Foundation.
     19 
     20 // You should have received a copy of the GNU General Public License and
     21 // a copy of the GCC Runtime Library Exception along with this program;
     22 // see the files COPYING3 and COPYING.RUNTIME respectively.  If not, see
     23 // <http://www.gnu.org/licenses/>.
     24 
     25 /** @file parallel/balanced_quicksort.h
     26  *  @brief Implementation of a dynamically load-balanced parallel quicksort.
     27  *
     28  *  It works in-place and needs only logarithmic extra memory.
     29  *  The algorithm is similar to the one proposed in
     30  *
     31  *  P. Tsigas and Y. Zhang.
     32  *  A simple, fast parallel implementation of quicksort and
     33  *  its performance evaluation on SUN enterprise 10000.
     34  *  In 11th Euromicro Conference on Parallel, Distributed and
     35  *  Network-Based Processing, page 372, 2003.
     36  *
     37  *  This file is a GNU parallel extension to the Standard C++ Library.
     38  */
     39 
     40 // Written by Johannes Singler.
     41 
     42 #ifndef _GLIBCXX_PARALLEL_BALANCED_QUICKSORT_H
     43 #define _GLIBCXX_PARALLEL_BALANCED_QUICKSORT_H 1
     44 
     45 #include <parallel/basic_iterator.h>
     46 #include <bits/stl_algo.h>
     47 
     48 #include <parallel/settings.h>
     49 #include <parallel/partition.h>
     50 #include <parallel/random_number.h>
     51 #include <parallel/queue.h>
     52 #include <functional>
     53 
     54 #if _GLIBCXX_ASSERTIONS
     55 #include <parallel/checkers.h>
     56 #endif
     57 
     58 namespace __gnu_parallel
     59 {
     60 /** @brief Information local to one thread in the parallel quicksort run. */
     61 template<typename RandomAccessIterator>
     62   struct QSBThreadLocal
     63   {
     64     typedef std::iterator_traits<RandomAccessIterator> traits_type;
     65     typedef typename traits_type::difference_type difference_type;
     66 
     67     /** @brief Continuous part of the sequence, described by an
     68     iterator pair. */
     69     typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
     70 
     71     /** @brief Initial piece to work on. */
     72     Piece initial;
     73 
     74     /** @brief Work-stealing queue. */
     75     RestrictedBoundedConcurrentQueue<Piece> leftover_parts;
     76 
     77     /** @brief Number of threads involved in this algorithm. */
     78     thread_index_t num_threads;
     79 
     80     /** @brief Pointer to a counter of elements left over to sort. */
     81     volatile difference_type* elements_leftover;
     82 
     83     /** @brief The complete sequence to sort. */
     84     Piece global;
     85 
     86     /** @brief Constructor.
     87      *  @param queue_size Size of the work-stealing queue. */
     88     QSBThreadLocal(int queue_size) : leftover_parts(queue_size) { }
     89   };
     90 
     91 /** @brief Balanced quicksort divide step.
     92   *  @param begin Begin iterator of subsequence.
     93   *  @param end End iterator of subsequence.
     94   *  @param comp Comparator.
     95   *  @param num_threads Number of threads that are allowed to work on
     96   *  this part.
     97   *  @pre @c (end-begin)>=1 */
     98 template<typename RandomAccessIterator, typename Comparator>
     99   typename std::iterator_traits<RandomAccessIterator>::difference_type
    100   qsb_divide(RandomAccessIterator begin, RandomAccessIterator end,
    101              Comparator comp, thread_index_t num_threads)
    102   {
    103     _GLIBCXX_PARALLEL_ASSERT(num_threads > 0);
    104 
    105     typedef std::iterator_traits<RandomAccessIterator> traits_type;
    106     typedef typename traits_type::value_type value_type;
    107     typedef typename traits_type::difference_type difference_type;
    108 
    109     RandomAccessIterator pivot_pos =
    110       median_of_three_iterators(begin, begin + (end - begin) / 2,
    111 				end  - 1, comp);
    112 
    113 #if defined(_GLIBCXX_ASSERTIONS)
    114     // Must be in between somewhere.
    115     difference_type n = end - begin;
    116 
    117     _GLIBCXX_PARALLEL_ASSERT(
    118            (!comp(*pivot_pos, *begin) && !comp(*(begin + n / 2), *pivot_pos))
    119         || (!comp(*pivot_pos, *begin) && !comp(*(end - 1), *pivot_pos))
    120         || (!comp(*pivot_pos, *(begin + n / 2)) && !comp(*begin, *pivot_pos))
    121         || (!comp(*pivot_pos, *(begin + n / 2)) && !comp(*(end - 1), *pivot_pos))
    122         || (!comp(*pivot_pos, *(end - 1)) && !comp(*begin, *pivot_pos))
    123         || (!comp(*pivot_pos, *(end - 1)) && !comp(*(begin + n / 2), *pivot_pos)));
    124 #endif
    125 
    126     // Swap pivot value to end.
    127     if (pivot_pos != (end - 1))
    128       std::swap(*pivot_pos, *(end - 1));
    129     pivot_pos = end - 1;
    130 
    131     __gnu_parallel::binder2nd<Comparator, value_type, value_type, bool>
    132         pred(comp, *pivot_pos);
    133 
    134     // Divide, returning end - begin - 1 in the worst case.
    135     difference_type split_pos = parallel_partition(
    136         begin, end - 1, pred, num_threads);
    137 
    138     // Swap back pivot to middle.
    139     std::swap(*(begin + split_pos), *pivot_pos);
    140     pivot_pos = begin + split_pos;
    141 
    142 #if _GLIBCXX_ASSERTIONS
    143     RandomAccessIterator r;
    144     for (r = begin; r != pivot_pos; ++r)
    145       _GLIBCXX_PARALLEL_ASSERT(comp(*r, *pivot_pos));
    146     for (; r != end; ++r)
    147       _GLIBCXX_PARALLEL_ASSERT(!comp(*r, *pivot_pos));
    148 #endif
    149 
    150     return split_pos;
    151   }
    152 
    153 /** @brief Quicksort conquer step.
    154   *  @param tls Array of thread-local storages.
    155   *  @param begin Begin iterator of subsequence.
    156   *  @param end End iterator of subsequence.
    157   *  @param comp Comparator.
    158   *  @param iam Number of the thread processing this function.
    159   *  @param num_threads
    160   *          Number of threads that are allowed to work on this part. */
    161 template<typename RandomAccessIterator, typename Comparator>
    162   void
    163   qsb_conquer(QSBThreadLocal<RandomAccessIterator>** tls,
    164               RandomAccessIterator begin, RandomAccessIterator end,
    165               Comparator comp,
    166               thread_index_t iam, thread_index_t num_threads,
    167               bool parent_wait)
    168   {
    169     typedef std::iterator_traits<RandomAccessIterator> traits_type;
    170     typedef typename traits_type::value_type value_type;
    171     typedef typename traits_type::difference_type difference_type;
    172 
    173     difference_type n = end - begin;
    174 
    175     if (num_threads <= 1 || n <= 1)
    176       {
    177         tls[iam]->initial.first  = begin;
    178         tls[iam]->initial.second = end;
    179 
    180         qsb_local_sort_with_helping(tls, comp, iam, parent_wait);
    181 
    182         return;
    183       }
    184 
    185     // Divide step.
    186     difference_type split_pos = qsb_divide(begin, end, comp, num_threads);
    187 
    188 #if _GLIBCXX_ASSERTIONS
    189     _GLIBCXX_PARALLEL_ASSERT(0 <= split_pos && split_pos < (end - begin));
    190 #endif
    191 
    192     thread_index_t num_threads_leftside =
    193         std::max<thread_index_t>(1, std::min<thread_index_t>(
    194                           num_threads - 1, split_pos * num_threads / n));
    195 
    196 #   pragma omp atomic
    197     *tls[iam]->elements_leftover -= (difference_type)1;
    198 
    199     // Conquer step.
    200 #   pragma omp parallel num_threads(2)
    201     {
    202       bool wait;
    203       if(omp_get_num_threads() < 2)
    204         wait = false;
    205       else
    206         wait = parent_wait;
    207 
    208 #     pragma omp sections
    209         {
    210 #         pragma omp section
    211             {
    212               qsb_conquer(tls, begin, begin + split_pos, comp,
    213                           iam,
    214                           num_threads_leftside,
    215                           wait);
    216               wait = parent_wait;
    217             }
    218           // The pivot_pos is left in place, to ensure termination.
    219 #         pragma omp section
    220             {
    221               qsb_conquer(tls, begin + split_pos + 1, end, comp,
    222                           iam + num_threads_leftside,
    223                           num_threads - num_threads_leftside,
    224                           wait);
    225               wait = parent_wait;
    226             }
    227         }
    228     }
    229   }
    230 
    231 /**
    232   *  @brief Quicksort step doing load-balanced local sort.
    233   *  @param tls Array of thread-local storages.
    234   *  @param comp Comparator.
    235   *  @param iam Number of the thread processing this function.
    236   */
    237 template<typename RandomAccessIterator, typename Comparator>
    238   void
    239   qsb_local_sort_with_helping(QSBThreadLocal<RandomAccessIterator>** tls,
    240                               Comparator& comp, int iam, bool wait)
    241   {
    242     typedef std::iterator_traits<RandomAccessIterator> traits_type;
    243     typedef typename traits_type::value_type value_type;
    244     typedef typename traits_type::difference_type difference_type;
    245     typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
    246 
    247     QSBThreadLocal<RandomAccessIterator>& tl = *tls[iam];
    248 
    249     difference_type base_case_n =
    250         _Settings::get().sort_qsb_base_case_maximal_n;
    251     if (base_case_n < 2)
    252       base_case_n = 2;
    253     thread_index_t num_threads = tl.num_threads;
    254 
    255     // Every thread has its own random number generator.
    256     random_number rng(iam + 1);
    257 
    258     Piece current = tl.initial;
    259 
    260     difference_type elements_done = 0;
    261 #if _GLIBCXX_ASSERTIONS
    262     difference_type total_elements_done = 0;
    263 #endif
    264 
    265     for (;;)
    266       {
    267         // Invariant: current must be a valid (maybe empty) range.
    268         RandomAccessIterator begin = current.first, end = current.second;
    269         difference_type n = end - begin;
    270 
    271         if (n > base_case_n)
    272           {
    273             // Divide.
    274             RandomAccessIterator pivot_pos = begin +  rng(n);
    275 
    276             // Swap pivot_pos value to end.
    277             if (pivot_pos != (end - 1))
    278               std::swap(*pivot_pos, *(end - 1));
    279             pivot_pos = end - 1;
    280 
    281             __gnu_parallel::binder2nd
    282                 <Comparator, value_type, value_type, bool>
    283                 pred(comp, *pivot_pos);
    284 
    285             // Divide, leave pivot unchanged in last place.
    286             RandomAccessIterator split_pos1, split_pos2;
    287             split_pos1 = __gnu_sequential::partition(begin, end - 1, pred);
    288 
    289             // Left side: < pivot_pos; right side: >= pivot_pos.
    290 #if _GLIBCXX_ASSERTIONS
    291             _GLIBCXX_PARALLEL_ASSERT(begin <= split_pos1 && split_pos1 < end);
    292 #endif
    293             // Swap pivot back to middle.
    294             if (split_pos1 != pivot_pos)
    295               std::swap(*split_pos1, *pivot_pos);
    296             pivot_pos = split_pos1;
    297 
    298             // In case all elements are equal, split_pos1 == 0.
    299             if ((split_pos1 + 1 - begin) < (n >> 7)
    300             || (end - split_pos1) < (n >> 7))
    301               {
    302                 // Very unequal split, one part smaller than one 128th
    303                 // elements not strictly larger than the pivot.
    304                 __gnu_parallel::unary_negate<__gnu_parallel::binder1st
    305 		  <Comparator, value_type, value_type, bool>, value_type>
    306 		  pred(__gnu_parallel::binder1st
    307 		       <Comparator, value_type, value_type, bool>(comp,
    308 								  *pivot_pos));
    309 
    310                 // Find other end of pivot-equal range.
    311                 split_pos2 = __gnu_sequential::partition(split_pos1 + 1,
    312 							 end, pred);
    313               }
    314             else
    315               // Only skip the pivot.
    316               split_pos2 = split_pos1 + 1;
    317 
    318             // Elements equal to pivot are done.
    319             elements_done += (split_pos2 - split_pos1);
    320 #if _GLIBCXX_ASSERTIONS
    321             total_elements_done += (split_pos2 - split_pos1);
    322 #endif
    323             // Always push larger part onto stack.
    324             if (((split_pos1 + 1) - begin) < (end - (split_pos2)))
    325               {
    326                 // Right side larger.
    327                 if ((split_pos2) != end)
    328                   tl.leftover_parts.push_front(std::make_pair(split_pos2,
    329 							      end));
    330 
    331                 //current.first = begin;	//already set anyway
    332                 current.second = split_pos1;
    333                 continue;
    334               }
    335             else
    336               {
    337                 // Left side larger.
    338                 if (begin != split_pos1)
    339                   tl.leftover_parts.push_front(std::make_pair(begin,
    340 							      split_pos1));
    341 
    342                 current.first = split_pos2;
    343                 //current.second = end;	//already set anyway
    344                 continue;
    345               }
    346           }
    347         else
    348           {
    349             __gnu_sequential::sort(begin, end, comp);
    350             elements_done += n;
    351 #if _GLIBCXX_ASSERTIONS
    352             total_elements_done += n;
    353 #endif
    354 
    355             // Prefer own stack, small pieces.
    356             if (tl.leftover_parts.pop_front(current))
    357               continue;
    358 
    359 #           pragma omp atomic
    360             *tl.elements_leftover -= elements_done;
    361 
    362             elements_done = 0;
    363 
    364 #if _GLIBCXX_ASSERTIONS
    365             double search_start = omp_get_wtime();
    366 #endif
    367 
    368             // Look for new work.
    369             bool successfully_stolen = false;
    370             while (wait && *tl.elements_leftover > 0 && !successfully_stolen
    371 #if _GLIBCXX_ASSERTIONS
    372               // Possible dead-lock.
    373               && (omp_get_wtime() < (search_start + 1.0))
    374 #endif
    375               )
    376               {
    377                 thread_index_t victim;
    378                 victim = rng(num_threads);
    379 
    380                 // Large pieces.
    381                 successfully_stolen = (victim != iam)
    382                     && tls[victim]->leftover_parts.pop_back(current);
    383                 if (!successfully_stolen)
    384                   yield();
    385 #if !defined(__ICC) && !defined(__ECC)
    386 #               pragma omp flush
    387 #endif
    388               }
    389 
    390 #if _GLIBCXX_ASSERTIONS
    391             if (omp_get_wtime() >= (search_start + 1.0))
    392               {
    393                 sleep(1);
    394                 _GLIBCXX_PARALLEL_ASSERT(omp_get_wtime()
    395 					 < (search_start + 1.0));
    396               }
    397 #endif
    398             if (!successfully_stolen)
    399               {
    400 #if _GLIBCXX_ASSERTIONS
    401                 _GLIBCXX_PARALLEL_ASSERT(*tl.elements_leftover == 0);
    402 #endif
    403                 return;
    404               }
    405           }
    406       }
    407   }
    408 
    409 /** @brief Top-level quicksort routine.
    410   *  @param begin Begin iterator of sequence.
    411   *  @param end End iterator of sequence.
    412   *  @param comp Comparator.
    413   *  @param num_threads Number of threads that are allowed to work on
    414   *  this part.
    415   */
    416 template<typename RandomAccessIterator, typename Comparator>
    417   void
    418   parallel_sort_qsb(RandomAccessIterator begin, RandomAccessIterator end,
    419                     Comparator comp,
    420                     thread_index_t num_threads)
    421   {
    422     _GLIBCXX_CALL(end - begin)
    423 
    424     typedef std::iterator_traits<RandomAccessIterator> traits_type;
    425     typedef typename traits_type::value_type value_type;
    426     typedef typename traits_type::difference_type difference_type;
    427     typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
    428 
    429     typedef QSBThreadLocal<RandomAccessIterator> tls_type;
    430 
    431     difference_type n = end - begin;
    432 
    433     if (n <= 1)
    434       return;
    435 
    436     // At least one element per processor.
    437     if (num_threads > n)
    438       num_threads = static_cast<thread_index_t>(n);
    439 
    440     // Initialize thread local storage
    441     tls_type** tls = new tls_type*[num_threads];
    442     difference_type queue_size = num_threads * (thread_index_t)(log2(n) + 1);
    443     for (thread_index_t t = 0; t < num_threads; ++t)
    444       tls[t] = new QSBThreadLocal<RandomAccessIterator>(queue_size);
    445 
    446     // There can never be more than ceil(log2(n)) ranges on the stack, because
    447     // 1. Only one processor pushes onto the stack
    448     // 2. The largest range has at most length n
    449     // 3. Each range is larger than half of the range remaining
    450     volatile difference_type elements_leftover = n;
    451     for (int i = 0; i < num_threads; ++i)
    452       {
    453         tls[i]->elements_leftover = &elements_leftover;
    454         tls[i]->num_threads = num_threads;
    455         tls[i]->global = std::make_pair(begin, end);
    456 
    457         // Just in case nothing is left to assign.
    458         tls[i]->initial = std::make_pair(end, end);
    459       }
    460 
    461     // Main recursion call.
    462     qsb_conquer(tls, begin, begin + n, comp, 0, num_threads, true);
    463 
    464 #if _GLIBCXX_ASSERTIONS
    465     // All stack must be empty.
    466     Piece dummy;
    467     for (int i = 1; i < num_threads; ++i)
    468       _GLIBCXX_PARALLEL_ASSERT(!tls[i]->leftover_parts.pop_back(dummy));
    469 #endif
    470 
    471     for (int i = 0; i < num_threads; ++i)
    472       delete tls[i];
    473     delete[] tls;
    474   }
    475 } // namespace __gnu_parallel
    476 
    477 #endif /* _GLIBCXX_PARALLEL_BALANCED_QUICKSORT_H */
    478