1 // -*- C++ -*- 2 //===----------------------------------------------------------------------===// 3 // 4 // The LLVM Compiler Infrastructure 5 // 6 // This file is dual licensed under the MIT and the University of Illinois Open 7 // Source Licenses. See LICENSE.TXT for details. 8 // 9 //===----------------------------------------------------------------------===// 10 11 // UNSUPPORTED: c++98, c++03, c++11 12 13 #include <experimental/coroutine> 14 #include <cassert> 15 16 using namespace std::experimental; 17 18 bool cancel = false; 19 20 struct goroutine 21 { 22 static int const N = 10; 23 static int count; 24 static coroutine_handle<> stack[N]; 25 26 static void schedule(coroutine_handle<>& rh) 27 { 28 assert(count < N); 29 stack[count++] = rh; 30 rh = nullptr; 31 } 32 33 ~goroutine() {} 34 35 static void go(goroutine) {} 36 37 static void run_one() 38 { 39 assert(count > 0); 40 stack[--count](); 41 } 42 43 struct promise_type 44 { 45 suspend_never initial_suspend() { 46 return {}; 47 } 48 suspend_never final_suspend() { 49 return {}; 50 } 51 void return_void() {} 52 goroutine get_return_object() { 53 return{}; 54 } 55 void unhandled_exception() {} 56 }; 57 }; 58 int goroutine::count; 59 coroutine_handle<> goroutine::stack[N]; 60 61 coroutine_handle<goroutine::promise_type> workaround; 62 63 class channel; 64 65 struct push_awaiter { 66 channel* ch; 67 bool await_ready() {return false; } 68 void await_suspend(coroutine_handle<> rh); 69 void await_resume() {} 70 }; 71 72 struct pull_awaiter { 73 channel * ch; 74 75 bool await_ready(); 76 void await_suspend(coroutine_handle<> rh); 77 int await_resume(); 78 }; 79 80 class channel 81 { 82 using T = int; 83 84 friend struct push_awaiter; 85 friend struct pull_awaiter; 86 87 T const* pvalue = nullptr; 88 coroutine_handle<> reader = nullptr; 89 coroutine_handle<> writer = nullptr; 90 public: 91 push_awaiter push(T const& value) 92 { 93 assert(pvalue == nullptr); 94 assert(!writer); 95 pvalue = &value; 96 97 return { this }; 98 } 99 100 pull_awaiter pull() 101 { 102 assert(!reader); 103 104 return { this }; 105 } 106 107 void sync_push(T const& value) 108 { 109 assert(!pvalue); 110 pvalue = &value; 111 assert(reader); 112 reader(); 113 assert(!pvalue); 114 reader = nullptr; 115 } 116 117 auto sync_pull() 118 { 119 while (!pvalue) goroutine::run_one(); 120 auto result = *pvalue; 121 pvalue = nullptr; 122 if (writer) 123 { 124 auto wr = writer; 125 writer = nullptr; 126 wr(); 127 } 128 return result; 129 } 130 }; 131 132 void push_awaiter::await_suspend(coroutine_handle<> rh) 133 { 134 ch->writer = rh; 135 if (ch->reader) goroutine::schedule(ch->reader); 136 } 137 138 139 bool pull_awaiter::await_ready() { 140 return !!ch->writer; 141 } 142 void pull_awaiter::await_suspend(coroutine_handle<> rh) { 143 ch->reader = rh; 144 } 145 int pull_awaiter::await_resume() { 146 auto result = *ch->pvalue; 147 ch->pvalue = nullptr; 148 if (ch->writer) { 149 //goroutine::schedule(ch->writer); 150 auto wr = ch->writer; 151 ch->writer = nullptr; 152 wr(); 153 } 154 return result; 155 } 156 157 goroutine pusher(channel& left, channel& right) 158 { 159 for (;;) { 160 auto val = co_await left.pull(); 161 co_await right.push(val + 1); 162 } 163 } 164 165 const int N = 100; 166 channel* c = new channel[N + 1]; 167 168 int main() { 169 for (int i = 0; i < N; ++i) 170 goroutine::go(pusher(c[i], c[i + 1])); 171 172 c[0].sync_push(0); 173 int result = c[N].sync_pull(); 174 175 assert(result == 100); 176 } 177