Home | History | Annotate | Download | only in quic
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/tools/quic/quic_in_memory_cache.h"
      6 
      7 #include "base/files/file_enumerator.h"
      8 #include "base/files/file_util.h"
      9 #include "base/stl_util.h"
     10 #include "base/strings/string_number_conversions.h"
     11 #include "net/tools/balsa/balsa_headers.h"
     12 
     13 using base::FilePath;
     14 using base::StringPiece;
     15 using std::string;
     16 
     17 // Specifies the directory used during QuicInMemoryCache
     18 // construction to seed the cache. Cache directory can be
     19 // generated using `wget -p --save-headers <url>
     20 
     21 namespace net {
     22 namespace tools {
     23 
     24 std::string FLAGS_quic_in_memory_cache_dir = "";
     25 
     26 namespace {
     27 
     28 // BalsaVisitor implementation (glue) which caches response bodies.
     29 class CachingBalsaVisitor : public NoOpBalsaVisitor {
     30  public:
     31   CachingBalsaVisitor() : done_framing_(false) {}
     32   virtual void ProcessBodyData(const char* input, size_t size) OVERRIDE {
     33     AppendToBody(input, size);
     34   }
     35   virtual void MessageDone() OVERRIDE {
     36     done_framing_ = true;
     37   }
     38   virtual void HandleHeaderError(BalsaFrame* framer) OVERRIDE {
     39     UnhandledError();
     40   }
     41   virtual void HandleHeaderWarning(BalsaFrame* framer) OVERRIDE {
     42     UnhandledError();
     43   }
     44   virtual void HandleChunkingError(BalsaFrame* framer) OVERRIDE {
     45     UnhandledError();
     46   }
     47   virtual void HandleBodyError(BalsaFrame* framer) OVERRIDE {
     48     UnhandledError();
     49   }
     50   void UnhandledError() {
     51     LOG(DFATAL) << "Unhandled error framing HTTP.";
     52   }
     53   void AppendToBody(const char* input, size_t size) {
     54     body_.append(input, size);
     55   }
     56   bool done_framing() const { return done_framing_; }
     57   const string& body() const { return body_; }
     58 
     59  private:
     60   bool done_framing_;
     61   string body_;
     62 };
     63 
     64 }  // namespace
     65 
     66 // static
     67 QuicInMemoryCache* QuicInMemoryCache::GetInstance() {
     68   return Singleton<QuicInMemoryCache>::get();
     69 }
     70 
     71 const QuicInMemoryCache::Response* QuicInMemoryCache::GetResponse(
     72     const BalsaHeaders& request_headers) const {
     73   ResponseMap::const_iterator it = responses_.find(GetKey(request_headers));
     74   if (it == responses_.end()) {
     75     return NULL;
     76   }
     77   return it->second;
     78 }
     79 
     80 void QuicInMemoryCache::AddSimpleResponse(StringPiece method,
     81                                           StringPiece path,
     82                                           StringPiece version,
     83                                           StringPiece response_code,
     84                                           StringPiece response_detail,
     85                                           StringPiece body) {
     86   BalsaHeaders request_headers, response_headers;
     87   request_headers.SetRequestFirstlineFromStringPieces(method,
     88                                                       path,
     89                                                       version);
     90   response_headers.SetRequestFirstlineFromStringPieces(version,
     91                                                        response_code,
     92                                                        response_detail);
     93   response_headers.AppendHeader("content-length",
     94                                 base::IntToString(body.length()));
     95 
     96   AddResponse(request_headers, response_headers, body);
     97 }
     98 
     99 void QuicInMemoryCache::AddResponse(const BalsaHeaders& request_headers,
    100                                     const BalsaHeaders& response_headers,
    101                                     StringPiece response_body) {
    102   VLOG(1) << "Adding response for: " << GetKey(request_headers);
    103   if (ContainsKey(responses_, GetKey(request_headers))) {
    104     LOG(DFATAL) << "Response for given request already exists!";
    105     return;
    106   }
    107   Response* new_response = new Response();
    108   new_response->set_headers(response_headers);
    109   new_response->set_body(response_body);
    110   responses_[GetKey(request_headers)] = new_response;
    111 }
    112 
    113 void QuicInMemoryCache::AddSpecialResponse(StringPiece method,
    114                                            StringPiece path,
    115                                            StringPiece version,
    116                                            SpecialResponseType response_type) {
    117   BalsaHeaders request_headers, response_headers;
    118   request_headers.SetRequestFirstlineFromStringPieces(method,
    119                                                       path,
    120                                                       version);
    121   AddResponse(request_headers, response_headers, "");
    122   responses_[GetKey(request_headers)]->response_type_ = response_type;
    123 }
    124 
    125 QuicInMemoryCache::QuicInMemoryCache() {
    126   Initialize();
    127 }
    128 
    129 void QuicInMemoryCache::ResetForTests() {
    130   STLDeleteValues(&responses_);
    131   Initialize();
    132 }
    133 
    134 void QuicInMemoryCache::Initialize() {
    135   // If there's no defined cache dir, we have no initialization to do.
    136   if (FLAGS_quic_in_memory_cache_dir.empty()) {
    137     VLOG(1) << "No cache directory found. Skipping initialization.";
    138     return;
    139   }
    140   VLOG(1) << "Attempting to initialize QuicInMemoryCache from directory: "
    141           << FLAGS_quic_in_memory_cache_dir;
    142 
    143   FilePath directory(FLAGS_quic_in_memory_cache_dir);
    144   base::FileEnumerator file_list(directory,
    145                                  true,
    146                                  base::FileEnumerator::FILES);
    147 
    148   FilePath file = file_list.Next();
    149   while (!file.empty()) {
    150     // Need to skip files in .svn directories
    151     if (file.value().find("/.svn/") != std::string::npos) {
    152       file = file_list.Next();
    153       continue;
    154     }
    155 
    156     BalsaHeaders request_headers, response_headers;
    157 
    158     string file_contents;
    159     base::ReadFileToString(file, &file_contents);
    160 
    161     // Frame HTTP.
    162     CachingBalsaVisitor caching_visitor;
    163     BalsaFrame framer;
    164     framer.set_balsa_headers(&response_headers);
    165     framer.set_balsa_visitor(&caching_visitor);
    166     size_t processed = 0;
    167     while (processed < file_contents.length() &&
    168            !caching_visitor.done_framing()) {
    169       processed += framer.ProcessInput(file_contents.c_str() + processed,
    170                                        file_contents.length() - processed);
    171     }
    172 
    173     if (!caching_visitor.done_framing()) {
    174       LOG(DFATAL) << "Did not frame entire message from file: " << file.value()
    175                   << " (" << processed << " of " << file_contents.length()
    176                   << " bytes).";
    177     }
    178     if (processed < file_contents.length()) {
    179       // Didn't frame whole file. Assume remainder is body.
    180       // This sometimes happens as a result of incompatibilities between
    181       // BalsaFramer and wget's serialization of HTTP sans content-length.
    182       caching_visitor.AppendToBody(file_contents.c_str() + processed,
    183                                    file_contents.length() - processed);
    184       processed += file_contents.length();
    185     }
    186 
    187     StringPiece base = file.value();
    188     if (response_headers.HasHeader("X-Original-Url")) {
    189       base = response_headers.GetHeader("X-Original-Url");
    190       response_headers.RemoveAllOfHeader("X-Original-Url");
    191       // Remove the protocol so that the string is of the form host + path,
    192       // which is parsed properly below.
    193       if (StringPieceUtils::StartsWithIgnoreCase(base, "https://")) {
    194         base.remove_prefix(8);
    195       } else if (StringPieceUtils::StartsWithIgnoreCase(base, "http://")) {
    196         base.remove_prefix(7);
    197       }
    198     }
    199     int path_start = base.find_first_of('/');
    200     DCHECK_LT(0, path_start);
    201     StringPiece host(base.substr(0, path_start));
    202     StringPiece path(base.substr(path_start));
    203     if (path[path.length() - 1] == ',') {
    204       path.remove_suffix(1);
    205     }
    206     // Set up request headers. Assume method is GET and protocol is HTTP/1.1.
    207     request_headers.SetRequestFirstlineFromStringPieces("GET",
    208                                                         path,
    209                                                         "HTTP/1.1");
    210     request_headers.ReplaceOrAppendHeader("host", host);
    211 
    212     VLOG(1) << "Inserting 'http://" << GetKey(request_headers)
    213             << "' into QuicInMemoryCache.";
    214 
    215     AddResponse(request_headers, response_headers, caching_visitor.body());
    216 
    217     file = file_list.Next();
    218   }
    219 }
    220 
    221 QuicInMemoryCache::~QuicInMemoryCache() {
    222   STLDeleteValues(&responses_);
    223 }
    224 
    225 string QuicInMemoryCache::GetKey(const BalsaHeaders& request_headers) const {
    226   StringPiece uri = request_headers.request_uri();
    227   if (uri.size() == 0) {
    228     return "";
    229   }
    230   StringPiece host;
    231   if (uri[0] == '/') {
    232     host = request_headers.GetHeader("host");
    233   } else if (StringPieceUtils::StartsWithIgnoreCase(uri, "https://")) {
    234     uri.remove_prefix(8);
    235   } else if (StringPieceUtils::StartsWithIgnoreCase(uri, "http://")) {
    236     uri.remove_prefix(7);
    237   }
    238   return host.as_string() + uri.as_string();
    239 }
    240 
    241 }  // namespace tools
    242 }  // namespace net
    243