Home | History | Annotate | Download | only in Tooling
      1 //===--- Refactoring.cpp - Framework for clang refactoring tools ----------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 //  Implements tools to support refactorings.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #include "clang/Basic/DiagnosticOptions.h"
     15 #include "clang/Basic/FileManager.h"
     16 #include "clang/Basic/SourceManager.h"
     17 #include "clang/Frontend/TextDiagnosticPrinter.h"
     18 #include "clang/Lex/Lexer.h"
     19 #include "clang/Rewrite/Core/Rewriter.h"
     20 #include "clang/Tooling/Refactoring.h"
     21 #include "llvm/Support/FileSystem.h"
     22 #include "llvm/Support/Path.h"
     23 #include "llvm/Support/raw_os_ostream.h"
     24 
     25 namespace clang {
     26 namespace tooling {
     27 
     28 static const char * const InvalidLocation = "";
     29 
     30 Replacement::Replacement()
     31   : FilePath(InvalidLocation) {}
     32 
     33 Replacement::Replacement(StringRef FilePath, unsigned Offset, unsigned Length,
     34                          StringRef ReplacementText)
     35     : FilePath(FilePath), ReplacementRange(Offset, Length),
     36       ReplacementText(ReplacementText) {}
     37 
     38 Replacement::Replacement(const SourceManager &Sources, SourceLocation Start,
     39                          unsigned Length, StringRef ReplacementText) {
     40   setFromSourceLocation(Sources, Start, Length, ReplacementText);
     41 }
     42 
     43 Replacement::Replacement(const SourceManager &Sources,
     44                          const CharSourceRange &Range,
     45                          StringRef ReplacementText) {
     46   setFromSourceRange(Sources, Range, ReplacementText);
     47 }
     48 
     49 bool Replacement::isApplicable() const {
     50   return FilePath != InvalidLocation;
     51 }
     52 
     53 bool Replacement::apply(Rewriter &Rewrite) const {
     54   SourceManager &SM = Rewrite.getSourceMgr();
     55   const FileEntry *Entry = SM.getFileManager().getFile(FilePath);
     56   if (!Entry)
     57     return false;
     58   FileID ID;
     59   // FIXME: Use SM.translateFile directly.
     60   SourceLocation Location = SM.translateFileLineCol(Entry, 1, 1);
     61   ID = Location.isValid() ?
     62     SM.getFileID(Location) :
     63     SM.createFileID(Entry, SourceLocation(), SrcMgr::C_User);
     64   // FIXME: We cannot check whether Offset + Length is in the file, as
     65   // the remapping API is not public in the RewriteBuffer.
     66   const SourceLocation Start =
     67     SM.getLocForStartOfFile(ID).
     68     getLocWithOffset(ReplacementRange.getOffset());
     69   // ReplaceText returns false on success.
     70   // ReplaceText only fails if the source location is not a file location, in
     71   // which case we already returned false earlier.
     72   bool RewriteSucceeded = !Rewrite.ReplaceText(
     73       Start, ReplacementRange.getLength(), ReplacementText);
     74   assert(RewriteSucceeded);
     75   return RewriteSucceeded;
     76 }
     77 
     78 std::string Replacement::toString() const {
     79   std::string result;
     80   llvm::raw_string_ostream stream(result);
     81   stream << FilePath << ": " << ReplacementRange.getOffset() << ":+"
     82          << ReplacementRange.getLength() << ":\"" << ReplacementText << "\"";
     83   return result;
     84 }
     85 
     86 bool operator<(const Replacement &LHS, const Replacement &RHS) {
     87   if (LHS.getOffset() != RHS.getOffset())
     88     return LHS.getOffset() < RHS.getOffset();
     89   if (LHS.getLength() != RHS.getLength())
     90     return LHS.getLength() < RHS.getLength();
     91   if (LHS.getFilePath() != RHS.getFilePath())
     92     return LHS.getFilePath() < RHS.getFilePath();
     93   return LHS.getReplacementText() < RHS.getReplacementText();
     94 }
     95 
     96 bool operator==(const Replacement &LHS, const Replacement &RHS) {
     97   return LHS.getOffset() == RHS.getOffset() &&
     98          LHS.getLength() == RHS.getLength() &&
     99          LHS.getFilePath() == RHS.getFilePath() &&
    100          LHS.getReplacementText() == RHS.getReplacementText();
    101 }
    102 
    103 void Replacement::setFromSourceLocation(const SourceManager &Sources,
    104                                         SourceLocation Start, unsigned Length,
    105                                         StringRef ReplacementText) {
    106   const std::pair<FileID, unsigned> DecomposedLocation =
    107       Sources.getDecomposedLoc(Start);
    108   const FileEntry *Entry = Sources.getFileEntryForID(DecomposedLocation.first);
    109   if (Entry) {
    110     // Make FilePath absolute so replacements can be applied correctly when
    111     // relative paths for files are used.
    112     llvm::SmallString<256> FilePath(Entry->getName());
    113     std::error_code EC = llvm::sys::fs::make_absolute(FilePath);
    114     this->FilePath = EC ? FilePath.c_str() : Entry->getName();
    115   } else {
    116     this->FilePath = InvalidLocation;
    117   }
    118   this->ReplacementRange = Range(DecomposedLocation.second, Length);
    119   this->ReplacementText = ReplacementText;
    120 }
    121 
    122 // FIXME: This should go into the Lexer, but we need to figure out how
    123 // to handle ranges for refactoring in general first - there is no obvious
    124 // good way how to integrate this into the Lexer yet.
    125 static int getRangeSize(const SourceManager &Sources,
    126                         const CharSourceRange &Range) {
    127   SourceLocation SpellingBegin = Sources.getSpellingLoc(Range.getBegin());
    128   SourceLocation SpellingEnd = Sources.getSpellingLoc(Range.getEnd());
    129   std::pair<FileID, unsigned> Start = Sources.getDecomposedLoc(SpellingBegin);
    130   std::pair<FileID, unsigned> End = Sources.getDecomposedLoc(SpellingEnd);
    131   if (Start.first != End.first) return -1;
    132   if (Range.isTokenRange())
    133     End.second += Lexer::MeasureTokenLength(SpellingEnd, Sources,
    134                                             LangOptions());
    135   return End.second - Start.second;
    136 }
    137 
    138 void Replacement::setFromSourceRange(const SourceManager &Sources,
    139                                      const CharSourceRange &Range,
    140                                      StringRef ReplacementText) {
    141   setFromSourceLocation(Sources, Sources.getSpellingLoc(Range.getBegin()),
    142                         getRangeSize(Sources, Range), ReplacementText);
    143 }
    144 
    145 bool applyAllReplacements(const Replacements &Replaces, Rewriter &Rewrite) {
    146   bool Result = true;
    147   for (Replacements::const_iterator I = Replaces.begin(),
    148                                     E = Replaces.end();
    149        I != E; ++I) {
    150     if (I->isApplicable()) {
    151       Result = I->apply(Rewrite) && Result;
    152     } else {
    153       Result = false;
    154     }
    155   }
    156   return Result;
    157 }
    158 
    159 // FIXME: Remove this function when Replacements is implemented as std::vector
    160 // instead of std::set.
    161 bool applyAllReplacements(const std::vector<Replacement> &Replaces,
    162                           Rewriter &Rewrite) {
    163   bool Result = true;
    164   for (std::vector<Replacement>::const_iterator I = Replaces.begin(),
    165                                                 E = Replaces.end();
    166        I != E; ++I) {
    167     if (I->isApplicable()) {
    168       Result = I->apply(Rewrite) && Result;
    169     } else {
    170       Result = false;
    171     }
    172   }
    173   return Result;
    174 }
    175 
    176 std::string applyAllReplacements(StringRef Code, const Replacements &Replaces) {
    177   FileManager Files((FileSystemOptions()));
    178   DiagnosticsEngine Diagnostics(
    179       IntrusiveRefCntPtr<DiagnosticIDs>(new DiagnosticIDs),
    180       new DiagnosticOptions);
    181   Diagnostics.setClient(new TextDiagnosticPrinter(
    182       llvm::outs(), &Diagnostics.getDiagnosticOptions()));
    183   SourceManager SourceMgr(Diagnostics, Files);
    184   Rewriter Rewrite(SourceMgr, LangOptions());
    185   llvm::MemoryBuffer *Buf = llvm::MemoryBuffer::getMemBuffer(Code, "<stdin>");
    186   const clang::FileEntry *Entry =
    187       Files.getVirtualFile("<stdin>", Buf->getBufferSize(), 0);
    188   SourceMgr.overrideFileContents(Entry, Buf);
    189   FileID ID =
    190       SourceMgr.createFileID(Entry, SourceLocation(), clang::SrcMgr::C_User);
    191   for (Replacements::const_iterator I = Replaces.begin(), E = Replaces.end();
    192        I != E; ++I) {
    193     Replacement Replace("<stdin>", I->getOffset(), I->getLength(),
    194                         I->getReplacementText());
    195     if (!Replace.apply(Rewrite))
    196       return "";
    197   }
    198   std::string Result;
    199   llvm::raw_string_ostream OS(Result);
    200   Rewrite.getEditBuffer(ID).write(OS);
    201   OS.flush();
    202   return Result;
    203 }
    204 
    205 unsigned shiftedCodePosition(const Replacements &Replaces, unsigned Position) {
    206   unsigned NewPosition = Position;
    207   for (Replacements::iterator I = Replaces.begin(), E = Replaces.end(); I != E;
    208        ++I) {
    209     if (I->getOffset() >= Position)
    210       break;
    211     if (I->getOffset() + I->getLength() > Position)
    212       NewPosition += I->getOffset() + I->getLength() - Position;
    213     NewPosition += I->getReplacementText().size() - I->getLength();
    214   }
    215   return NewPosition;
    216 }
    217 
    218 // FIXME: Remove this function when Replacements is implemented as std::vector
    219 // instead of std::set.
    220 unsigned shiftedCodePosition(const std::vector<Replacement> &Replaces,
    221                              unsigned Position) {
    222   unsigned NewPosition = Position;
    223   for (std::vector<Replacement>::const_iterator I = Replaces.begin(),
    224                                                 E = Replaces.end();
    225        I != E; ++I) {
    226     if (I->getOffset() >= Position)
    227       break;
    228     if (I->getOffset() + I->getLength() > Position)
    229       NewPosition += I->getOffset() + I->getLength() - Position;
    230     NewPosition += I->getReplacementText().size() - I->getLength();
    231   }
    232   return NewPosition;
    233 }
    234 
    235 void deduplicate(std::vector<Replacement> &Replaces,
    236                  std::vector<Range> &Conflicts) {
    237   if (Replaces.empty())
    238     return;
    239 
    240   // Deduplicate
    241   std::sort(Replaces.begin(), Replaces.end());
    242   std::vector<Replacement>::iterator End =
    243       std::unique(Replaces.begin(), Replaces.end());
    244   Replaces.erase(End, Replaces.end());
    245 
    246   // Detect conflicts
    247   Range ConflictRange(Replaces.front().getOffset(),
    248                       Replaces.front().getLength());
    249   unsigned ConflictStart = 0;
    250   unsigned ConflictLength = 1;
    251   for (unsigned i = 1; i < Replaces.size(); ++i) {
    252     Range Current(Replaces[i].getOffset(), Replaces[i].getLength());
    253     if (ConflictRange.overlapsWith(Current)) {
    254       // Extend conflicted range
    255       ConflictRange = Range(ConflictRange.getOffset(),
    256                             std::max(ConflictRange.getLength(),
    257                                      Current.getOffset() + Current.getLength() -
    258                                          ConflictRange.getOffset()));
    259       ++ConflictLength;
    260     } else {
    261       if (ConflictLength > 1)
    262         Conflicts.push_back(Range(ConflictStart, ConflictLength));
    263       ConflictRange = Current;
    264       ConflictStart = i;
    265       ConflictLength = 1;
    266     }
    267   }
    268 
    269   if (ConflictLength > 1)
    270     Conflicts.push_back(Range(ConflictStart, ConflictLength));
    271 }
    272 
    273 
    274 RefactoringTool::RefactoringTool(const CompilationDatabase &Compilations,
    275                                  ArrayRef<std::string> SourcePaths)
    276   : ClangTool(Compilations, SourcePaths) {}
    277 
    278 Replacements &RefactoringTool::getReplacements() { return Replace; }
    279 
    280 int RefactoringTool::runAndSave(FrontendActionFactory *ActionFactory) {
    281   if (int Result = run(ActionFactory)) {
    282     return Result;
    283   }
    284 
    285   LangOptions DefaultLangOptions;
    286   IntrusiveRefCntPtr<DiagnosticOptions> DiagOpts = new DiagnosticOptions();
    287   TextDiagnosticPrinter DiagnosticPrinter(llvm::errs(), &*DiagOpts);
    288   DiagnosticsEngine Diagnostics(
    289       IntrusiveRefCntPtr<DiagnosticIDs>(new DiagnosticIDs()),
    290       &*DiagOpts, &DiagnosticPrinter, false);
    291   SourceManager Sources(Diagnostics, getFiles());
    292   Rewriter Rewrite(Sources, DefaultLangOptions);
    293 
    294   if (!applyAllReplacements(Rewrite)) {
    295     llvm::errs() << "Skipped some replacements.\n";
    296   }
    297 
    298   return saveRewrittenFiles(Rewrite);
    299 }
    300 
    301 bool RefactoringTool::applyAllReplacements(Rewriter &Rewrite) {
    302   return tooling::applyAllReplacements(Replace, Rewrite);
    303 }
    304 
    305 int RefactoringTool::saveRewrittenFiles(Rewriter &Rewrite) {
    306   return Rewrite.overwriteChangedFiles() ? 1 : 0;
    307 }
    308 
    309 } // end namespace tooling
    310 } // end namespace clang
    311