Home | History | Annotate | Download | only in end.to.end
      1 // -*- C++ -*-
      2 //===----------------------------------------------------------------------===//
      3 //
      4 //                     The LLVM Compiler Infrastructure
      5 //
      6 // This file is dual licensed under the MIT and the University of Illinois Open
      7 // Source Licenses. See LICENSE.TXT for details.
      8 //
      9 //===----------------------------------------------------------------------===//
     10 
     11 // UNSUPPORTED: c++98, c++03, c++11
     12 
     13 #include <experimental/coroutine>
     14 #include <cassert>
     15 
     16 using namespace std::experimental;
     17 
     18 bool cancel = false;
     19 
     20 struct goroutine
     21 {
     22   static int const N = 10;
     23   static int count;
     24   static coroutine_handle<> stack[N];
     25 
     26   static void schedule(coroutine_handle<>& rh)
     27   {
     28     assert(count < N);
     29     stack[count++] = rh;
     30     rh = nullptr;
     31   }
     32 
     33   ~goroutine() {}
     34 
     35   static void go(goroutine) {}
     36 
     37   static void run_one()
     38   {
     39     assert(count > 0);
     40     stack[--count]();
     41   }
     42 
     43   struct promise_type
     44   {
     45     suspend_never initial_suspend() {
     46       return {};
     47     }
     48     suspend_never final_suspend() {
     49       return {};
     50     }
     51     void return_void() {}
     52     goroutine get_return_object() {
     53       return{};
     54     }
     55     void unhandled_exception() {}
     56   };
     57 };
     58 int goroutine::count;
     59 coroutine_handle<> goroutine::stack[N];
     60 
     61 coroutine_handle<goroutine::promise_type> workaround;
     62 
     63 class channel;
     64 
     65 struct push_awaiter {
     66   channel* ch;
     67   bool await_ready() {return false; }
     68   void await_suspend(coroutine_handle<> rh);
     69   void await_resume() {}
     70 };
     71 
     72 struct pull_awaiter {
     73   channel * ch;
     74 
     75   bool await_ready();
     76   void await_suspend(coroutine_handle<> rh);
     77   int await_resume();
     78 };
     79 
     80 class channel
     81 {
     82   using T = int;
     83 
     84   friend struct push_awaiter;
     85   friend struct pull_awaiter;
     86 
     87   T const* pvalue = nullptr;
     88   coroutine_handle<> reader = nullptr;
     89   coroutine_handle<> writer = nullptr;
     90 public:
     91   push_awaiter push(T const& value)
     92   {
     93     assert(pvalue == nullptr);
     94     assert(!writer);
     95     pvalue = &value;
     96 
     97     return { this };
     98   }
     99 
    100   pull_awaiter pull()
    101   {
    102     assert(!reader);
    103 
    104     return { this };
    105   }
    106 
    107   void sync_push(T const& value)
    108   {
    109     assert(!pvalue);
    110     pvalue = &value;
    111     assert(reader);
    112     reader();
    113     assert(!pvalue);
    114     reader = nullptr;
    115   }
    116 
    117   auto sync_pull()
    118   {
    119     while (!pvalue) goroutine::run_one();
    120     auto result = *pvalue;
    121     pvalue = nullptr;
    122     if (writer)
    123     {
    124       auto wr = writer;
    125       writer = nullptr;
    126       wr();
    127     }
    128     return result;
    129   }
    130 };
    131 
    132 void push_awaiter::await_suspend(coroutine_handle<> rh)
    133 {
    134   ch->writer = rh;
    135   if (ch->reader) goroutine::schedule(ch->reader);
    136 }
    137 
    138 
    139 bool pull_awaiter::await_ready() {
    140   return !!ch->writer;
    141 }
    142 void pull_awaiter::await_suspend(coroutine_handle<> rh) {
    143   ch->reader = rh;
    144 }
    145 int pull_awaiter::await_resume() {
    146   auto result = *ch->pvalue;
    147   ch->pvalue = nullptr;
    148   if (ch->writer) {
    149     //goroutine::schedule(ch->writer);
    150     auto wr = ch->writer;
    151     ch->writer = nullptr;
    152     wr();
    153   }
    154   return result;
    155 }
    156 
    157 goroutine pusher(channel& left, channel& right)
    158 {
    159   for (;;) {
    160     auto val = co_await left.pull();
    161     co_await right.push(val + 1);
    162   }
    163 }
    164 
    165 const int N = 100;
    166 channel* c = new channel[N + 1];
    167 
    168 int main() {
    169   for (int i = 0; i < N; ++i)
    170     goroutine::go(pusher(c[i], c[i + 1]));
    171 
    172   c[0].sync_push(0);
    173   int result = c[N].sync_pull();
    174 
    175   assert(result == 100);
    176 }
    177