1 // Copyright 2015 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "CheckTraceVisitor.h" 6 7 #include <vector> 8 9 #include "Config.h" 10 11 using namespace clang; 12 13 CheckTraceVisitor::CheckTraceVisitor(CXXMethodDecl* trace, 14 RecordInfo* info, 15 RecordCache* cache) 16 : trace_(trace), info_(info), cache_(cache) {} 17 18 bool CheckTraceVisitor::VisitMemberExpr(MemberExpr* member) { 19 // In weak callbacks, consider any occurrence as a correct usage. 20 // TODO: We really want to require that isAlive is checked on manually 21 // processed weak fields. 22 if (IsWeakCallback()) { 23 if (FieldDecl* field = dyn_cast<FieldDecl>(member->getMemberDecl())) 24 FoundField(field); 25 } 26 return true; 27 } 28 29 bool CheckTraceVisitor::VisitCallExpr(CallExpr* call) { 30 // In weak callbacks we don't check calls (see VisitMemberExpr). 31 if (IsWeakCallback()) 32 return true; 33 34 Expr* callee = call->getCallee(); 35 36 // Trace calls from a templated derived class result in a 37 // DependentScopeMemberExpr because the concrete trace call depends on the 38 // instantiation of any shared template parameters. In this case the call is 39 // "unresolved" and we resort to comparing the syntactic type names. 40 if (CXXDependentScopeMemberExpr* expr = 41 dyn_cast<CXXDependentScopeMemberExpr>(callee)) { 42 CheckCXXDependentScopeMemberExpr(call, expr); 43 return true; 44 } 45 46 // A tracing call will have either a |visitor| or a |m_field| argument. 47 // A registerWeakMembers call will have a |this| argument. 48 if (call->getNumArgs() != 1) 49 return true; 50 Expr* arg = call->getArg(0); 51 52 if (UnresolvedMemberExpr* expr = dyn_cast<UnresolvedMemberExpr>(callee)) { 53 // This could be a trace call of a base class, as explained in the 54 // comments of CheckTraceBaseCall(). 55 if (CheckTraceBaseCall(call)) 56 return true; 57 58 if (expr->getMemberName().getAsString() == kRegisterWeakMembersName) 59 MarkAllWeakMembersTraced(); 60 61 QualType base = expr->getBaseType(); 62 if (!base->isPointerType()) 63 return true; 64 CXXRecordDecl* decl = base->getPointeeType()->getAsCXXRecordDecl(); 65 if (decl) 66 CheckTraceFieldCall(expr->getMemberName().getAsString(), decl, arg); 67 return true; 68 } 69 70 if (CXXMemberCallExpr* expr = dyn_cast<CXXMemberCallExpr>(call)) { 71 if (CheckTraceFieldMemberCall(expr) || CheckRegisterWeakMembers(expr)) 72 return true; 73 74 } 75 76 CheckTraceBaseCall(call); 77 return true; 78 } 79 80 bool CheckTraceVisitor::IsTraceCallName(const std::string& name) { 81 // Currently, a manually dispatched class cannot have mixin bases (having 82 // one would add a vtable which we explicitly check against). This means 83 // that we can only make calls to a trace method of the same name. Revisit 84 // this if our mixin/vtable assumption changes. 85 return name == trace_->getName(); 86 } 87 88 CXXRecordDecl* CheckTraceVisitor::GetDependentTemplatedDecl( 89 CXXDependentScopeMemberExpr* expr) { 90 NestedNameSpecifier* qual = expr->getQualifier(); 91 if (!qual) 92 return 0; 93 94 const Type* type = qual->getAsType(); 95 if (!type) 96 return 0; 97 98 return RecordInfo::GetDependentTemplatedDecl(*type); 99 } 100 101 namespace { 102 103 class FindFieldVisitor : public RecursiveASTVisitor<FindFieldVisitor> { 104 public: 105 FindFieldVisitor(); 106 MemberExpr* member() const; 107 FieldDecl* field() const; 108 bool TraverseMemberExpr(MemberExpr* member); 109 110 private: 111 MemberExpr* member_; 112 FieldDecl* field_; 113 }; 114 115 FindFieldVisitor::FindFieldVisitor() 116 : member_(0), 117 field_(0) { 118 } 119 120 MemberExpr* FindFieldVisitor::member() const { 121 return member_; 122 } 123 124 FieldDecl* FindFieldVisitor::field() const { 125 return field_; 126 } 127 128 bool FindFieldVisitor::TraverseMemberExpr(MemberExpr* member) { 129 if (FieldDecl* field = dyn_cast<FieldDecl>(member->getMemberDecl())) { 130 member_ = member; 131 field_ = field; 132 return false; 133 } 134 return true; 135 } 136 137 } // namespace 138 139 void CheckTraceVisitor::CheckCXXDependentScopeMemberExpr( 140 CallExpr* call, 141 CXXDependentScopeMemberExpr* expr) { 142 std::string fn_name = expr->getMember().getAsString(); 143 144 // Check for VisitorDispatcher::trace(field) and 145 // VisitorDispatcher::registerWeakMembers. 146 if (!expr->isImplicitAccess()) { 147 if (DeclRefExpr* base_decl = dyn_cast<DeclRefExpr>(expr->getBase())) { 148 if (Config::IsVisitorDispatcherType(base_decl->getType())) { 149 if (call->getNumArgs() == 1 && fn_name == kTraceName) { 150 FindFieldVisitor finder; 151 finder.TraverseStmt(call->getArg(0)); 152 if (finder.field()) 153 FoundField(finder.field()); 154 155 return; 156 } else if (call->getNumArgs() == 1 && 157 fn_name == kRegisterWeakMembersName) { 158 MarkAllWeakMembersTraced(); 159 } 160 } 161 } 162 } 163 164 CXXRecordDecl* tmpl = GetDependentTemplatedDecl(expr); 165 if (!tmpl) 166 return; 167 168 // Check for Super<T>::trace(visitor) 169 if (call->getNumArgs() == 1 && IsTraceCallName(fn_name)) { 170 RecordInfo::Bases::iterator it = info_->GetBases().begin(); 171 for (; it != info_->GetBases().end(); ++it) { 172 if (it->first->getName() == tmpl->getName()) 173 it->second.MarkTraced(); 174 } 175 } 176 177 // Check for TraceIfNeeded<T>::trace(visitor, &field) 178 if (call->getNumArgs() == 2 && fn_name == kTraceName && 179 tmpl->getName() == kTraceIfNeededName) { 180 FindFieldVisitor finder; 181 finder.TraverseStmt(call->getArg(1)); 182 if (finder.field()) 183 FoundField(finder.field()); 184 } 185 } 186 187 bool CheckTraceVisitor::CheckTraceBaseCall(CallExpr* call) { 188 // Checks for "Base::trace(visitor)"-like calls. 189 190 // Checking code for these two variables is shared among MemberExpr* case 191 // and UnresolvedMemberCase* case below. 192 // 193 // For example, if we've got "Base::trace(visitor)" as |call|, 194 // callee_record will be "Base", and func_name will be "trace". 195 CXXRecordDecl* callee_record = nullptr; 196 std::string func_name; 197 198 if (MemberExpr* callee = dyn_cast<MemberExpr>(call->getCallee())) { 199 if (!callee->hasQualifier()) 200 return false; 201 202 FunctionDecl* trace_decl = 203 dyn_cast<FunctionDecl>(callee->getMemberDecl()); 204 if (!trace_decl || !Config::IsTraceMethod(trace_decl)) 205 return false; 206 207 const Type* type = callee->getQualifier()->getAsType(); 208 if (!type) 209 return false; 210 211 callee_record = type->getAsCXXRecordDecl(); 212 func_name = trace_decl->getName(); 213 } else if (UnresolvedMemberExpr* callee = 214 dyn_cast<UnresolvedMemberExpr>(call->getCallee())) { 215 // Callee part may become unresolved if the type of the argument 216 // ("visitor") is a template parameter and the called function is 217 // overloaded. 218 // 219 // Here, we try to find a function that looks like trace() from the 220 // candidate overloaded functions, and if we find one, we assume it is 221 // called here. 222 223 CXXMethodDecl* trace_decl = nullptr; 224 for (NamedDecl* named_decl : callee->decls()) { 225 if (CXXMethodDecl* method_decl = dyn_cast<CXXMethodDecl>(named_decl)) { 226 if (Config::IsTraceMethod(method_decl)) { 227 trace_decl = method_decl; 228 break; 229 } 230 } 231 } 232 if (!trace_decl) 233 return false; 234 235 // Check if the passed argument is named "visitor". 236 if (call->getNumArgs() != 1) 237 return false; 238 DeclRefExpr* arg = dyn_cast<DeclRefExpr>(call->getArg(0)); 239 if (!arg || arg->getNameInfo().getAsString() != kVisitorVarName) 240 return false; 241 242 callee_record = trace_decl->getParent(); 243 func_name = callee->getMemberName().getAsString(); 244 } 245 246 if (!callee_record) 247 return false; 248 249 if (!IsTraceCallName(func_name)) 250 return false; 251 252 for (auto& base : info_->GetBases()) { 253 // We want to deal with omitted trace() function in an intermediary 254 // class in the class hierarchy, e.g.: 255 // class A : public GarbageCollected<A> { trace() { ... } }; 256 // class B : public A { /* No trace(); have nothing to trace. */ }; 257 // class C : public B { trace() { B::trace(visitor); } } 258 // where, B::trace() is actually A::trace(), and in some cases we get 259 // A as |callee_record| instead of B. We somehow need to mark B as 260 // traced if we find A::trace() call. 261 // 262 // To solve this, here we keep going up the class hierarchy as long as 263 // they are not required to have a trace method. The implementation is 264 // a simple DFS, where |base_records| represents the set of base classes 265 // we need to visit. 266 267 std::vector<CXXRecordDecl*> base_records; 268 base_records.push_back(base.first); 269 270 while (!base_records.empty()) { 271 CXXRecordDecl* base_record = base_records.back(); 272 base_records.pop_back(); 273 274 if (base_record == callee_record) { 275 // If we find a matching trace method, pretend the user has written 276 // a correct trace() method of the base; in the example above, we 277 // find A::trace() here and mark B as correctly traced. 278 base.second.MarkTraced(); 279 return true; 280 } 281 282 if (RecordInfo* base_info = cache_->Lookup(base_record)) { 283 if (!base_info->RequiresTraceMethod()) { 284 // If this base class is not required to have a trace method, then 285 // the actual trace method may be defined in an ancestor. 286 for (auto& inner_base : base_info->GetBases()) 287 base_records.push_back(inner_base.first); 288 } 289 } 290 } 291 } 292 293 return false; 294 } 295 296 bool CheckTraceVisitor::CheckTraceFieldMemberCall(CXXMemberCallExpr* call) { 297 return CheckTraceFieldCall(call->getMethodDecl()->getNameAsString(), 298 call->getRecordDecl(), 299 call->getArg(0)); 300 } 301 302 bool CheckTraceVisitor::CheckTraceFieldCall( 303 const std::string& name, 304 CXXRecordDecl* callee, 305 Expr* arg) { 306 if (name != kTraceName || !Config::IsVisitor(callee->getName())) 307 return false; 308 309 FindFieldVisitor finder; 310 finder.TraverseStmt(arg); 311 if (finder.field()) 312 FoundField(finder.field()); 313 314 return true; 315 } 316 317 bool CheckTraceVisitor::CheckRegisterWeakMembers(CXXMemberCallExpr* call) { 318 CXXMethodDecl* fn = call->getMethodDecl(); 319 if (fn->getName() != kRegisterWeakMembersName) 320 return false; 321 322 if (fn->isTemplateInstantiation()) { 323 const TemplateArgumentList& args = 324 *fn->getTemplateSpecializationInfo()->TemplateArguments; 325 // The second template argument is the callback method. 326 if (args.size() > 1 && 327 args[1].getKind() == TemplateArgument::Declaration) { 328 if (FunctionDecl* callback = 329 dyn_cast<FunctionDecl>(args[1].getAsDecl())) { 330 if (callback->hasBody()) { 331 CheckTraceVisitor nested_visitor(nullptr, info_, nullptr); 332 nested_visitor.TraverseStmt(callback->getBody()); 333 } 334 } 335 // TODO: mark all WeakMember<>s as traced even if 336 // the body isn't available? 337 } 338 } 339 return true; 340 } 341 342 bool CheckTraceVisitor::IsWeakCallback() const { 343 return !trace_; 344 } 345 346 void CheckTraceVisitor::MarkTraced(RecordInfo::Fields::iterator it) { 347 // In a weak callback we can't mark strong fields as traced. 348 if (IsWeakCallback() && !it->second.edge()->IsWeakMember()) 349 return; 350 it->second.MarkTraced(); 351 } 352 353 void CheckTraceVisitor::FoundField(FieldDecl* field) { 354 if (Config::IsTemplateInstantiation(info_->record())) { 355 // Pointer equality on fields does not work for template instantiations. 356 // The trace method refers to fields of the template definition which 357 // are different from the instantiated fields that need to be traced. 358 const std::string& name = field->getNameAsString(); 359 for (RecordInfo::Fields::iterator it = info_->GetFields().begin(); 360 it != info_->GetFields().end(); 361 ++it) { 362 if (it->first->getNameAsString() == name) { 363 MarkTraced(it); 364 break; 365 } 366 } 367 } else { 368 RecordInfo::Fields::iterator it = info_->GetFields().find(field); 369 if (it != info_->GetFields().end()) 370 MarkTraced(it); 371 } 372 } 373 374 void CheckTraceVisitor::MarkAllWeakMembersTraced() { 375 // If we find a call to registerWeakMembers which is unresolved we 376 // unsoundly consider all weak members as traced. 377 // TODO: Find out how to validate weak member tracing for unresolved call. 378 for (auto& field : info_->GetFields()) { 379 if (field.second.edge()->IsWeakMember()) 380 field.second.MarkTraced(); 381 } 382 } 383