1 //===--- SemaCoroutines.cpp - Semantic Analysis for Coroutines ------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file implements semantic analysis for C++ Coroutines. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "clang/Sema/SemaInternal.h" 15 #include "clang/AST/Decl.h" 16 #include "clang/AST/ExprCXX.h" 17 #include "clang/AST/StmtCXX.h" 18 #include "clang/Lex/Preprocessor.h" 19 #include "clang/Sema/Initialization.h" 20 #include "clang/Sema/Overload.h" 21 using namespace clang; 22 using namespace sema; 23 24 /// Look up the std::coroutine_traits<...>::promise_type for the given 25 /// function type. 26 static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType, 27 SourceLocation Loc) { 28 // FIXME: Cache std::coroutine_traits once we've found it. 29 NamespaceDecl *Std = S.getStdNamespace(); 30 if (!Std) { 31 S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found); 32 return QualType(); 33 } 34 35 LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"), 36 Loc, Sema::LookupOrdinaryName); 37 if (!S.LookupQualifiedName(Result, Std)) { 38 S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found); 39 return QualType(); 40 } 41 42 ClassTemplateDecl *CoroTraits = Result.getAsSingle<ClassTemplateDecl>(); 43 if (!CoroTraits) { 44 Result.suppressDiagnostics(); 45 // We found something weird. Complain about the first thing we found. 46 NamedDecl *Found = *Result.begin(); 47 S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits); 48 return QualType(); 49 } 50 51 // Form template argument list for coroutine_traits<R, P1, P2, ...>. 52 TemplateArgumentListInfo Args(Loc, Loc); 53 Args.addArgument(TemplateArgumentLoc( 54 TemplateArgument(FnType->getReturnType()), 55 S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), Loc))); 56 // FIXME: If the function is a non-static member function, add the type 57 // of the implicit object parameter before the formal parameters. 58 for (QualType T : FnType->getParamTypes()) 59 Args.addArgument(TemplateArgumentLoc( 60 TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, Loc))); 61 62 // Build the template-id. 63 QualType CoroTrait = 64 S.CheckTemplateIdType(TemplateName(CoroTraits), Loc, Args); 65 if (CoroTrait.isNull()) 66 return QualType(); 67 if (S.RequireCompleteType(Loc, CoroTrait, 68 diag::err_coroutine_traits_missing_specialization)) 69 return QualType(); 70 71 CXXRecordDecl *RD = CoroTrait->getAsCXXRecordDecl(); 72 assert(RD && "specialization of class template is not a class?"); 73 74 // Look up the ::promise_type member. 75 LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), Loc, 76 Sema::LookupOrdinaryName); 77 S.LookupQualifiedName(R, RD); 78 auto *Promise = R.getAsSingle<TypeDecl>(); 79 if (!Promise) { 80 S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_found) 81 << RD; 82 return QualType(); 83 } 84 85 // The promise type is required to be a class type. 86 QualType PromiseType = S.Context.getTypeDeclType(Promise); 87 if (!PromiseType->getAsCXXRecordDecl()) { 88 // Use the fully-qualified name of the type. 89 auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, Std); 90 NNS = NestedNameSpecifier::Create(S.Context, NNS, false, 91 CoroTrait.getTypePtr()); 92 PromiseType = S.Context.getElaboratedType(ETK_None, NNS, PromiseType); 93 94 S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_class) 95 << PromiseType; 96 return QualType(); 97 } 98 99 return PromiseType; 100 } 101 102 /// Check that this is a context in which a coroutine suspension can appear. 103 static FunctionScopeInfo * 104 checkCoroutineContext(Sema &S, SourceLocation Loc, StringRef Keyword) { 105 // 'co_await' and 'co_yield' are not permitted in unevaluated operands. 106 if (S.isUnevaluatedContext()) { 107 S.Diag(Loc, diag::err_coroutine_unevaluated_context) << Keyword; 108 return nullptr; 109 } 110 111 // Any other usage must be within a function. 112 // FIXME: Reject a coroutine with a deduced return type. 113 auto *FD = dyn_cast<FunctionDecl>(S.CurContext); 114 if (!FD) { 115 S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext) 116 ? diag::err_coroutine_objc_method 117 : diag::err_coroutine_outside_function) << Keyword; 118 } else if (isa<CXXConstructorDecl>(FD) || isa<CXXDestructorDecl>(FD)) { 119 // Coroutines TS [special]/6: 120 // A special member function shall not be a coroutine. 121 // 122 // FIXME: We assume that this really means that a coroutine cannot 123 // be a constructor or destructor. 124 S.Diag(Loc, diag::err_coroutine_ctor_dtor) 125 << isa<CXXDestructorDecl>(FD) << Keyword; 126 } else if (FD->isConstexpr()) { 127 S.Diag(Loc, diag::err_coroutine_constexpr) << Keyword; 128 } else if (FD->isVariadic()) { 129 S.Diag(Loc, diag::err_coroutine_varargs) << Keyword; 130 } else { 131 auto *ScopeInfo = S.getCurFunction(); 132 assert(ScopeInfo && "missing function scope for function"); 133 134 // If we don't have a promise variable, build one now. 135 if (!ScopeInfo->CoroutinePromise) { 136 QualType T = 137 FD->getType()->isDependentType() 138 ? S.Context.DependentTy 139 : lookupPromiseType(S, FD->getType()->castAs<FunctionProtoType>(), 140 Loc); 141 if (T.isNull()) 142 return nullptr; 143 144 // Create and default-initialize the promise. 145 ScopeInfo->CoroutinePromise = 146 VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(), 147 &S.PP.getIdentifierTable().get("__promise"), T, 148 S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None); 149 S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise); 150 if (!ScopeInfo->CoroutinePromise->isInvalidDecl()) 151 S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise, false); 152 } 153 154 return ScopeInfo; 155 } 156 157 return nullptr; 158 } 159 160 /// Build a call to 'operator co_await' if there is a suitable operator for 161 /// the given expression. 162 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S, 163 SourceLocation Loc, Expr *E) { 164 UnresolvedSet<16> Functions; 165 SemaRef.LookupOverloadedOperatorName(OO_Coawait, S, E->getType(), QualType(), 166 Functions); 167 return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E); 168 } 169 170 struct ReadySuspendResumeResult { 171 bool IsInvalid; 172 Expr *Results[3]; 173 }; 174 175 static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc, 176 StringRef Name, 177 MutableArrayRef<Expr *> Args) { 178 DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc); 179 180 // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&. 181 CXXScopeSpec SS; 182 ExprResult Result = S.BuildMemberReferenceExpr( 183 Base, Base->getType(), Loc, /*IsPtr=*/false, SS, 184 SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr, 185 /*Scope=*/nullptr); 186 if (Result.isInvalid()) 187 return ExprError(); 188 189 return S.ActOnCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr); 190 } 191 192 /// Build calls to await_ready, await_suspend, and await_resume for a co_await 193 /// expression. 194 static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, SourceLocation Loc, 195 Expr *E) { 196 // Assume invalid until we see otherwise. 197 ReadySuspendResumeResult Calls = {true, {}}; 198 199 const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"}; 200 for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) { 201 Expr *Operand = new (S.Context) OpaqueValueExpr( 202 Loc, E->getType(), VK_LValue, E->getObjectKind(), E); 203 204 // FIXME: Pass coroutine handle to await_suspend. 205 ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], None); 206 if (Result.isInvalid()) 207 return Calls; 208 Calls.Results[I] = Result.get(); 209 } 210 211 Calls.IsInvalid = false; 212 return Calls; 213 } 214 215 ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) { 216 if (E->getType()->isPlaceholderType()) { 217 ExprResult R = CheckPlaceholderExpr(E); 218 if (R.isInvalid()) return ExprError(); 219 E = R.get(); 220 } 221 222 ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E); 223 if (Awaitable.isInvalid()) 224 return ExprError(); 225 return BuildCoawaitExpr(Loc, Awaitable.get()); 226 } 227 ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) { 228 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await"); 229 if (!Coroutine) 230 return ExprError(); 231 232 if (E->getType()->isPlaceholderType()) { 233 ExprResult R = CheckPlaceholderExpr(E); 234 if (R.isInvalid()) return ExprError(); 235 E = R.get(); 236 } 237 238 if (E->getType()->isDependentType()) { 239 Expr *Res = new (Context) CoawaitExpr(Loc, Context.DependentTy, E); 240 Coroutine->CoroutineStmts.push_back(Res); 241 return Res; 242 } 243 244 // If the expression is a temporary, materialize it as an lvalue so that we 245 // can use it multiple times. 246 if (E->getValueKind() == VK_RValue) 247 E = new (Context) MaterializeTemporaryExpr(E->getType(), E, true); 248 249 // Build the await_ready, await_suspend, await_resume calls. 250 ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E); 251 if (RSS.IsInvalid) 252 return ExprError(); 253 254 Expr *Res = new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1], 255 RSS.Results[2]); 256 Coroutine->CoroutineStmts.push_back(Res); 257 return Res; 258 } 259 260 static ExprResult buildPromiseCall(Sema &S, FunctionScopeInfo *Coroutine, 261 SourceLocation Loc, StringRef Name, 262 MutableArrayRef<Expr *> Args) { 263 assert(Coroutine->CoroutinePromise && "no promise for coroutine"); 264 265 // Form a reference to the promise. 266 auto *Promise = Coroutine->CoroutinePromise; 267 ExprResult PromiseRef = S.BuildDeclRefExpr( 268 Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc); 269 if (PromiseRef.isInvalid()) 270 return ExprError(); 271 272 // Call 'yield_value', passing in E. 273 return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args); 274 } 275 276 ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) { 277 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield"); 278 if (!Coroutine) 279 return ExprError(); 280 281 // Build yield_value call. 282 ExprResult Awaitable = 283 buildPromiseCall(*this, Coroutine, Loc, "yield_value", E); 284 if (Awaitable.isInvalid()) 285 return ExprError(); 286 287 // Build 'operator co_await' call. 288 Awaitable = buildOperatorCoawaitCall(*this, S, Loc, Awaitable.get()); 289 if (Awaitable.isInvalid()) 290 return ExprError(); 291 292 return BuildCoyieldExpr(Loc, Awaitable.get()); 293 } 294 ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) { 295 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield"); 296 if (!Coroutine) 297 return ExprError(); 298 299 if (E->getType()->isPlaceholderType()) { 300 ExprResult R = CheckPlaceholderExpr(E); 301 if (R.isInvalid()) return ExprError(); 302 E = R.get(); 303 } 304 305 if (E->getType()->isDependentType()) { 306 Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E); 307 Coroutine->CoroutineStmts.push_back(Res); 308 return Res; 309 } 310 311 // If the expression is a temporary, materialize it as an lvalue so that we 312 // can use it multiple times. 313 if (E->getValueKind() == VK_RValue) 314 E = new (Context) MaterializeTemporaryExpr(E->getType(), E, true); 315 316 // Build the await_ready, await_suspend, await_resume calls. 317 ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E); 318 if (RSS.IsInvalid) 319 return ExprError(); 320 321 Expr *Res = new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1], 322 RSS.Results[2]); 323 Coroutine->CoroutineStmts.push_back(Res); 324 return Res; 325 } 326 327 StmtResult Sema::ActOnCoreturnStmt(SourceLocation Loc, Expr *E) { 328 return BuildCoreturnStmt(Loc, E); 329 } 330 StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) { 331 auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return"); 332 if (!Coroutine) 333 return StmtError(); 334 335 if (E && E->getType()->isPlaceholderType() && 336 !E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)) { 337 ExprResult R = CheckPlaceholderExpr(E); 338 if (R.isInvalid()) return StmtError(); 339 E = R.get(); 340 } 341 342 // FIXME: If the operand is a reference to a variable that's about to go out 343 // of scope, we should treat the operand as an xvalue for this overload 344 // resolution. 345 ExprResult PC; 346 if (E && !E->getType()->isVoidType()) { 347 PC = buildPromiseCall(*this, Coroutine, Loc, "return_value", E); 348 } else { 349 E = MakeFullDiscardedValueExpr(E).get(); 350 PC = buildPromiseCall(*this, Coroutine, Loc, "return_void", None); 351 } 352 if (PC.isInvalid()) 353 return StmtError(); 354 355 Expr *PCE = ActOnFinishFullExpr(PC.get()).get(); 356 357 Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE); 358 Coroutine->CoroutineStmts.push_back(Res); 359 return Res; 360 } 361 362 void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) { 363 FunctionScopeInfo *Fn = getCurFunction(); 364 assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine"); 365 366 // Coroutines [stmt.return]p1: 367 // A return statement shall not appear in a coroutine. 368 if (Fn->FirstReturnLoc.isValid()) { 369 Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine); 370 auto *First = Fn->CoroutineStmts[0]; 371 Diag(First->getLocStart(), diag::note_declared_coroutine_here) 372 << (isa<CoawaitExpr>(First) ? 0 : 373 isa<CoyieldExpr>(First) ? 1 : 2); 374 } 375 376 bool AnyCoawaits = false; 377 bool AnyCoyields = false; 378 for (auto *CoroutineStmt : Fn->CoroutineStmts) { 379 AnyCoawaits |= isa<CoawaitExpr>(CoroutineStmt); 380 AnyCoyields |= isa<CoyieldExpr>(CoroutineStmt); 381 } 382 383 if (!AnyCoawaits && !AnyCoyields) 384 Diag(Fn->CoroutineStmts.front()->getLocStart(), 385 diag::ext_coroutine_without_co_await_co_yield); 386 387 SourceLocation Loc = FD->getLocation(); 388 389 // Form a declaration statement for the promise declaration, so that AST 390 // visitors can more easily find it. 391 StmtResult PromiseStmt = 392 ActOnDeclStmt(ConvertDeclToDeclGroup(Fn->CoroutinePromise), Loc, Loc); 393 if (PromiseStmt.isInvalid()) 394 return FD->setInvalidDecl(); 395 396 // Form and check implicit 'co_await p.initial_suspend();' statement. 397 ExprResult InitialSuspend = 398 buildPromiseCall(*this, Fn, Loc, "initial_suspend", None); 399 // FIXME: Support operator co_await here. 400 if (!InitialSuspend.isInvalid()) 401 InitialSuspend = BuildCoawaitExpr(Loc, InitialSuspend.get()); 402 InitialSuspend = ActOnFinishFullExpr(InitialSuspend.get()); 403 if (InitialSuspend.isInvalid()) 404 return FD->setInvalidDecl(); 405 406 // Form and check implicit 'co_await p.final_suspend();' statement. 407 ExprResult FinalSuspend = 408 buildPromiseCall(*this, Fn, Loc, "final_suspend", None); 409 // FIXME: Support operator co_await here. 410 if (!FinalSuspend.isInvalid()) 411 FinalSuspend = BuildCoawaitExpr(Loc, FinalSuspend.get()); 412 FinalSuspend = ActOnFinishFullExpr(FinalSuspend.get()); 413 if (FinalSuspend.isInvalid()) 414 return FD->setInvalidDecl(); 415 416 // FIXME: Perform analysis of set_exception call. 417 418 // FIXME: Try to form 'p.return_void();' expression statement to handle 419 // control flowing off the end of the coroutine. 420 421 // Build implicit 'p.get_return_object()' expression and form initialization 422 // of return type from it. 423 ExprResult ReturnObject = 424 buildPromiseCall(*this, Fn, Loc, "get_return_object", None); 425 if (ReturnObject.isInvalid()) 426 return FD->setInvalidDecl(); 427 QualType RetType = FD->getReturnType(); 428 if (!RetType->isDependentType()) { 429 InitializedEntity Entity = 430 InitializedEntity::InitializeResult(Loc, RetType, false); 431 ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType, 432 ReturnObject.get()); 433 if (ReturnObject.isInvalid()) 434 return FD->setInvalidDecl(); 435 } 436 ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc); 437 if (ReturnObject.isInvalid()) 438 return FD->setInvalidDecl(); 439 440 // FIXME: Perform move-initialization of parameters into frame-local copies. 441 SmallVector<Expr*, 16> ParamMoves; 442 443 // Build body for the coroutine wrapper statement. 444 Body = new (Context) CoroutineBodyStmt( 445 Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(), 446 /*SetException*/nullptr, /*Fallthrough*/nullptr, 447 ReturnObject.get(), ParamMoves); 448 } 449