Home | History | Annotate | Download | only in Sema
      1 //===--- SemaCoroutines.cpp - Semantic Analysis for Coroutines ------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 //  This file implements semantic analysis for C++ Coroutines.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #include "clang/Sema/SemaInternal.h"
     15 #include "clang/AST/Decl.h"
     16 #include "clang/AST/ExprCXX.h"
     17 #include "clang/AST/StmtCXX.h"
     18 #include "clang/Lex/Preprocessor.h"
     19 #include "clang/Sema/Initialization.h"
     20 #include "clang/Sema/Overload.h"
     21 using namespace clang;
     22 using namespace sema;
     23 
     24 /// Look up the std::coroutine_traits<...>::promise_type for the given
     25 /// function type.
     26 static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType,
     27                                   SourceLocation Loc) {
     28   // FIXME: Cache std::coroutine_traits once we've found it.
     29   NamespaceDecl *Std = S.getStdNamespace();
     30   if (!Std) {
     31     S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
     32     return QualType();
     33   }
     34 
     35   LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"),
     36                       Loc, Sema::LookupOrdinaryName);
     37   if (!S.LookupQualifiedName(Result, Std)) {
     38     S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
     39     return QualType();
     40   }
     41 
     42   ClassTemplateDecl *CoroTraits = Result.getAsSingle<ClassTemplateDecl>();
     43   if (!CoroTraits) {
     44     Result.suppressDiagnostics();
     45     // We found something weird. Complain about the first thing we found.
     46     NamedDecl *Found = *Result.begin();
     47     S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits);
     48     return QualType();
     49   }
     50 
     51   // Form template argument list for coroutine_traits<R, P1, P2, ...>.
     52   TemplateArgumentListInfo Args(Loc, Loc);
     53   Args.addArgument(TemplateArgumentLoc(
     54       TemplateArgument(FnType->getReturnType()),
     55       S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), Loc)));
     56   // FIXME: If the function is a non-static member function, add the type
     57   // of the implicit object parameter before the formal parameters.
     58   for (QualType T : FnType->getParamTypes())
     59     Args.addArgument(TemplateArgumentLoc(
     60         TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, Loc)));
     61 
     62   // Build the template-id.
     63   QualType CoroTrait =
     64       S.CheckTemplateIdType(TemplateName(CoroTraits), Loc, Args);
     65   if (CoroTrait.isNull())
     66     return QualType();
     67   if (S.RequireCompleteType(Loc, CoroTrait,
     68                             diag::err_coroutine_traits_missing_specialization))
     69     return QualType();
     70 
     71   CXXRecordDecl *RD = CoroTrait->getAsCXXRecordDecl();
     72   assert(RD && "specialization of class template is not a class?");
     73 
     74   // Look up the ::promise_type member.
     75   LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), Loc,
     76                  Sema::LookupOrdinaryName);
     77   S.LookupQualifiedName(R, RD);
     78   auto *Promise = R.getAsSingle<TypeDecl>();
     79   if (!Promise) {
     80     S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_found)
     81       << RD;
     82     return QualType();
     83   }
     84 
     85   // The promise type is required to be a class type.
     86   QualType PromiseType = S.Context.getTypeDeclType(Promise);
     87   if (!PromiseType->getAsCXXRecordDecl()) {
     88     // Use the fully-qualified name of the type.
     89     auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, Std);
     90     NNS = NestedNameSpecifier::Create(S.Context, NNS, false,
     91                                       CoroTrait.getTypePtr());
     92     PromiseType = S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
     93 
     94     S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_class)
     95       << PromiseType;
     96     return QualType();
     97   }
     98 
     99   return PromiseType;
    100 }
    101 
    102 /// Check that this is a context in which a coroutine suspension can appear.
    103 static FunctionScopeInfo *
    104 checkCoroutineContext(Sema &S, SourceLocation Loc, StringRef Keyword) {
    105   // 'co_await' and 'co_yield' are not permitted in unevaluated operands.
    106   if (S.isUnevaluatedContext()) {
    107     S.Diag(Loc, diag::err_coroutine_unevaluated_context) << Keyword;
    108     return nullptr;
    109   }
    110 
    111   // Any other usage must be within a function.
    112   // FIXME: Reject a coroutine with a deduced return type.
    113   auto *FD = dyn_cast<FunctionDecl>(S.CurContext);
    114   if (!FD) {
    115     S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext)
    116                     ? diag::err_coroutine_objc_method
    117                     : diag::err_coroutine_outside_function) << Keyword;
    118   } else if (isa<CXXConstructorDecl>(FD) || isa<CXXDestructorDecl>(FD)) {
    119     // Coroutines TS [special]/6:
    120     //   A special member function shall not be a coroutine.
    121     //
    122     // FIXME: We assume that this really means that a coroutine cannot
    123     //        be a constructor or destructor.
    124     S.Diag(Loc, diag::err_coroutine_ctor_dtor)
    125       << isa<CXXDestructorDecl>(FD) << Keyword;
    126   } else if (FD->isConstexpr()) {
    127     S.Diag(Loc, diag::err_coroutine_constexpr) << Keyword;
    128   } else if (FD->isVariadic()) {
    129     S.Diag(Loc, diag::err_coroutine_varargs) << Keyword;
    130   } else {
    131     auto *ScopeInfo = S.getCurFunction();
    132     assert(ScopeInfo && "missing function scope for function");
    133 
    134     // If we don't have a promise variable, build one now.
    135     if (!ScopeInfo->CoroutinePromise) {
    136       QualType T =
    137           FD->getType()->isDependentType()
    138               ? S.Context.DependentTy
    139               : lookupPromiseType(S, FD->getType()->castAs<FunctionProtoType>(),
    140                                   Loc);
    141       if (T.isNull())
    142         return nullptr;
    143 
    144       // Create and default-initialize the promise.
    145       ScopeInfo->CoroutinePromise =
    146           VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(),
    147                           &S.PP.getIdentifierTable().get("__promise"), T,
    148                           S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
    149       S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise);
    150       if (!ScopeInfo->CoroutinePromise->isInvalidDecl())
    151         S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise, false);
    152     }
    153 
    154     return ScopeInfo;
    155   }
    156 
    157   return nullptr;
    158 }
    159 
    160 /// Build a call to 'operator co_await' if there is a suitable operator for
    161 /// the given expression.
    162 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
    163                                            SourceLocation Loc, Expr *E) {
    164   UnresolvedSet<16> Functions;
    165   SemaRef.LookupOverloadedOperatorName(OO_Coawait, S, E->getType(), QualType(),
    166                                        Functions);
    167   return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
    168 }
    169 
    170 struct ReadySuspendResumeResult {
    171   bool IsInvalid;
    172   Expr *Results[3];
    173 };
    174 
    175 static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
    176                                   StringRef Name,
    177                                   MutableArrayRef<Expr *> Args) {
    178   DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc);
    179 
    180   // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&.
    181   CXXScopeSpec SS;
    182   ExprResult Result = S.BuildMemberReferenceExpr(
    183       Base, Base->getType(), Loc, /*IsPtr=*/false, SS,
    184       SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr,
    185       /*Scope=*/nullptr);
    186   if (Result.isInvalid())
    187     return ExprError();
    188 
    189   return S.ActOnCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr);
    190 }
    191 
    192 /// Build calls to await_ready, await_suspend, and await_resume for a co_await
    193 /// expression.
    194 static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, SourceLocation Loc,
    195                                                   Expr *E) {
    196   // Assume invalid until we see otherwise.
    197   ReadySuspendResumeResult Calls = {true, {}};
    198 
    199   const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"};
    200   for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) {
    201     Expr *Operand = new (S.Context) OpaqueValueExpr(
    202         Loc, E->getType(), VK_LValue, E->getObjectKind(), E);
    203 
    204     // FIXME: Pass coroutine handle to await_suspend.
    205     ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], None);
    206     if (Result.isInvalid())
    207       return Calls;
    208     Calls.Results[I] = Result.get();
    209   }
    210 
    211   Calls.IsInvalid = false;
    212   return Calls;
    213 }
    214 
    215 ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
    216   if (E->getType()->isPlaceholderType()) {
    217     ExprResult R = CheckPlaceholderExpr(E);
    218     if (R.isInvalid()) return ExprError();
    219     E = R.get();
    220   }
    221 
    222   ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E);
    223   if (Awaitable.isInvalid())
    224     return ExprError();
    225   return BuildCoawaitExpr(Loc, Awaitable.get());
    226 }
    227 ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) {
    228   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await");
    229   if (!Coroutine)
    230     return ExprError();
    231 
    232   if (E->getType()->isPlaceholderType()) {
    233     ExprResult R = CheckPlaceholderExpr(E);
    234     if (R.isInvalid()) return ExprError();
    235     E = R.get();
    236   }
    237 
    238   if (E->getType()->isDependentType()) {
    239     Expr *Res = new (Context) CoawaitExpr(Loc, Context.DependentTy, E);
    240     Coroutine->CoroutineStmts.push_back(Res);
    241     return Res;
    242   }
    243 
    244   // If the expression is a temporary, materialize it as an lvalue so that we
    245   // can use it multiple times.
    246   if (E->getValueKind() == VK_RValue)
    247     E = new (Context) MaterializeTemporaryExpr(E->getType(), E, true);
    248 
    249   // Build the await_ready, await_suspend, await_resume calls.
    250   ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E);
    251   if (RSS.IsInvalid)
    252     return ExprError();
    253 
    254   Expr *Res = new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1],
    255                                         RSS.Results[2]);
    256   Coroutine->CoroutineStmts.push_back(Res);
    257   return Res;
    258 }
    259 
    260 static ExprResult buildPromiseCall(Sema &S, FunctionScopeInfo *Coroutine,
    261                                    SourceLocation Loc, StringRef Name,
    262                                    MutableArrayRef<Expr *> Args) {
    263   assert(Coroutine->CoroutinePromise && "no promise for coroutine");
    264 
    265   // Form a reference to the promise.
    266   auto *Promise = Coroutine->CoroutinePromise;
    267   ExprResult PromiseRef = S.BuildDeclRefExpr(
    268       Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
    269   if (PromiseRef.isInvalid())
    270     return ExprError();
    271 
    272   // Call 'yield_value', passing in E.
    273   return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
    274 }
    275 
    276 ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) {
    277   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
    278   if (!Coroutine)
    279     return ExprError();
    280 
    281   // Build yield_value call.
    282   ExprResult Awaitable =
    283       buildPromiseCall(*this, Coroutine, Loc, "yield_value", E);
    284   if (Awaitable.isInvalid())
    285     return ExprError();
    286 
    287   // Build 'operator co_await' call.
    288   Awaitable = buildOperatorCoawaitCall(*this, S, Loc, Awaitable.get());
    289   if (Awaitable.isInvalid())
    290     return ExprError();
    291 
    292   return BuildCoyieldExpr(Loc, Awaitable.get());
    293 }
    294 ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) {
    295   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
    296   if (!Coroutine)
    297     return ExprError();
    298 
    299   if (E->getType()->isPlaceholderType()) {
    300     ExprResult R = CheckPlaceholderExpr(E);
    301     if (R.isInvalid()) return ExprError();
    302     E = R.get();
    303   }
    304 
    305   if (E->getType()->isDependentType()) {
    306     Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E);
    307     Coroutine->CoroutineStmts.push_back(Res);
    308     return Res;
    309   }
    310 
    311   // If the expression is a temporary, materialize it as an lvalue so that we
    312   // can use it multiple times.
    313   if (E->getValueKind() == VK_RValue)
    314     E = new (Context) MaterializeTemporaryExpr(E->getType(), E, true);
    315 
    316   // Build the await_ready, await_suspend, await_resume calls.
    317   ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E);
    318   if (RSS.IsInvalid)
    319     return ExprError();
    320 
    321   Expr *Res = new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1],
    322                                         RSS.Results[2]);
    323   Coroutine->CoroutineStmts.push_back(Res);
    324   return Res;
    325 }
    326 
    327 StmtResult Sema::ActOnCoreturnStmt(SourceLocation Loc, Expr *E) {
    328   return BuildCoreturnStmt(Loc, E);
    329 }
    330 StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) {
    331   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return");
    332   if (!Coroutine)
    333     return StmtError();
    334 
    335   if (E && E->getType()->isPlaceholderType() &&
    336       !E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)) {
    337     ExprResult R = CheckPlaceholderExpr(E);
    338     if (R.isInvalid()) return StmtError();
    339     E = R.get();
    340   }
    341 
    342   // FIXME: If the operand is a reference to a variable that's about to go out
    343   // of scope, we should treat the operand as an xvalue for this overload
    344   // resolution.
    345   ExprResult PC;
    346   if (E && !E->getType()->isVoidType()) {
    347     PC = buildPromiseCall(*this, Coroutine, Loc, "return_value", E);
    348   } else {
    349     E = MakeFullDiscardedValueExpr(E).get();
    350     PC = buildPromiseCall(*this, Coroutine, Loc, "return_void", None);
    351   }
    352   if (PC.isInvalid())
    353     return StmtError();
    354 
    355   Expr *PCE = ActOnFinishFullExpr(PC.get()).get();
    356 
    357   Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE);
    358   Coroutine->CoroutineStmts.push_back(Res);
    359   return Res;
    360 }
    361 
    362 void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
    363   FunctionScopeInfo *Fn = getCurFunction();
    364   assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine");
    365 
    366   // Coroutines [stmt.return]p1:
    367   //   A return statement shall not appear in a coroutine.
    368   if (Fn->FirstReturnLoc.isValid()) {
    369     Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
    370     auto *First = Fn->CoroutineStmts[0];
    371     Diag(First->getLocStart(), diag::note_declared_coroutine_here)
    372       << (isa<CoawaitExpr>(First) ? 0 :
    373           isa<CoyieldExpr>(First) ? 1 : 2);
    374   }
    375 
    376   bool AnyCoawaits = false;
    377   bool AnyCoyields = false;
    378   for (auto *CoroutineStmt : Fn->CoroutineStmts) {
    379     AnyCoawaits |= isa<CoawaitExpr>(CoroutineStmt);
    380     AnyCoyields |= isa<CoyieldExpr>(CoroutineStmt);
    381   }
    382 
    383   if (!AnyCoawaits && !AnyCoyields)
    384     Diag(Fn->CoroutineStmts.front()->getLocStart(),
    385          diag::ext_coroutine_without_co_await_co_yield);
    386 
    387   SourceLocation Loc = FD->getLocation();
    388 
    389   // Form a declaration statement for the promise declaration, so that AST
    390   // visitors can more easily find it.
    391   StmtResult PromiseStmt =
    392       ActOnDeclStmt(ConvertDeclToDeclGroup(Fn->CoroutinePromise), Loc, Loc);
    393   if (PromiseStmt.isInvalid())
    394     return FD->setInvalidDecl();
    395 
    396   // Form and check implicit 'co_await p.initial_suspend();' statement.
    397   ExprResult InitialSuspend =
    398       buildPromiseCall(*this, Fn, Loc, "initial_suspend", None);
    399   // FIXME: Support operator co_await here.
    400   if (!InitialSuspend.isInvalid())
    401     InitialSuspend = BuildCoawaitExpr(Loc, InitialSuspend.get());
    402   InitialSuspend = ActOnFinishFullExpr(InitialSuspend.get());
    403   if (InitialSuspend.isInvalid())
    404     return FD->setInvalidDecl();
    405 
    406   // Form and check implicit 'co_await p.final_suspend();' statement.
    407   ExprResult FinalSuspend =
    408       buildPromiseCall(*this, Fn, Loc, "final_suspend", None);
    409   // FIXME: Support operator co_await here.
    410   if (!FinalSuspend.isInvalid())
    411     FinalSuspend = BuildCoawaitExpr(Loc, FinalSuspend.get());
    412   FinalSuspend = ActOnFinishFullExpr(FinalSuspend.get());
    413   if (FinalSuspend.isInvalid())
    414     return FD->setInvalidDecl();
    415 
    416   // FIXME: Perform analysis of set_exception call.
    417 
    418   // FIXME: Try to form 'p.return_void();' expression statement to handle
    419   // control flowing off the end of the coroutine.
    420 
    421   // Build implicit 'p.get_return_object()' expression and form initialization
    422   // of return type from it.
    423   ExprResult ReturnObject =
    424     buildPromiseCall(*this, Fn, Loc, "get_return_object", None);
    425   if (ReturnObject.isInvalid())
    426     return FD->setInvalidDecl();
    427   QualType RetType = FD->getReturnType();
    428   if (!RetType->isDependentType()) {
    429     InitializedEntity Entity =
    430         InitializedEntity::InitializeResult(Loc, RetType, false);
    431     ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType,
    432                                                    ReturnObject.get());
    433     if (ReturnObject.isInvalid())
    434       return FD->setInvalidDecl();
    435   }
    436   ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc);
    437   if (ReturnObject.isInvalid())
    438     return FD->setInvalidDecl();
    439 
    440   // FIXME: Perform move-initialization of parameters into frame-local copies.
    441   SmallVector<Expr*, 16> ParamMoves;
    442 
    443   // Build body for the coroutine wrapper statement.
    444   Body = new (Context) CoroutineBodyStmt(
    445       Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(),
    446       /*SetException*/nullptr, /*Fallthrough*/nullptr,
    447       ReturnObject.get(), ParamMoves);
    448 }
    449