Home | History | Annotate | Download | only in parallel
      1 // -*- C++ -*-
      2 
      3 // Copyright (C) 2007, 2008, 2009 Free Software Foundation, Inc.
      4 //
      5 // This file is part of the GNU ISO C++ Library.  This library is free
      6 // software; you can redistribute it and/or modify it under the terms
      7 // of the GNU General Public License as published by the Free Software
      8 // Foundation; either version 3, or (at your option) any later
      9 // version.
     10 
     11 // This library is distributed in the hope that it will be useful, but
     12 // WITHOUT ANY WARRANTY; without even the implied warranty of
     13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
     14 // General Public License for more details.
     15 
     16 // Under Section 7 of GPL version 3, you are granted additional
     17 // permissions described in the GCC Runtime Library Exception, version
     18 // 3.1, as published by the Free Software Foundation.
     19 
     20 // You should have received a copy of the GNU General Public License and
     21 // a copy of the GCC Runtime Library Exception along with this program;
     22 // see the files COPYING3 and COPYING.RUNTIME respectively.  If not, see
     23 // <http://www.gnu.org/licenses/>.
     24 
     25 /** @file parallel/multiway_mergesort.h
     26  *  @brief Parallel multiway merge sort.
     27  *  This file is a GNU parallel extension to the Standard C++ Library.
     28  */
     29 
     30 // Written by Johannes Singler.
     31 
     32 #ifndef _GLIBCXX_PARALLEL_MULTIWAY_MERGESORT_H
     33 #define _GLIBCXX_PARALLEL_MULTIWAY_MERGESORT_H 1
     34 
     35 #include <vector>
     36 
     37 #include <parallel/basic_iterator.h>
     38 #include <bits/stl_algo.h>
     39 #include <parallel/parallel.h>
     40 #include <parallel/multiway_merge.h>
     41 
     42 namespace __gnu_parallel
     43 {
     44 
     45 /** @brief Subsequence description. */
     46 template<typename _DifferenceTp>
     47   struct Piece
     48   {
     49     typedef _DifferenceTp difference_type;
     50 
     51     /** @brief Begin of subsequence. */
     52     difference_type begin;
     53 
     54     /** @brief End of subsequence. */
     55     difference_type end;
     56   };
     57 
     58 /** @brief Data accessed by all threads.
     59   *
     60   *  PMWMS = parallel multiway mergesort */
     61 template<typename RandomAccessIterator>
     62   struct PMWMSSortingData
     63   {
     64     typedef std::iterator_traits<RandomAccessIterator> traits_type;
     65     typedef typename traits_type::value_type value_type;
     66     typedef typename traits_type::difference_type difference_type;
     67 
     68     /** @brief Number of threads involved. */
     69     thread_index_t num_threads;
     70 
     71     /** @brief Input begin. */
     72     RandomAccessIterator source;
     73 
     74     /** @brief Start indices, per thread. */
     75     difference_type* starts;
     76 
     77     /** @brief Storage in which to sort. */
     78     value_type** temporary;
     79 
     80     /** @brief Samples. */
     81     value_type* samples;
     82 
     83     /** @brief Offsets to add to the found positions. */
     84     difference_type* offsets;
     85 
     86     /** @brief Pieces of data to merge @c [thread][sequence] */
     87     std::vector<Piece<difference_type> >* pieces;
     88 };
     89 
     90 /**
     91   *  @brief Select samples from a sequence.
     92   *  @param sd Pointer to algorithm data. Result will be placed in
     93   *  @c sd->samples.
     94   *  @param num_samples Number of samples to select.
     95   */
     96 template<typename RandomAccessIterator, typename _DifferenceTp>
     97   void
     98   determine_samples(PMWMSSortingData<RandomAccessIterator>* sd,
     99                     _DifferenceTp num_samples)
    100   {
    101     typedef std::iterator_traits<RandomAccessIterator> traits_type;
    102     typedef typename traits_type::value_type value_type;
    103     typedef _DifferenceTp difference_type;
    104 
    105     thread_index_t iam = omp_get_thread_num();
    106 
    107     difference_type* es = new difference_type[num_samples + 2];
    108 
    109     equally_split(sd->starts[iam + 1] - sd->starts[iam],
    110                   num_samples + 1, es);
    111 
    112     for (difference_type i = 0; i < num_samples; ++i)
    113       ::new(&(sd->samples[iam * num_samples + i]))
    114 	  value_type(sd->source[sd->starts[iam] + es[i + 1]]);
    115 
    116     delete[] es;
    117   }
    118 
    119 /** @brief Split consistently. */
    120 template<bool exact, typename RandomAccessIterator,
    121           typename Comparator, typename SortingPlacesIterator>
    122   struct split_consistently
    123   {
    124   };
    125 
    126 /** @brief Split by exact splitting. */
    127 template<typename RandomAccessIterator, typename Comparator,
    128           typename SortingPlacesIterator>
    129   struct split_consistently
    130     <true, RandomAccessIterator, Comparator, SortingPlacesIterator>
    131   {
    132     void operator()(
    133       const thread_index_t iam,
    134       PMWMSSortingData<RandomAccessIterator>* sd,
    135       Comparator& comp,
    136       const typename
    137         std::iterator_traits<RandomAccessIterator>::difference_type
    138           num_samples)
    139       const
    140   {
    141 #   pragma omp barrier
    142 
    143     std::vector<std::pair<SortingPlacesIterator, SortingPlacesIterator> >
    144         seqs(sd->num_threads);
    145     for (thread_index_t s = 0; s < sd->num_threads; s++)
    146       seqs[s] = std::make_pair(sd->temporary[s],
    147                                 sd->temporary[s]
    148                                     + (sd->starts[s + 1] - sd->starts[s]));
    149 
    150     std::vector<SortingPlacesIterator> offsets(sd->num_threads);
    151 
    152     // if not last thread
    153     if (iam < sd->num_threads - 1)
    154       multiseq_partition(seqs.begin(), seqs.end(),
    155                           sd->starts[iam + 1], offsets.begin(), comp);
    156 
    157     for (int seq = 0; seq < sd->num_threads; seq++)
    158       {
    159         // for each sequence
    160         if (iam < (sd->num_threads - 1))
    161           sd->pieces[iam][seq].end = offsets[seq] - seqs[seq].first;
    162         else
    163           // very end of this sequence
    164           sd->pieces[iam][seq].end =
    165               sd->starts[seq + 1] - sd->starts[seq];
    166       }
    167 
    168 #   pragma omp barrier
    169 
    170     for (thread_index_t seq = 0; seq < sd->num_threads; seq++)
    171       {
    172         // For each sequence.
    173         if (iam > 0)
    174           sd->pieces[iam][seq].begin = sd->pieces[iam - 1][seq].end;
    175         else
    176           // Absolute beginning.
    177           sd->pieces[iam][seq].begin = 0;
    178       }
    179   }
    180   };
    181 
    182 /** @brief Split by sampling. */
    183 template<typename RandomAccessIterator, typename Comparator,
    184           typename SortingPlacesIterator>
    185   struct split_consistently<false, RandomAccessIterator, Comparator,
    186                              SortingPlacesIterator>
    187   {
    188     void operator()(
    189         const thread_index_t iam,
    190         PMWMSSortingData<RandomAccessIterator>* sd,
    191         Comparator& comp,
    192         const typename
    193           std::iterator_traits<RandomAccessIterator>::difference_type
    194             num_samples)
    195         const
    196     {
    197       typedef std::iterator_traits<RandomAccessIterator> traits_type;
    198       typedef typename traits_type::value_type value_type;
    199       typedef typename traits_type::difference_type difference_type;
    200 
    201       determine_samples(sd, num_samples);
    202 
    203 #     pragma omp barrier
    204 
    205 #     pragma omp single
    206       __gnu_sequential::sort(sd->samples,
    207                              sd->samples + (num_samples * sd->num_threads),
    208                              comp);
    209 
    210 #     pragma omp barrier
    211 
    212       for (thread_index_t s = 0; s < sd->num_threads; ++s)
    213         {
    214           // For each sequence.
    215           if (num_samples * iam > 0)
    216             sd->pieces[iam][s].begin =
    217                 std::lower_bound(sd->temporary[s],
    218                     sd->temporary[s]
    219                         + (sd->starts[s + 1] - sd->starts[s]),
    220                     sd->samples[num_samples * iam],
    221                     comp)
    222                 - sd->temporary[s];
    223           else
    224             // Absolute beginning.
    225             sd->pieces[iam][s].begin = 0;
    226 
    227           if ((num_samples * (iam + 1)) < (num_samples * sd->num_threads))
    228             sd->pieces[iam][s].end =
    229                 std::lower_bound(sd->temporary[s],
    230                         sd->temporary[s]
    231                             + (sd->starts[s + 1] - sd->starts[s]),
    232                         sd->samples[num_samples * (iam + 1)],
    233                         comp)
    234                 - sd->temporary[s];
    235           else
    236             // Absolute end.
    237             sd->pieces[iam][s].end = sd->starts[s + 1] - sd->starts[s];
    238         }
    239     }
    240   };
    241 
    242 template<bool stable, typename RandomAccessIterator, typename Comparator>
    243   struct possibly_stable_sort
    244   {
    245   };
    246 
    247 template<typename RandomAccessIterator, typename Comparator>
    248   struct possibly_stable_sort<true, RandomAccessIterator, Comparator>
    249   {
    250     void operator()(const RandomAccessIterator& begin,
    251                      const RandomAccessIterator& end, Comparator& comp) const
    252     {
    253       __gnu_sequential::stable_sort(begin, end, comp);
    254     }
    255   };
    256 
    257 template<typename RandomAccessIterator, typename Comparator>
    258   struct possibly_stable_sort<false, RandomAccessIterator, Comparator>
    259   {
    260     void operator()(const RandomAccessIterator begin,
    261                      const RandomAccessIterator end, Comparator& comp) const
    262     {
    263       __gnu_sequential::sort(begin, end, comp);
    264     }
    265   };
    266 
    267 template<bool stable, typename SeqRandomAccessIterator,
    268           typename RandomAccessIterator, typename Comparator,
    269           typename DiffType>
    270   struct possibly_stable_multiway_merge
    271   {
    272   };
    273 
    274 template<typename SeqRandomAccessIterator, typename RandomAccessIterator,
    275           typename Comparator, typename DiffType>
    276   struct possibly_stable_multiway_merge
    277     <true, SeqRandomAccessIterator, RandomAccessIterator, Comparator,
    278     DiffType>
    279   {
    280     void operator()(const SeqRandomAccessIterator& seqs_begin,
    281                       const SeqRandomAccessIterator& seqs_end,
    282                       const RandomAccessIterator& target,
    283                       Comparator& comp,
    284                       DiffType length_am) const
    285     {
    286       stable_multiway_merge(seqs_begin, seqs_end, target, length_am, comp,
    287                        sequential_tag());
    288     }
    289   };
    290 
    291 template<typename SeqRandomAccessIterator, typename RandomAccessIterator,
    292           typename Comparator, typename DiffType>
    293   struct possibly_stable_multiway_merge
    294     <false, SeqRandomAccessIterator, RandomAccessIterator, Comparator,
    295     DiffType>
    296   {
    297     void operator()(const SeqRandomAccessIterator& seqs_begin,
    298                       const SeqRandomAccessIterator& seqs_end,
    299                       const RandomAccessIterator& target,
    300                       Comparator& comp,
    301                       DiffType length_am) const
    302     {
    303       multiway_merge(seqs_begin, seqs_end, target, length_am, comp,
    304                        sequential_tag());
    305     }
    306   };
    307 
    308 /** @brief PMWMS code executed by each thread.
    309   *  @param sd Pointer to algorithm data.
    310   *  @param comp Comparator.
    311   */
    312 template<bool stable, bool exact, typename RandomAccessIterator,
    313           typename Comparator>
    314   void
    315   parallel_sort_mwms_pu(PMWMSSortingData<RandomAccessIterator>* sd,
    316                         Comparator& comp)
    317   {
    318     typedef std::iterator_traits<RandomAccessIterator> traits_type;
    319     typedef typename traits_type::value_type value_type;
    320     typedef typename traits_type::difference_type difference_type;
    321 
    322     thread_index_t iam = omp_get_thread_num();
    323 
    324     // Length of this thread's chunk, before merging.
    325     difference_type length_local = sd->starts[iam + 1] - sd->starts[iam];
    326 
    327     // Sort in temporary storage, leave space for sentinel.
    328 
    329     typedef value_type* SortingPlacesIterator;
    330 
    331     sd->temporary[iam] =
    332         static_cast<value_type*>(
    333         ::operator new(sizeof(value_type) * (length_local + 1)));
    334 
    335     // Copy there.
    336     std::uninitialized_copy(sd->source + sd->starts[iam],
    337                             sd->source + sd->starts[iam] + length_local,
    338                             sd->temporary[iam]);
    339 
    340     possibly_stable_sort<stable, SortingPlacesIterator, Comparator>()
    341         (sd->temporary[iam], sd->temporary[iam] + length_local, comp);
    342 
    343     // Invariant: locally sorted subsequence in sd->temporary[iam],
    344     // sd->temporary[iam] + length_local.
    345 
    346     // No barrier here: Synchronization is done by the splitting routine.
    347 
    348     difference_type num_samples =
    349         _Settings::get().sort_mwms_oversampling * sd->num_threads - 1;
    350     split_consistently
    351       <exact, RandomAccessIterator, Comparator, SortingPlacesIterator>()
    352         (iam, sd, comp, num_samples);
    353 
    354     // Offset from target begin, length after merging.
    355     difference_type offset = 0, length_am = 0;
    356     for (thread_index_t s = 0; s < sd->num_threads; s++)
    357       {
    358         length_am += sd->pieces[iam][s].end - sd->pieces[iam][s].begin;
    359         offset += sd->pieces[iam][s].begin;
    360       }
    361 
    362     typedef std::vector<
    363       std::pair<SortingPlacesIterator, SortingPlacesIterator> >
    364         seq_vector_type;
    365     seq_vector_type seqs(sd->num_threads);
    366 
    367     for (int s = 0; s < sd->num_threads; ++s)
    368       {
    369         seqs[s] =
    370           std::make_pair(sd->temporary[s] + sd->pieces[iam][s].begin,
    371         sd->temporary[s] + sd->pieces[iam][s].end);
    372       }
    373 
    374     possibly_stable_multiway_merge<
    375         stable,
    376         typename seq_vector_type::iterator,
    377         RandomAccessIterator,
    378         Comparator, difference_type>()
    379           (seqs.begin(), seqs.end(),
    380            sd->source + offset, comp,
    381            length_am);
    382 
    383 #   pragma omp barrier
    384 
    385     ::operator delete(sd->temporary[iam]);
    386   }
    387 
    388 /** @brief PMWMS main call.
    389   *  @param begin Begin iterator of sequence.
    390   *  @param end End iterator of sequence.
    391   *  @param comp Comparator.
    392   *  @param n Length of sequence.
    393   *  @param num_threads Number of threads to use.
    394   */
    395 template<bool stable, bool exact, typename RandomAccessIterator,
    396            typename Comparator>
    397   void
    398   parallel_sort_mwms(RandomAccessIterator begin, RandomAccessIterator end,
    399                      Comparator comp,
    400                      thread_index_t num_threads)
    401   {
    402     _GLIBCXX_CALL(end - begin)
    403 
    404     typedef std::iterator_traits<RandomAccessIterator> traits_type;
    405     typedef typename traits_type::value_type value_type;
    406     typedef typename traits_type::difference_type difference_type;
    407 
    408     difference_type n = end - begin;
    409 
    410     if (n <= 1)
    411       return;
    412 
    413     // at least one element per thread
    414     if (num_threads > n)
    415       num_threads = static_cast<thread_index_t>(n);
    416 
    417     // shared variables
    418     PMWMSSortingData<RandomAccessIterator> sd;
    419     difference_type* starts;
    420 
    421 #   pragma omp parallel num_threads(num_threads)
    422       {
    423         num_threads = omp_get_num_threads();  //no more threads than requested
    424 
    425 #       pragma omp single
    426           {
    427             sd.num_threads = num_threads;
    428             sd.source = begin;
    429 
    430             sd.temporary = new value_type*[num_threads];
    431 
    432             if (!exact)
    433               {
    434                 difference_type size =
    435                     (_Settings::get().sort_mwms_oversampling * num_threads - 1)
    436                         * num_threads;
    437                 sd.samples = static_cast<value_type*>(
    438                               ::operator new(size * sizeof(value_type)));
    439               }
    440             else
    441               sd.samples = NULL;
    442 
    443             sd.offsets = new difference_type[num_threads - 1];
    444             sd.pieces = new std::vector<Piece<difference_type> >[num_threads];
    445             for (int s = 0; s < num_threads; ++s)
    446               sd.pieces[s].resize(num_threads);
    447             starts = sd.starts = new difference_type[num_threads + 1];
    448 
    449             difference_type chunk_length = n / num_threads;
    450             difference_type split = n % num_threads;
    451             difference_type pos = 0;
    452             for (int i = 0; i < num_threads; ++i)
    453               {
    454                 starts[i] = pos;
    455                 pos += (i < split) ? (chunk_length + 1) : chunk_length;
    456               }
    457             starts[num_threads] = pos;
    458           } //single
    459 
    460         // Now sort in parallel.
    461         parallel_sort_mwms_pu<stable, exact>(&sd, comp);
    462       } //parallel
    463 
    464     delete[] starts;
    465     delete[] sd.temporary;
    466 
    467     if (!exact)
    468       ::operator delete(sd.samples);
    469 
    470     delete[] sd.offsets;
    471     delete[] sd.pieces;
    472   }
    473 } //namespace __gnu_parallel
    474 
    475 #endif /* _GLIBCXX_PARALLEL_MULTIWAY_MERGESORT_H */
    476