Home | History | Annotate | Download | only in flip_server
      1 // Copyright (c) 2009 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include <dirent.h>
      6 #include <linux/tcp.h>  // For TCP_NODELAY
      7 #include <sys/socket.h>
      8 #include <sys/types.h>
      9 #include <unistd.h>
     10 #include <openssl/ssl.h>
     11 
     12 #include <deque>
     13 #include <iostream>
     14 #include <limits>
     15 #include <vector>
     16 #include <list>
     17 
     18 #include "base/logging.h"
     19 #include "base/simple_thread.h"
     20 #include "base/timer.h"
     21 #include "base/lock.h"
     22 #include "net/flip/flip_frame_builder.h"
     23 #include "net/flip/flip_framer.h"
     24 #include "net/flip/flip_protocol.h"
     25 #include "net/tools/flip_server/balsa_enums.h"
     26 #include "net/tools/flip_server/balsa_frame.h"
     27 #include "net/tools/flip_server/balsa_headers.h"
     28 #include "net/tools/flip_server/balsa_visitor_interface.h"
     29 #include "net/tools/flip_server/buffer_interface.h"
     30 #include "net/tools/flip_server/create_listener.h"
     31 #include "net/tools/flip_server/epoll_server.h"
     32 #include "net/tools/flip_server/loadtime_measurement.h"
     33 #include "net/tools/flip_server/other_defines.h"
     34 #include "net/tools/flip_server/ring_buffer.h"
     35 #include "net/tools/flip_server/simple_buffer.h"
     36 #include "net/tools/flip_server/split.h"
     37 #include "net/tools/flip_server/url_to_filename_encoder.h"
     38 #include "net/tools/flip_server/url_utilities.h"
     39 
     40 ////////////////////////////////////////////////////////////////////////////////
     41 
     42 using base::StringPiece;
     43 using base::SimpleThread;
     44 // using base::Lock;  // heh, this isn't in base namespace?!
     45 // using base::AutoLock;  // ditto!
     46 using flip::CONTROL_FLAG_NONE;
     47 using flip::DATA_FLAG_COMPRESSED;
     48 using flip::DATA_FLAG_FIN;
     49 using flip::FIN_STREAM;
     50 using flip::FlipControlFrame;
     51 using flip::FlipDataFlags;
     52 using flip::FlipDataFrame;
     53 using flip::FlipFinStreamControlFrame;
     54 using flip::FlipFrame;
     55 using flip::FlipFrameBuilder;
     56 using flip::FlipFramer;
     57 using flip::FlipFramerVisitorInterface;
     58 using flip::FlipHeaderBlock;
     59 using flip::FlipStreamId;
     60 using flip::FlipSynReplyControlFrame;
     61 using flip::FlipSynStreamControlFrame;
     62 using flip::SYN_REPLY;
     63 using flip::SYN_STREAM;
     64 using net::BalsaFrame;
     65 using net::BalsaFrameEnums;
     66 using net::BalsaHeaders;
     67 using net::BalsaHeadersEnums;
     68 using net::BalsaVisitorInterface;
     69 using net::EpollAlarmCallbackInterface;
     70 using net::EpollCallbackInterface;
     71 using net::EpollEvent;
     72 using net::EpollServer;
     73 using net::RingBuffer;
     74 using net::SimpleBuffer;
     75 using net::SplitStringPieceToVector;
     76 using net::UrlUtilities;
     77 using std::deque;
     78 using std::map;
     79 using std::pair;
     80 using std::string;
     81 using std::vector;
     82 using std::list;
     83 using std::ostream;
     84 using std::cerr;
     85 
     86 ////////////////////////////////////////////////////////////////////////////////
     87 
     88 //         If set to true, then the server will act as an SSL server for both
     89 //          HTTP and FLIP);
     90 bool FLAGS_use_ssl = true;
     91 
     92 // The name of the cert .pem file);
     93 string FLAGS_ssl_cert_name = "cert.pem";
     94 
     95 // The name of the key .pem file);
     96 string FLAGS_ssl_key_name = "key.pem";
     97 
     98 // The number of responses given before the server closes the
     99 //  connection);
    100 int32 FLAGS_response_count_until_close = 1000*1000;
    101 
    102 // If true, then disables the nagle algorithm);
    103 bool FLAGS_no_nagle = true;
    104 
    105 // The number of times that accept() will be called when the
    106 //  alarm goes off when the accept_using_alarm flag is set to true.
    107 //  If set to 0, accept() will be performed until the accept queue
    108 //  is completely drained and the accept() call returns an error);
    109 int32 FLAGS_accepts_per_wake = 0;
    110 
    111 // The port on which the flip server listens);
    112 int32 FLAGS_flip_port = 10040;
    113 
    114 // The port on which the http server listens);
    115 int32 FLAGS_port = 16002;
    116 
    117 // The size of the TCP accept backlog);
    118 int32 FLAGS_accept_backlog_size = 1024;
    119 
    120 // The directory where cache locates);
    121 string FLAGS_cache_base_dir = ".";
    122 
    123 // If true, then encode url to filename);
    124 bool FLAGS_need_to_encode_url = true;
    125 
    126 // If set to false a single socket will be used. If set to true
    127 //  then a new socket will be created for each accept thread.
    128 //  Note that this only works with kernels that support
    129 //  SO_REUSEPORT);
    130 bool FLAGS_reuseport = false;
    131 
    132 // The amount of time the server delays before sending back the
    133 //  reply);
    134 double FLAGS_server_think_time_in_s = 0;
    135 
    136 // Does the server send X-Subresource headers);
    137 bool FLAGS_use_xsub = false;
    138 
    139 // Does the server send X-Associated-Content headers);
    140 bool FLAGS_use_xac = false;
    141 
    142 // Does the server advance cwnd by sending no-op packets);
    143 bool FLAGS_use_cwnd_opener = false;
    144 
    145 // Does the server compress data frames);
    146 bool FLAGS_use_compression = false;
    147 
    148 // The path to the urls file which includes the urls for testing);
    149 string FLAGS_urls_file = "experimental/users/fenix/flip/urls.txt";
    150 
    151 // The path to the html that does the pageload in iframe);
    152 string FLAGS_pageload_html_file =
    153   "experimental/users/fenix/flip/loadtime_measurement.html";
    154 
    155 // If set to true, record requests in file named as fd used);
    156 bool FLAGS_record_mode = false;
    157 
    158 // The path to save the record files);
    159 string FLAGS_record_path = ".";
    160 
    161 ////////////////////////////////////////////////////////////////////////////////
    162 
    163 // Creates a socket with domain, type and protocol parameters.
    164 // Assigns the return value of socket() to *fd.
    165 // Returns errno if an error occurs, else returns zero.
    166 int CreateSocket(int domain, int type, int protocol, int *fd) {
    167   CHECK(fd != NULL);
    168   *fd = ::socket(domain, type, protocol);
    169   return (*fd == -1) ? errno : 0;
    170 }
    171 
    172 ////////////////////////////////////////////////////////////////////////////////
    173 
    174 // Sets an FD to be nonblocking.
    175 void SetNonBlocking(int fd) {
    176   DCHECK(fd >= 0);
    177 
    178   int fcntl_return = fcntl(fd, F_GETFL, 0);
    179   CHECK_NE(fcntl_return, -1)
    180     << "error doing fcntl(fd, F_GETFL, 0) fd: " << fd
    181     << " errno=" << errno;
    182 
    183   if (fcntl_return & O_NONBLOCK)
    184     return;
    185 
    186   fcntl_return = fcntl(fd, F_SETFL, fcntl_return | O_NONBLOCK);
    187   CHECK_NE(fcntl_return, -1)
    188     << "error doing fcntl(fd, F_SETFL, fcntl_return) fd: " << fd
    189     << " errno=" << errno;
    190 }
    191 
    192 ////////////////////////////////////////////////////////////////////////////////
    193 
    194 LoadtimeMeasurement global_loadtime_measurement(FLAGS_urls_file,
    195                                                 FLAGS_pageload_html_file);
    196 
    197 ////////////////////////////////////////////////////////////////////////////////
    198 
    199 struct GlobalSSLState {
    200   SSL_METHOD* ssl_method;
    201   SSL_CTX* ssl_ctx;
    202 };
    203 
    204 ////////////////////////////////////////////////////////////////////////////////
    205 
    206 GlobalSSLState* global_ssl_state = NULL;
    207 
    208 ////////////////////////////////////////////////////////////////////////////////
    209 
    210 // SSL stuff
    211 void flip_init_ssl(GlobalSSLState* state) {
    212   SSL_library_init();
    213   SSL_load_error_strings();
    214 
    215   state->ssl_method = TLSv1_server_method();
    216   state->ssl_ctx = SSL_CTX_new(state->ssl_method);
    217   if (!state->ssl_ctx) {
    218     LOG(FATAL) << "Unable to create SSL context";
    219   }
    220   if (SSL_CTX_use_certificate_file(state->ssl_ctx,
    221                                    FLAGS_ssl_cert_name.c_str(),
    222                                    SSL_FILETYPE_PEM) <= 0) {
    223     LOG(FATAL) << "Unable to use cert.pem as SSL cert.";
    224   }
    225   if (SSL_CTX_use_PrivateKey_file(state->ssl_ctx,
    226                                   FLAGS_ssl_key_name.c_str(),
    227                                   SSL_FILETYPE_PEM) <= 0) {
    228     LOG(FATAL) << "Unable to use key.pem as SSL key.";
    229   }
    230   if (!SSL_CTX_check_private_key(state->ssl_ctx)) {
    231     LOG(FATAL) << "The cert.pem and key.pem files don't match";
    232   }
    233 }
    234 
    235 SSL* flip_new_ssl(SSL_CTX* ssl_ctx) {
    236   SSL* ssl = SSL_new(ssl_ctx);
    237   SSL_set_accept_state(ssl);
    238   return ssl;
    239 }
    240 
    241 ////////////////////////////////////////////////////////////////////////////////
    242 
    243 const int kInitialDataSendersThreshold =  (2 * 1460) - FlipFrame::size();
    244 const int kNormalSegmentSize = (2 * 1460) - FlipFrame::size();
    245 
    246 ////////////////////////////////////////////////////////////////////////////////
    247 
    248 class DataFrame {
    249  public:
    250   const char* data;
    251   size_t size;
    252   bool delete_when_done;
    253   size_t index;
    254   DataFrame() : data(NULL), size(0), delete_when_done(false), index(0) {}
    255   void MaybeDelete() {
    256     if (delete_when_done) {
    257       delete[] data;
    258     }
    259   }
    260 };
    261 
    262 ////////////////////////////////////////////////////////////////////////////////
    263 
    264 class StoreBodyAndHeadersVisitor: public BalsaVisitorInterface {
    265  public:
    266   BalsaHeaders headers;
    267   string body;
    268   bool error_;
    269 
    270   virtual void ProcessBodyInput(const char *input, size_t size) {}
    271   virtual void ProcessBodyData(const char *input, size_t size) {
    272     body.append(input, size);
    273   }
    274   virtual void ProcessHeaderInput(const char *input, size_t size) {}
    275   virtual void ProcessTrailerInput(const char *input, size_t size) {}
    276   virtual void ProcessHeaders(const BalsaHeaders& headers) {
    277     // nothing to do here-- we're assuming that the BalsaFrame has
    278     // been handed our headers.
    279   }
    280   virtual void ProcessRequestFirstLine(const char* line_input,
    281                                        size_t line_length,
    282                                        const char* method_input,
    283                                        size_t method_length,
    284                                        const char* request_uri_input,
    285                                        size_t request_uri_length,
    286                                        const char* version_input,
    287                                        size_t version_length) {}
    288   virtual void ProcessResponseFirstLine(const char *line_input,
    289                                         size_t line_length,
    290                                         const char *version_input,
    291                                         size_t version_length,
    292                                         const char *status_input,
    293                                         size_t status_length,
    294                                         const char *reason_input,
    295                                         size_t reason_length) {}
    296   virtual void ProcessChunkLength(size_t chunk_length) {}
    297   virtual void ProcessChunkExtensions(const char *input, size_t size) {}
    298   virtual void HeaderDone() {}
    299   virtual void MessageDone() {}
    300   virtual void HandleHeaderError(BalsaFrame* framer) { HandleError(); }
    301   virtual void HandleHeaderWarning(BalsaFrame* framer) { HandleError(); }
    302   virtual void HandleChunkingError(BalsaFrame* framer) { HandleError(); }
    303   virtual void HandleBodyError(BalsaFrame* framer) { HandleError(); }
    304 
    305   void HandleError() { error_ = true; }
    306 };
    307 
    308 ////////////////////////////////////////////////////////////////////////////////
    309 
    310 struct FileData {
    311   void CopyFrom(const FileData& file_data) {
    312     headers = new BalsaHeaders;
    313     headers->CopyFrom(*(file_data.headers));
    314     filename = file_data.filename;
    315     related_files = file_data.related_files;
    316     body = file_data.body;
    317   }
    318   FileData(BalsaHeaders* h, const string& b) : headers(h), body(b) {}
    319   FileData() {}
    320   BalsaHeaders* headers;
    321   string filename;
    322   vector< pair<int, string> > related_files;   // priority, filename
    323   string body;
    324 };
    325 
    326 ////////////////////////////////////////////////////////////////////////////////
    327 
    328 class MemCacheIter {
    329  public:
    330   MemCacheIter() :
    331       file_data(NULL),
    332       priority(0),
    333       transformed_header(false),
    334       body_bytes_consumed(0),
    335       stream_id(0),
    336       max_segment_size(kInitialDataSendersThreshold),
    337       bytes_sent(0) {}
    338   explicit MemCacheIter(FileData* fd) :
    339       file_data(fd),
    340       priority(0),
    341       transformed_header(false),
    342       body_bytes_consumed(0),
    343       stream_id(0),
    344       max_segment_size(kInitialDataSendersThreshold),
    345       bytes_sent(0) {}
    346   FileData* file_data;
    347   int priority;
    348   bool transformed_header;
    349   size_t body_bytes_consumed;
    350   uint32 stream_id;
    351   uint32 max_segment_size;
    352   size_t bytes_sent;
    353 };
    354 
    355 ////////////////////////////////////////////////////////////////////////////////
    356 
    357 class MemoryCache {
    358  public:
    359   typedef map<string, FileData> Files;
    360 
    361  public:
    362   Files files_;
    363   string cwd_;
    364 
    365   void CloneFrom(const MemoryCache& mc) {
    366     for (Files::const_iterator i = mc.files_.begin();
    367          i != mc.files_.end();
    368          ++i) {
    369       Files::iterator out_i =
    370         files_.insert(make_pair(i->first, FileData())).first;
    371       out_i->second.CopyFrom(i->second);
    372       cwd_ = mc.cwd_;
    373     }
    374   }
    375 
    376   void AddFiles() {
    377     LOG(INFO) << "Adding files!";
    378     deque<string> paths;
    379     cwd_ = FLAGS_cache_base_dir;
    380     paths.push_back(cwd_ + "/GET_");
    381     DIR* current_dir = NULL;
    382     while (!paths.empty()) {
    383       while (current_dir == NULL && !paths.empty()) {
    384         string current_dir_name = paths.front();
    385         VLOG(1) << "Attempting to open dir: \"" << current_dir_name << "\"";
    386         current_dir = opendir(current_dir_name.c_str());
    387         paths.pop_front();
    388 
    389         if (current_dir == NULL) {
    390           perror("Unable to open directory. ");
    391           current_dir_name.clear();
    392           continue;
    393         }
    394 
    395         if (current_dir) {
    396           VLOG(1) << "Succeeded opening";
    397           for (struct dirent* dir_data = readdir(current_dir);
    398                dir_data != NULL;
    399                dir_data = readdir(current_dir)) {
    400             string current_entry_name =
    401               current_dir_name + "/" + dir_data->d_name;
    402             if (dir_data->d_type == DT_REG) {
    403               VLOG(1) << "Found file: " << current_entry_name;
    404               ReadAndStoreFileContents(current_entry_name.c_str());
    405             } else if (dir_data->d_type == DT_DIR) {
    406               VLOG(1) << "Found subdir: " << current_entry_name;
    407               if (string(dir_data->d_name) != "." &&
    408                   string(dir_data->d_name) != "..") {
    409                 VLOG(1) << "Adding to search path: " << current_entry_name;
    410                 paths.push_front(current_entry_name);
    411               }
    412             }
    413           }
    414           VLOG(1) << "Oops, no data left. Closing dir.";
    415           closedir(current_dir);
    416           current_dir = NULL;
    417         }
    418       }
    419     }
    420   }
    421 
    422   void ReadToString(const char* filename, string* output) {
    423     output->clear();
    424     int fd = open(filename, 0, "r");
    425     if (fd == -1)
    426       return;
    427     char buffer[4096];
    428     ssize_t read_status = read(fd, buffer, sizeof(buffer));
    429     while (read_status > 0) {
    430       output->append(buffer, static_cast<size_t>(read_status));
    431       do {
    432         read_status = read(fd, buffer, sizeof(buffer));
    433       } while (read_status <= 0 && errno == EINTR);
    434     }
    435     close(fd);
    436   }
    437 
    438   void ReadAndStoreFileContents(const char* filename) {
    439     StoreBodyAndHeadersVisitor visitor;
    440     BalsaFrame framer;
    441     framer.set_balsa_visitor(&visitor);
    442     framer.set_balsa_headers(&(visitor.headers));
    443     string filename_contents;
    444     ReadToString(filename, &filename_contents);
    445 
    446     // Ugly hack to make everything look like 1.1.
    447     if (filename_contents.find("HTTP/1.0") == 0)
    448       filename_contents[7] = '1';
    449 
    450     size_t pos = 0;
    451     size_t old_pos = 0;
    452     while (true) {
    453       old_pos = pos;
    454       pos += framer.ProcessInput(filename_contents.data() + pos,
    455                                  filename_contents.size() - pos);
    456       if (framer.Error() || pos == old_pos) {
    457         LOG(ERROR) << "Unable to make forward progress, or error"
    458           " framing file: " << filename;
    459         if (framer.Error()) {
    460           LOG(INFO) << "********************************************ERROR!";
    461           return;
    462         }
    463         return;
    464       }
    465       if (framer.MessageFullyRead()) {
    466         // If no Content-Length or Transfer-Encoding was captured in the
    467         // file, then the rest of the data is the body.  Many of the captures
    468         // from within Chrome don't have content-lengths.
    469         if (!visitor.body.length())
    470           visitor.body = filename_contents.substr(pos);
    471         break;
    472       }
    473     }
    474     visitor.headers.RemoveAllOfHeader("content-length");
    475     visitor.headers.RemoveAllOfHeader("transfer-encoding");
    476     visitor.headers.RemoveAllOfHeader("connection");
    477     visitor.headers.AppendHeader("transfer-encoding", "chunked");
    478     visitor.headers.AppendHeader("connection", "keep-alive");
    479 
    480     // Experiment with changing headers for forcing use of cached
    481     // versions of content.
    482     // TODO(mbelshe) REMOVE ME
    483 #if 0
    484     // TODO(mbelshe): append current date.
    485     visitor.headers.RemoveAllOfHeader("date");
    486     if (visitor.headers.HasHeader("expires")) {
    487       visitor.headers.RemoveAllOfHeader("expires");
    488       visitor.headers.AppendHeader("expires",
    489                                  "Fri, 30 Aug, 2019 12:00:00 GMT");
    490     }
    491 #endif
    492     BalsaHeaders* headers = new BalsaHeaders;
    493     headers->CopyFrom(visitor.headers);
    494     string filename_stripped =
    495       string(filename).substr(cwd_.size() + 1);
    496 //    LOG(INFO) << "Adding file (" << visitor.body.length() << " bytes): "
    497 //              << filename_stripped;
    498     files_[filename_stripped] = FileData();
    499     FileData& fd = files_[filename_stripped];
    500     fd = FileData(headers, visitor.body);
    501     fd.filename = string(filename_stripped,
    502                          filename_stripped.find_first_of('/'));
    503     if (headers->HasHeader("X-Associated-Content")) {
    504       string content =
    505         headers->GetHeader("X-Associated-Content").as_string();
    506       vector<StringPiece> urls_and_priorities;
    507       SplitStringPieceToVector(content, "||", &urls_and_priorities, true);
    508       VLOG(1) << "Examining X-Associated-Content header";
    509       for (unsigned int i = 0; i < urls_and_priorities.size(); ++i) {
    510         const StringPiece& url_and_priority_pair = urls_and_priorities[i];
    511         vector<StringPiece> url_and_priority;
    512         SplitStringPieceToVector(url_and_priority_pair, "??",
    513                                  &url_and_priority, true);
    514         if (url_and_priority.size() >= 2) {
    515           string priority_string(url_and_priority[0].data(),
    516                                  url_and_priority[0].size());
    517           string filename_string(url_and_priority[1].data(),
    518                                  url_and_priority[1].size());
    519           int priority;
    520           char* last_eaten_char;
    521           priority = strtol(priority_string.c_str(), &last_eaten_char, 0);
    522           if (last_eaten_char ==
    523               priority_string.c_str() + priority_string.size()) {
    524             pair<int, string> entry(priority, filename_string);
    525             VLOG(1) << "Adding associated content: " << filename_string;
    526             fd.related_files.push_back(entry);
    527           }
    528         }
    529       }
    530     }
    531   }
    532 
    533   // Called at runtime to update learned headers
    534   // |url| is a url which contains a referrer header.
    535   // |referrer| is the referring URL
    536   // Adds an X-Subresource or X-Associated-Content to |referer| for |url|
    537   void UpdateHeaders(string referrer, string file_url) {
    538     if (!FLAGS_use_xac && !FLAGS_use_xsub)
    539       return;
    540 
    541     string referrer_host_path =
    542       net::UrlToFilenameEncoder::Encode(referrer, "GET_/");
    543 
    544     FileData* fd1 = GetFileData(string("GET_") + file_url);
    545     if (!fd1) {
    546       LOG(ERROR) << "Updating headers for unknown url: " << file_url;
    547       return;
    548     }
    549     string url = fd1->headers->GetHeader("X-Original-Url").as_string();
    550     string content_type = fd1->headers->GetHeader("Content-Type").as_string();
    551     if (content_type.length() == 0) {
    552       LOG(ERROR) << "Skipping subresource with unknown content-type";
    553       return;
    554     }
    555 
    556     // Now, lets see if this is the same host or not
    557     bool same_host = (UrlUtilities::GetUrlHost(referrer) ==
    558                       UrlUtilities::GetUrlHost(url));
    559 
    560     // This is a hacked algorithm for figuring out what priority
    561     // to use with pushed content.
    562     int priority = 4;
    563     if (content_type.find("css") != string::npos)
    564       priority = 1;
    565     else if (content_type.find("cript") != string::npos)
    566       priority = 1;
    567     else if (content_type.find("html") != string::npos)
    568       priority = 2;
    569 
    570     LOG(ERROR) << "Attempting update for " << referrer_host_path;
    571 
    572     FileData* fd2 = GetFileData(referrer_host_path);
    573     if (fd2 != NULL) {
    574       // If they are on the same host, we'll use X-Associated-Content
    575       string header_name;
    576       string new_value;
    577       string delimiter;
    578       bool related_files = false;
    579       if (same_host && FLAGS_use_xac) {
    580         header_name = "X-Associated-Content";
    581         char pri_ch = priority + '0';
    582         new_value = pri_ch + string("??") + url;
    583         delimiter = "||";
    584         related_files = true;
    585       } else {
    586         if (!FLAGS_use_xsub)
    587           return;
    588         header_name = "X-Subresource";
    589         new_value = content_type + "!!" + url;
    590         delimiter = "!!";
    591       }
    592 
    593       if (fd2->headers->HasNonEmptyHeader(header_name)) {
    594         string existing_header =
    595             fd2->headers->GetHeader(header_name).as_string();
    596         if (existing_header.find(url) != string::npos)
    597           return;  // header already recorded
    598 
    599         // Don't let these lists grow too long for low pri stuff.
    600         // TODO(mbelshe) We need better algorithms for this.
    601         if (existing_header.length() > 256 && priority > 2)
    602           return;
    603 
    604         new_value = existing_header + delimiter + new_value;
    605       }
    606 
    607       LOG(INFO) << "Recording " << header_name << " for " << new_value;
    608       fd2->headers->ReplaceOrAppendHeader(header_name, new_value);
    609 
    610       // Add it to the related files so that it will actually get sent out.
    611       if (related_files) {
    612         pair<int, string> entry(4, file_url);
    613         fd2->related_files.push_back(entry);
    614       }
    615     } else {
    616       LOG(ERROR) << "Failed to update headers:";
    617       LOG(ERROR) << "FAIL url: " << url;
    618       LOG(ERROR) << "FAIL ref: " << referrer_host_path;
    619     }
    620   }
    621 
    622   FileData* GetFileData(const string& filename) {
    623     Files::iterator fi = files_.end();
    624     if (filename.compare(filename.length() - 5, 5, ".html", 5) == 0) {
    625       string new_filename(filename.data(), filename.size() - 5);
    626       new_filename += ".http";
    627       fi = files_.find(new_filename);
    628     }
    629     if (fi == files_.end())
    630       fi = files_.find(filename);
    631 
    632     if (fi == files_.end()) {
    633       return NULL;
    634     }
    635     return &(fi->second);
    636   }
    637 
    638   bool AssignFileData(const string& filename, MemCacheIter* mci) {
    639     mci->file_data = GetFileData(filename);
    640     if (mci->file_data == NULL) {
    641       LOG(ERROR) << "Could not find file data for " << filename;
    642       return false;
    643     }
    644     return true;
    645   }
    646 };
    647 
    648 ////////////////////////////////////////////////////////////////////////////////
    649 
    650 class NotifierInterface {
    651  public:
    652   virtual ~NotifierInterface() {}
    653   virtual void Notify() = 0;
    654 };
    655 
    656 ////////////////////////////////////////////////////////////////////////////////
    657 
    658 class SMInterface {
    659  public:
    660   virtual size_t ProcessInput(const char* data, size_t len) = 0;
    661   virtual bool MessageFullyRead() const = 0;
    662   virtual bool Error() const = 0;
    663   virtual const char* ErrorAsString() const = 0;
    664   virtual void Reset() = 0;
    665   virtual void ResetForNewConnection() = 0;
    666 
    667   virtual void PostAcceptHook() = 0;
    668 
    669   virtual void NewStream(uint32 stream_id, uint32 priority,
    670                          const string& filename) = 0;
    671   virtual void SendEOF(uint32 stream_id) = 0;
    672   virtual void SendErrorNotFound(uint32 stream_id) = 0;
    673   virtual size_t SendSynStream(uint32 stream_id,
    674                               const BalsaHeaders& headers) = 0;
    675   virtual size_t SendSynReply(uint32 stream_id,
    676                               const BalsaHeaders& headers) = 0;
    677   virtual void SendDataFrame(uint32 stream_id, const char* data, int64 len,
    678                              uint32 flags, bool compress) = 0;
    679   virtual void GetOutput() = 0;
    680 
    681   virtual ~SMInterface() {}
    682 };
    683 
    684 ////////////////////////////////////////////////////////////////////////////////
    685 
    686 class SMServerConnection;
    687 typedef SMInterface*(SMInterfaceFactory)(SMServerConnection* conn);
    688 
    689 ////////////////////////////////////////////////////////////////////////////////
    690 
    691 typedef list<DataFrame> OutputList;
    692 
    693 ////////////////////////////////////////////////////////////////////////////////
    694 
    695 class SMServerConnection;
    696 
    697 class SMServerConnectionPoolInterface {
    698  public:
    699   virtual ~SMServerConnectionPoolInterface() {}
    700   // SMServerConnections will use this:
    701   virtual void SMServerConnectionDone(SMServerConnection* conn) = 0;
    702 };
    703 
    704 ////////////////////////////////////////////////////////////////////////////////
    705 
    706 class SMServerConnection: public EpollCallbackInterface,
    707                           public NotifierInterface {
    708  private:
    709   SMServerConnection(SMInterfaceFactory* sm_interface_factory,
    710                      MemoryCache* memory_cache,
    711                      EpollServer* epoll_server) :
    712       fd_(-1),
    713       record_fd_(-1),
    714       events_(0),
    715 
    716       registered_in_epoll_server_(false),
    717       initialized_(false),
    718 
    719       connection_pool_(NULL),
    720       epoll_server_(epoll_server),
    721 
    722       read_buffer_(4096*10),
    723       memory_cache_(memory_cache),
    724       sm_interface_(sm_interface_factory(this)),
    725 
    726       max_bytes_sent_per_dowrite_(128),
    727 
    728       ssl_(NULL) {}
    729 
    730   int fd_;
    731   int record_fd_;
    732   int events_;
    733 
    734   bool registered_in_epoll_server_;
    735   bool initialized_;
    736 
    737   SMServerConnectionPoolInterface* connection_pool_;
    738   EpollServer* epoll_server_;
    739 
    740   RingBuffer read_buffer_;
    741 
    742   OutputList output_list_;
    743   MemoryCache* memory_cache_;
    744   SMInterface* sm_interface_;
    745 
    746   size_t max_bytes_sent_per_dowrite_;
    747 
    748   SSL* ssl_;
    749  public:
    750   EpollServer* epoll_server() { return epoll_server_; }
    751   OutputList* output_list() { return &output_list_; }
    752   MemoryCache* memory_cache() { return memory_cache_; }
    753   int record_fd() { return record_fd_; }
    754   void close_record_fd() {
    755     if (record_fd_ != -1) {
    756       close(record_fd_);
    757       record_fd_ = -1;
    758     }
    759   }
    760   void ReadyToSend() {
    761     epoll_server_->SetFDReady(fd_, EPOLLIN | EPOLLOUT);
    762   }
    763   void EnqueueDataFrame(const DataFrame& df) {
    764     output_list_.push_back(df);
    765     VLOG(2) << "EnqueueDataFrame. Setting FD ready.";
    766     ReadyToSend();
    767   }
    768 
    769  public:
    770   ~SMServerConnection() {
    771     if (initialized()) {
    772       Reset();
    773     }
    774   }
    775   static SMServerConnection* NewSMServerConnection(SMInterfaceFactory* smif,
    776                                                    MemoryCache* memory_cache,
    777                                                    EpollServer* epoll_server) {
    778     return new SMServerConnection(smif, memory_cache, epoll_server);
    779   }
    780 
    781   bool initialized() const { return initialized_; }
    782 
    783   void InitSMServerConnection(SMServerConnectionPoolInterface* connection_pool,
    784                               EpollServer* epoll_server,
    785                               int fd) {
    786     if (initialized_) {
    787       LOG(FATAL) << "Attempted to initialize already initialized server";
    788       return;
    789     }
    790     if (epoll_server_ && registered_in_epoll_server_ && fd_ != -1) {
    791       epoll_server_->UnregisterFD(fd_);
    792     }
    793     if (fd_ != -1) {
    794       VLOG(2) << "Closing pre-existing fd";
    795       close(fd_);
    796       fd_ = -1;
    797     }
    798     if (FLAGS_record_mode) {
    799       char record_file_name[1024];
    800       snprintf(record_file_name, sizeof(record_file_name), "%s/%d_%ld",
    801               FLAGS_record_path.c_str(), fd, epoll_server->NowInUsec()/1000);
    802       record_fd_ = open(record_file_name, O_CREAT|O_APPEND|O_WRONLY, S_IRWXU);
    803       if (record_fd_ < 0) {
    804         LOG(ERROR) << "Open record file for fd " << fd << " failed";
    805         record_fd_ = -1;
    806       }
    807     }
    808 
    809     fd_ = fd;
    810 
    811     registered_in_epoll_server_ = false;
    812     initialized_ = true;
    813 
    814     connection_pool_ = connection_pool;
    815     epoll_server_ = epoll_server;
    816 
    817     sm_interface_->Reset();
    818     read_buffer_.Clear();
    819 
    820     epoll_server_->RegisterFD(fd_, this, EPOLLIN | EPOLLOUT | EPOLLET);
    821 
    822     if (global_ssl_state) {
    823       ssl_ = flip_new_ssl(global_ssl_state->ssl_ctx);
    824       SSL_set_fd(ssl_, fd_);
    825     }
    826     sm_interface_->PostAcceptHook();
    827   }
    828 
    829   int Send(const char* bytes, int len, int flags) {
    830     return send(fd_, bytes, len, flags);
    831   }
    832 
    833   // the following are from the EpollCallbackInterface
    834   virtual void OnRegistration(EpollServer* eps, int fd, int event_mask) {
    835     registered_in_epoll_server_ = true;
    836   }
    837   virtual void OnModification(int fd, int event_mask) { }
    838   virtual void OnEvent(int fd, EpollEvent* event) {
    839     events_ |= event->in_events;
    840     HandleEvents();
    841     if (events_) {
    842       event->out_ready_mask = events_;
    843       events_ = 0;
    844     }
    845   }
    846   virtual void OnUnregistration(int fd, bool replaced) {
    847     registered_in_epoll_server_ = false;
    848   }
    849   virtual void OnShutdown(EpollServer* eps, int fd) {
    850     Cleanup("OnShutdown");
    851     return;
    852   }
    853 
    854  private:
    855   void HandleEvents() {
    856     VLOG(1) << "Received: " << EpollServer::EventMaskToString(events_);
    857     if (events_ & EPOLLIN) {
    858       if (!DoRead())
    859         goto handle_close_or_error;
    860     }
    861 
    862     if (events_ & EPOLLOUT) {
    863       if (!DoWrite())
    864         goto handle_close_or_error;
    865     }
    866 
    867     if (events_ & (EPOLLHUP | EPOLLERR)) {
    868       VLOG(2) << "!!!! Got HUP or ERR";
    869       goto handle_close_or_error;
    870     }
    871     return;
    872 
    873  handle_close_or_error:
    874     Cleanup("HandleEvents");
    875   }
    876 
    877   bool DoRead() {
    878     VLOG(2) << "DoRead()";
    879     if (fd_ == -1) {
    880       VLOG(2) << "DoRead(): fd_ == -1. Invalid FD. Returning false";
    881       return false;
    882     }
    883     while (!read_buffer_.Full()) {
    884       char* bytes;
    885       int size;
    886       read_buffer_.GetWritablePtr(&bytes, &size);
    887       ssize_t bytes_read = 0;
    888       if (ssl_) {
    889         bytes_read = SSL_read(ssl_, bytes, size);
    890       } else {
    891         bytes_read = recv(fd_, bytes, size, MSG_DONTWAIT);
    892       }
    893       int stored_errno = errno;
    894       if (bytes_read == -1) {
    895         switch (stored_errno) {
    896           case EAGAIN:
    897             events_ &= ~EPOLLIN;
    898             VLOG(2) << "Got EAGAIN while reading";
    899             goto done;
    900           case EINTR:
    901             VLOG(2) << "Got EINTR while reading";
    902             continue;
    903           default:
    904             VLOG(2) << "While calling recv, got error: " << stored_errno
    905               << " " << strerror(stored_errno);
    906             goto error_or_close;
    907         }
    908       } else if (bytes_read > 0) {
    909         VLOG(2) << "Read: " << bytes_read << " bytes from fd: " << fd_;
    910         read_buffer_.AdvanceWritablePtr(bytes_read);
    911         if (!DoConsumeReadData()) {
    912           goto error_or_close;
    913         }
    914         continue;
    915       } else {  // bytes_read == 0
    916         VLOG(2) << "0 bytes read with recv call.";
    917       }
    918       goto error_or_close;
    919     }
    920    done:
    921     return true;
    922 
    923    error_or_close:
    924     VLOG(2) << "DoRead(): error_or_close. Cleaning up, then returning false";
    925     Cleanup("DoRead");
    926     return false;
    927   }
    928 
    929   bool DoConsumeReadData() {
    930     char* bytes;
    931     int size;
    932     read_buffer_.GetReadablePtr(&bytes, &size);
    933     while (size != 0) {
    934       size_t bytes_consumed = sm_interface_->ProcessInput(bytes, size);
    935       VLOG(2) << "consumed: " << bytes_consumed << " from socket fd: " << fd_;
    936       if (bytes_consumed == 0) {
    937         break;
    938       }
    939       read_buffer_.AdvanceReadablePtr(bytes_consumed);
    940       if (sm_interface_->MessageFullyRead()) {
    941         VLOG(2) << "HandleRequestFullyRead";
    942         HandleRequestFullyRead();
    943         sm_interface_->Reset();
    944         events_ |= EPOLLOUT;
    945       } else if (sm_interface_->Error()) {
    946         LOG(ERROR) << "Framer error detected: "
    947                    << sm_interface_->ErrorAsString();
    948         // this causes everything to be closed/cleaned up.
    949         events_ |= EPOLLOUT;
    950         return false;
    951       }
    952       read_buffer_.GetReadablePtr(&bytes, &size);
    953     }
    954     return true;
    955   }
    956 
    957   void WriteResponse() {
    958     // this happens asynchronously from separate threads
    959     // feeding files into the output buffer.
    960   }
    961 
    962   void HandleRequestFullyRead() {
    963   }
    964 
    965   void Notify() {
    966   }
    967 
    968   bool DoWrite() {
    969     size_t bytes_sent = 0;
    970     int flags = MSG_NOSIGNAL | MSG_DONTWAIT;
    971     if (fd_ == -1) {
    972       VLOG(2) << "DoWrite: fd == -1. Returning false.";
    973       return false;
    974     }
    975     if (output_list_.empty()) {
    976       sm_interface_->GetOutput();
    977       if (output_list_.empty())
    978         events_ &= ~EPOLLOUT;
    979     }
    980     while (!output_list_.empty()) {
    981       if (bytes_sent >= max_bytes_sent_per_dowrite_) {
    982         events_ |= EPOLLOUT;
    983         break;
    984       }
    985       if (output_list_.size() < 2) {
    986         sm_interface_->GetOutput();
    987       }
    988       DataFrame& data_frame = output_list_.front();
    989       const char*  bytes = data_frame.data;
    990       int size = data_frame.size;
    991       bytes += data_frame.index;
    992       size -= data_frame.index;
    993       DCHECK_GE(size, 0);
    994       if (size <= 0) {
    995         data_frame.MaybeDelete();
    996         output_list_.pop_front();
    997         continue;
    998       }
    999 
   1000       flags = MSG_NOSIGNAL | MSG_DONTWAIT;
   1001       if (output_list_.size() > 1) {
   1002         flags |= MSG_MORE;
   1003       }
   1004       ssize_t bytes_written = 0;
   1005       if (ssl_) {
   1006         bytes_written = SSL_write(ssl_, bytes, size);
   1007       } else {
   1008         bytes_written = send(fd_, bytes, size, flags);
   1009       }
   1010       int stored_errno = errno;
   1011       if (bytes_written == -1) {
   1012         switch (stored_errno) {
   1013           case EAGAIN:
   1014             events_ &= ~EPOLLOUT;
   1015             VLOG(2) << " Got EAGAIN while writing";
   1016             goto done;
   1017           case EINTR:
   1018             VLOG(2) << " Got EINTR while writing";
   1019             continue;
   1020           default:
   1021             VLOG(2) << "While calling send, got error: " << stored_errno
   1022               << " " << strerror(stored_errno);
   1023             goto error_or_close;
   1024         }
   1025       } else if (bytes_written > 0) {
   1026         VLOG(1) << "Wrote: " << bytes_written  << " bytes to socket fd: "
   1027           << fd_;
   1028         data_frame.index += bytes_written;
   1029         bytes_sent += bytes_written;
   1030         continue;
   1031       }
   1032       VLOG(2) << "0 bytes written to socket " << fd_ << " with send call.";
   1033       goto error_or_close;
   1034     }
   1035    done:
   1036     return true;
   1037 
   1038    error_or_close:
   1039     VLOG(2) << "DoWrite: error_or_close. Returning false after cleaning up";
   1040     Cleanup("DoWrite");
   1041     return false;
   1042   }
   1043 
   1044   friend ostream& operator<<(ostream& os, const SMServerConnection& c) {
   1045     os << &c << "\n";
   1046     return os;
   1047   }
   1048 
   1049   void Reset() {
   1050     VLOG(2) << "Resetting";
   1051     if (ssl_) {
   1052       SSL_shutdown(ssl_);
   1053       SSL_free(ssl_);
   1054     }
   1055     if (registered_in_epoll_server_) {
   1056       epoll_server_->UnregisterFD(fd_);
   1057       registered_in_epoll_server_ = false;
   1058     }
   1059     if (fd_ >= 0) {
   1060       VLOG(2) << "Closing connection";
   1061       close(fd_);
   1062       fd_ = -1;
   1063     }
   1064     sm_interface_->ResetForNewConnection();
   1065     read_buffer_.Clear();
   1066     initialized_ = false;
   1067     events_ = 0;
   1068     output_list_.clear();
   1069   }
   1070 
   1071   void Cleanup(const char* cleanup) {
   1072     VLOG(2) << "Cleaning up: " << cleanup;
   1073     if (!initialized_) {
   1074       return;
   1075     }
   1076     Reset();
   1077     connection_pool_->SMServerConnectionDone(this);
   1078   }
   1079 };
   1080 
   1081 ////////////////////////////////////////////////////////////////////////////////
   1082 
   1083 class OutputOrdering {
   1084  public:
   1085   typedef list<MemCacheIter> PriorityRing;
   1086 
   1087   typedef map<uint32, PriorityRing> PriorityMap;
   1088 
   1089   struct PriorityMapPointer {
   1090     PriorityMapPointer(): ring(NULL), alarm_enabled(false) {}
   1091     PriorityRing* ring;
   1092     PriorityRing::iterator it;
   1093     bool alarm_enabled;
   1094     EpollServer::AlarmRegToken alarm_token;
   1095   };
   1096   typedef map<uint32, PriorityMapPointer> StreamIdToPriorityMap;
   1097 
   1098   StreamIdToPriorityMap stream_ids_;
   1099   PriorityMap priority_map_;
   1100   PriorityRing first_data_senders_;
   1101   uint32 first_data_senders_threshold_;  // when you've passed this, you're no
   1102                                          // longer a first_data_sender...
   1103   SMServerConnection* connection_;
   1104   EpollServer* epoll_server_;
   1105 
   1106   explicit OutputOrdering(SMServerConnection* connection) :
   1107       first_data_senders_threshold_(kInitialDataSendersThreshold),
   1108       connection_(connection),
   1109       epoll_server_(connection->epoll_server()) {
   1110   }
   1111 
   1112   void Reset() {
   1113     while (!stream_ids_.empty()) {
   1114       StreamIdToPriorityMap::iterator sitpmi = stream_ids_.begin();
   1115       PriorityMapPointer& pmp = sitpmi->second;
   1116       if (pmp.alarm_enabled) {
   1117         epoll_server_->UnregisterAlarm(pmp.alarm_token);
   1118       }
   1119       stream_ids_.erase(sitpmi);
   1120     }
   1121     priority_map_.clear();
   1122     first_data_senders_.clear();
   1123   }
   1124 
   1125   bool ExistsInPriorityMaps(uint32 stream_id) {
   1126     StreamIdToPriorityMap::iterator sitpmi = stream_ids_.find(stream_id);
   1127     return sitpmi != stream_ids_.end();
   1128   }
   1129 
   1130   struct BeginOutputtingAlarm : public EpollAlarmCallbackInterface {
   1131    public:
   1132     BeginOutputtingAlarm(OutputOrdering* oo,
   1133                          OutputOrdering::PriorityMapPointer* pmp,
   1134                          const MemCacheIter& mci) :
   1135         output_ordering_(oo), pmp_(pmp), mci_(mci), epoll_server_(NULL) {}
   1136 
   1137     int64 OnAlarm() {
   1138       OnUnregistration();
   1139       output_ordering_->MoveToActive(pmp_, mci_);
   1140       VLOG(1) << "ON ALARM! Should now start to output...";
   1141       delete this;
   1142       return 0;
   1143     }
   1144     void OnRegistration(const EpollServer::AlarmRegToken& tok,
   1145                         EpollServer* eps) {
   1146       epoll_server_ = eps;
   1147       pmp_->alarm_token = tok;
   1148       pmp_->alarm_enabled = true;
   1149     }
   1150     void OnUnregistration() {
   1151       pmp_->alarm_enabled = false;
   1152     }
   1153     void OnShutdown(EpollServer* eps) {
   1154       OnUnregistration();
   1155     }
   1156     ~BeginOutputtingAlarm() {
   1157       if (epoll_server_ && pmp_->alarm_enabled)
   1158         epoll_server_->UnregisterAlarm(pmp_->alarm_token);
   1159     }
   1160    private:
   1161     OutputOrdering* output_ordering_;
   1162     OutputOrdering::PriorityMapPointer* pmp_;
   1163     MemCacheIter mci_;
   1164     EpollServer* epoll_server_;
   1165   };
   1166 
   1167   void MoveToActive(PriorityMapPointer* pmp, MemCacheIter mci) {
   1168     VLOG(1) <<"Moving to active!";
   1169     first_data_senders_.push_back(mci);
   1170     pmp->ring = &first_data_senders_;
   1171     pmp->it = first_data_senders_.end();
   1172     --pmp->it;
   1173     connection_->ReadyToSend();
   1174   }
   1175 
   1176   void AddToOutputOrder(const MemCacheIter& mci) {
   1177     if (ExistsInPriorityMaps(mci.stream_id))
   1178       LOG(FATAL) << "OOps, already was inserted here?!";
   1179 
   1180     StreamIdToPriorityMap::iterator sitpmi;
   1181     sitpmi = stream_ids_.insert(
   1182         pair<uint32, PriorityMapPointer>(mci.stream_id,
   1183                                          PriorityMapPointer())).first;
   1184     PriorityMapPointer& pmp = sitpmi->second;
   1185 
   1186     BeginOutputtingAlarm* boa = new BeginOutputtingAlarm(this, &pmp, mci);
   1187     epoll_server_->RegisterAlarmApproximateDelta(
   1188         FLAGS_server_think_time_in_s * 1000000, boa);
   1189   }
   1190 
   1191   void SpliceToPriorityRing(PriorityRing::iterator pri) {
   1192     MemCacheIter& mci = *pri;
   1193     PriorityMap::iterator pmi = priority_map_.find(mci.priority);
   1194     if (pmi == priority_map_.end()) {
   1195       pmi = priority_map_.insert(
   1196           pair<uint32, PriorityRing>(mci.priority, PriorityRing())).first;
   1197     }
   1198 
   1199     pmi->second.splice(pmi->second.end(),
   1200                        first_data_senders_,
   1201                        pri);
   1202     StreamIdToPriorityMap::iterator sitpmi = stream_ids_.find(mci.stream_id);
   1203     sitpmi->second.ring = &(pmi->second);
   1204   }
   1205 
   1206   MemCacheIter* GetIter() {
   1207     while (!first_data_senders_.empty()) {
   1208       MemCacheIter& mci = first_data_senders_.front();
   1209       if (mci.bytes_sent >= first_data_senders_threshold_) {
   1210         SpliceToPriorityRing(first_data_senders_.begin());
   1211       } else {
   1212         first_data_senders_.splice(first_data_senders_.end(),
   1213                                   first_data_senders_,
   1214                                   first_data_senders_.begin());
   1215         mci.max_segment_size = kInitialDataSendersThreshold;
   1216         return &mci;
   1217       }
   1218     }
   1219     while (!priority_map_.empty()) {
   1220       PriorityRing& first_ring = priority_map_.begin()->second;
   1221       if (first_ring.empty()) {
   1222         priority_map_.erase(priority_map_.begin());
   1223         continue;
   1224       }
   1225       MemCacheIter& mci = first_ring.front();
   1226       first_ring.splice(first_ring.end(),
   1227                         first_ring,
   1228                         first_ring.begin());
   1229       mci.max_segment_size = kNormalSegmentSize;
   1230       return &mci;
   1231     }
   1232     return NULL;
   1233   }
   1234 
   1235   void RemoveStreamId(uint32 stream_id) {
   1236     StreamIdToPriorityMap::iterator sitpmi = stream_ids_.find(stream_id);
   1237     if (sitpmi == stream_ids_.end())
   1238       return;
   1239     PriorityMapPointer& pmp = sitpmi->second;
   1240     if (pmp.alarm_enabled) {
   1241       epoll_server_->UnregisterAlarm(pmp.alarm_token);
   1242     } else {
   1243       pmp.ring->erase(pmp.it);
   1244     }
   1245 
   1246     stream_ids_.erase(sitpmi);
   1247   }
   1248 };
   1249 
   1250 ////////////////////////////////////////////////////////////////////////////////
   1251 
   1252 class FlipSM : public FlipFramerVisitorInterface, public SMInterface {
   1253  private:
   1254   uint64 seq_num_;
   1255   FlipFramer* framer_;
   1256 
   1257   SMServerConnection* connection_;
   1258   OutputList* output_list_;
   1259   OutputOrdering output_ordering_;
   1260   MemoryCache* memory_cache_;
   1261   uint32 next_outgoing_stream_id_;
   1262  public:
   1263   explicit FlipSM(SMServerConnection* connection) :
   1264       seq_num_(0),
   1265       framer_(new FlipFramer),
   1266       connection_(connection),
   1267       output_list_(connection->output_list()),
   1268       output_ordering_(connection),
   1269       memory_cache_(connection->memory_cache()),
   1270       next_outgoing_stream_id_(2) {
   1271     framer_->set_visitor(this);
   1272   }
   1273  private:
   1274   virtual void OnError(FlipFramer* framer) {
   1275     /* do nothing with this right now */
   1276   }
   1277 
   1278   virtual void OnControl(const FlipControlFrame* frame) {
   1279     FlipHeaderBlock headers;
   1280     bool parsed_headers = false;
   1281     switch (frame->type()) {
   1282       case SYN_STREAM:
   1283         {
   1284         parsed_headers = framer_->ParseHeaderBlock(frame, &headers);
   1285         VLOG(2) << "OnSyn(" << frame->stream_id() << ")";
   1286         VLOG(2) << "headers parsed?: " << (parsed_headers? "yes": "no");
   1287         if (parsed_headers) {
   1288           VLOG(2) << "# headers: " << headers.size();
   1289         }
   1290         unsigned int j = 0;
   1291         for (FlipHeaderBlock::iterator i = headers.begin();
   1292              i != headers.end();
   1293              ++i) {
   1294           VLOG(2) << i->first << ": " << i->second;
   1295           if (FLAGS_record_mode && connection_->record_fd() > 0) {
   1296             // If record mode is enabled and corresponding server connection
   1297             // has file opened, then save the request headers into the file.
   1298             // All the requests from the same connection is save in one file.
   1299             // This file will be used to replay and generate FLIP requests
   1300             // load.
   1301             string header = i->first + ": " + i->second + "\n";
   1302             ++j;
   1303             if (j == headers.size()) {
   1304               header += "\n";  // add an additional empty lime
   1305             }
   1306             int r = write(
   1307                 connection_->record_fd(), header.c_str(), header.size());
   1308             if (r < 0) {
   1309               perror("unable to write to record file:");
   1310             }
   1311           }
   1312         }
   1313 
   1314         FlipHeaderBlock::iterator method = headers.find("method");
   1315         FlipHeaderBlock::iterator url = headers.find("url");
   1316         if (url == headers.end() || method == headers.end()) {
   1317           VLOG(2) << "didn't find method or url or method. Not creating stream";
   1318           break;
   1319         }
   1320 
   1321         FlipHeaderBlock::iterator referer = headers.find("referer");
   1322         if (referer != headers.end() && method->second == "GET") {
   1323           memory_cache_->UpdateHeaders(referer->second, url->second);
   1324         }
   1325         string uri = UrlUtilities::GetUrlPath(url->second);
   1326         string host = UrlUtilities::GetUrlHost(url->second);
   1327         // requests started with /testing are loadtime measurement related
   1328         // urls, use LoadtimeMeasurement class to handle them.
   1329         if (uri.find("/testing") == 0) {
   1330           string output;
   1331           global_loadtime_measurement.ProcessRequest(uri, output);
   1332           SendOKResponse(frame->stream_id(), &output);
   1333         } else {
   1334           string filename;
   1335           if (FLAGS_need_to_encode_url) {
   1336             filename = net::UrlToFilenameEncoder::Encode(
   1337                 "http://" + host + uri, method->second + "_/");
   1338           } else {
   1339             filename = string(method->second + "_" + url->second);
   1340           }
   1341 
   1342           NewStream(frame->stream_id(),
   1343                     reinterpret_cast<const FlipSynStreamControlFrame*>(frame)->
   1344                       priority(),
   1345                     filename);
   1346           }
   1347         }
   1348         break;
   1349 
   1350       case SYN_REPLY:
   1351         parsed_headers = framer_->ParseHeaderBlock(frame, &headers);
   1352         VLOG(2) << "OnSynReply(" << frame->stream_id() << ")";
   1353         break;
   1354       case FIN_STREAM:
   1355         VLOG(2) << "OnFin(" << frame->stream_id() << ")";
   1356         output_ordering_.RemoveStreamId(frame->stream_id());
   1357 
   1358         break;
   1359       default:
   1360         LOG(DFATAL) << "Unknown control frame type";
   1361     }
   1362   }
   1363   virtual void OnStreamFrameData(
   1364     FlipStreamId stream_id,
   1365     const char* data, size_t len) {
   1366     VLOG(2) << "StreamData(" << stream_id << ", [" << len << "])";
   1367     /* do nothing with this right now */
   1368   }
   1369   virtual void OnLameDuck() {
   1370     /* do nothing with this right now */
   1371   }
   1372 
   1373  public:
   1374   ~FlipSM() {
   1375     Reset();
   1376   }
   1377   size_t ProcessInput(const char* data, size_t len) {
   1378     return framer_->ProcessInput(data, len);
   1379   }
   1380 
   1381   bool MessageFullyRead() const {
   1382     return framer_->MessageFullyRead();
   1383   }
   1384 
   1385   bool Error() const {
   1386     return framer_->HasError();
   1387   }
   1388 
   1389   const char* ErrorAsString() const {
   1390     return FlipFramer::ErrorCodeToString(framer_->error_code());
   1391   }
   1392 
   1393   void Reset() {}
   1394   void ResetForNewConnection() {
   1395     // seq_num is not cleared, intentionally.
   1396     delete framer_;
   1397     framer_ = new FlipFramer;
   1398     framer_->set_visitor(this);
   1399     output_ordering_.Reset();
   1400     next_outgoing_stream_id_ = 2;
   1401   }
   1402 
   1403   // Send a couple of NOOP packets to force opening of cwnd.
   1404   void PostAcceptHook() {
   1405     if (!FLAGS_use_cwnd_opener)
   1406       return;
   1407 
   1408     // We send 2 because that is the initial cwnd, and also because
   1409     // we have to in order to get an ACK back from the client due to
   1410     // delayed ACK.
   1411     const int kPkts = 2;
   1412 
   1413     LOG(ERROR) << "Sending NOP FRAMES";
   1414 
   1415     scoped_ptr<FlipControlFrame> frame(FlipFramer::CreateNopFrame());
   1416     for (int i = 0; i < kPkts; ++i) {
   1417       char* bytes = frame->data();
   1418       size_t size = FlipFrame::size();
   1419       ssize_t bytes_written = connection_->Send(bytes, size, MSG_DONTWAIT);
   1420       if (bytes_written > 0 && static_cast<size_t>(bytes_written) != size) {
   1421         LOG(ERROR) << "Trouble sending Nop packet! (" << errno << ")";
   1422         if (errno == EAGAIN)
   1423           break;
   1424       }
   1425     }
   1426   }
   1427 
   1428   void AddAssociatedContent(FileData* file_data) {
   1429     for (unsigned int i = 0; i < file_data->related_files.size(); ++i) {
   1430       pair<int, string>& related_file = file_data->related_files[i];
   1431       MemCacheIter mci;
   1432       string filename  = "GET_";
   1433       filename += related_file.second;
   1434       if (!memory_cache_->AssignFileData(filename, &mci)) {
   1435         VLOG(1) << "Unable to find associated content for: " << filename;
   1436         continue;
   1437       }
   1438       VLOG(1) << "Adding associated content: " << filename;
   1439       mci.stream_id = next_outgoing_stream_id_;
   1440       next_outgoing_stream_id_ += 2;
   1441       mci.priority =  related_file.first;
   1442       AddToOutputOrder(mci);
   1443     }
   1444   }
   1445 
   1446   void NewStream(uint32 stream_id, uint32 priority, const string& filename) {
   1447     MemCacheIter mci;
   1448     mci.stream_id = stream_id;
   1449     mci.priority = priority;
   1450     if (!memory_cache_->AssignFileData(filename, &mci)) {
   1451       // error creating new stream.
   1452       VLOG(2) << "Sending ErrorNotFound";
   1453       SendErrorNotFound(stream_id);
   1454     } else {
   1455       AddToOutputOrder(mci);
   1456       if (FLAGS_use_xac) {
   1457         AddAssociatedContent(mci.file_data);
   1458       }
   1459     }
   1460   }
   1461 
   1462   void AddToOutputOrder(const MemCacheIter& mci) {
   1463     output_ordering_.AddToOutputOrder(mci);
   1464   }
   1465 
   1466   void SendEOF(uint32 stream_id) {
   1467     SendEOFImpl(stream_id);
   1468   }
   1469 
   1470   void SendErrorNotFound(uint32 stream_id) {
   1471     SendErrorNotFoundImpl(stream_id);
   1472   }
   1473 
   1474   void SendOKResponse(uint32 stream_id, string* output) {
   1475     SendOKResponseImpl(stream_id, output);
   1476   }
   1477 
   1478   size_t SendSynStream(uint32 stream_id, const BalsaHeaders& headers) {
   1479     return SendSynStreamImpl(stream_id, headers);
   1480   }
   1481 
   1482   size_t SendSynReply(uint32 stream_id, const BalsaHeaders& headers) {
   1483     return SendSynReplyImpl(stream_id, headers);
   1484   }
   1485 
   1486   void SendDataFrame(uint32 stream_id, const char* data, int64 len,
   1487                      uint32 flags, bool compress) {
   1488     FlipDataFlags flip_flags = static_cast<FlipDataFlags>(flags);
   1489     SendDataFrameImpl(stream_id, data, len, flip_flags, compress);
   1490   }
   1491 
   1492   FlipFramer* flip_framer() { return framer_; }
   1493 
   1494  private:
   1495   void SendEOFImpl(uint32 stream_id) {
   1496     SendDataFrame(stream_id, NULL, 0, DATA_FLAG_FIN, false);
   1497     VLOG(2) << "Sending EOF: " << stream_id;
   1498     KillStream(stream_id);
   1499   }
   1500 
   1501   void SendErrorNotFoundImpl(uint32 stream_id) {
   1502     BalsaHeaders my_headers;
   1503     my_headers.SetFirstlineFromStringPieces("HTTP/1.1", "404", "Not Found");
   1504     SendSynReplyImpl(stream_id, my_headers);
   1505     SendDataFrame(stream_id, "wtf?", 4, DATA_FLAG_FIN, false);
   1506     output_ordering_.RemoveStreamId(stream_id);
   1507   }
   1508 
   1509   void SendOKResponseImpl(uint32 stream_id, string* output) {
   1510     BalsaHeaders my_headers;
   1511     my_headers.SetFirstlineFromStringPieces("HTTP/1.1", "200", "OK");
   1512     SendSynReplyImpl(stream_id, my_headers);
   1513     SendDataFrame(
   1514         stream_id, output->c_str(), output->size(), DATA_FLAG_FIN, false);
   1515     output_ordering_.RemoveStreamId(stream_id);
   1516   }
   1517 
   1518   void KillStream(uint32 stream_id) {
   1519     output_ordering_.RemoveStreamId(stream_id);
   1520   }
   1521 
   1522   void CopyHeaders(FlipHeaderBlock& dest, const BalsaHeaders& headers) {
   1523     for (BalsaHeaders::const_header_lines_iterator hi =
   1524          headers.header_lines_begin();
   1525          hi != headers.header_lines_end();
   1526          ++hi) {
   1527       FlipHeaderBlock::iterator fhi = dest.find(hi->first.as_string());
   1528       if (fhi == dest.end()) {
   1529         dest[hi->first.as_string()] = hi->second.as_string();
   1530       } else {
   1531         dest[hi->first.as_string()] = (
   1532             string(fhi->second.data(), fhi->second.size()) + "," +
   1533             string(hi->second.data(), hi->second.size()));
   1534       }
   1535     }
   1536 
   1537     // These headers have no value
   1538     dest.erase("X-Associated-Content");  // TODO(mbelshe): case-sensitive
   1539     dest.erase("X-Original-Url");  // TODO(mbelshe): case-sensitive
   1540   }
   1541 
   1542   size_t SendSynStreamImpl(uint32 stream_id, const BalsaHeaders& headers) {
   1543     FlipHeaderBlock block;
   1544     block["method"] = headers.request_method().as_string();
   1545     if (!headers.HasHeader("status"))
   1546       block["status"] = headers.response_code().as_string();
   1547     if (!headers.HasHeader("version"))
   1548       block["version"] =headers.response_version().as_string();
   1549     if (headers.HasHeader("X-Original-Url")) {
   1550       string original_url = headers.GetHeader("X-Original-Url").as_string();
   1551       block["path"] = UrlUtilities::GetUrlPath(original_url);
   1552     } else {
   1553       block["path"] = headers.request_uri().as_string();
   1554     }
   1555     CopyHeaders(block, headers);
   1556 
   1557     FlipSynStreamControlFrame* fsrcf =
   1558       framer_->CreateSynStream(stream_id, 0, CONTROL_FLAG_NONE, true, &block);
   1559     DataFrame df;
   1560     df.size = fsrcf->length() + FlipFrame::size();
   1561     size_t df_size = df.size;
   1562     df.data = fsrcf->data();
   1563     df.delete_when_done = true;
   1564     EnqueueDataFrame(df);
   1565 
   1566     VLOG(2) << "Sending SynStreamheader " << stream_id;
   1567     return df_size;
   1568   }
   1569 
   1570   size_t SendSynReplyImpl(uint32 stream_id, const BalsaHeaders& headers) {
   1571     FlipHeaderBlock block;
   1572     CopyHeaders(block, headers);
   1573     block["status"] = headers.response_code().as_string() + " " +
   1574                       headers.response_reason_phrase().as_string();
   1575     block["version"] = headers.response_version().as_string();
   1576 
   1577     FlipSynReplyControlFrame* fsrcf =
   1578       framer_->CreateSynReply(stream_id, CONTROL_FLAG_NONE, true, &block);
   1579     DataFrame df;
   1580     df.size = fsrcf->length() + FlipFrame::size();
   1581     size_t df_size = df.size;
   1582     df.data = fsrcf->data();
   1583     df.delete_when_done = true;
   1584     EnqueueDataFrame(df);
   1585 
   1586     VLOG(2) << "Sending SynReplyheader " << stream_id;
   1587     return df_size;
   1588   }
   1589 
   1590   void SendDataFrameImpl(uint32 stream_id, const char* data, int64 len,
   1591                          FlipDataFlags flags, bool compress) {
   1592     // Force compression off if disabled via command line.
   1593     if (!FLAGS_use_compression)
   1594       flags = static_cast<FlipDataFlags>(flags & ~DATA_FLAG_COMPRESSED);
   1595 
   1596     // TODO(mbelshe):  We can't compress here - before going into the
   1597     //                 priority queue.  Compression needs to be done
   1598     //                 with late binding.
   1599     FlipDataFrame* fdf = framer_->CreateDataFrame(stream_id, data, len,
   1600                                                   flags);
   1601     DataFrame df;
   1602     df.size = fdf->length() + FlipFrame::size();
   1603     df.data = fdf->data();
   1604     df.delete_when_done = true;
   1605     EnqueueDataFrame(df);
   1606 
   1607     VLOG(2) << "Sending data frame" << stream_id << " [" << len << "]"
   1608             << " shrunk to " << fdf->length();
   1609   }
   1610 
   1611   void EnqueueDataFrame(const DataFrame& df) {
   1612     connection_->EnqueueDataFrame(df);
   1613   }
   1614 
   1615   void GetOutput() {
   1616     while (output_list_->size() < 2) {
   1617       MemCacheIter* mci = output_ordering_.GetIter();
   1618       if (mci == NULL) {
   1619         VLOG(2) << "GetOutput: nothing to output!?";
   1620         return;
   1621       }
   1622       if (!mci->transformed_header) {
   1623         mci->transformed_header = true;
   1624         VLOG(2) << "GetOutput transformed header stream_id: ["
   1625           << mci->stream_id << "]";
   1626         if ((mci->stream_id % 2) == 0) {
   1627           // this is a server initiated stream.
   1628           // Ideally, we'd do a 'syn-push' here, instead of a syn-reply.
   1629           BalsaHeaders headers;
   1630           headers.CopyFrom(*(mci->file_data->headers));
   1631           headers.ReplaceOrAppendHeader("status", "200");
   1632           headers.ReplaceOrAppendHeader("version", "http/1.1");
   1633           headers.SetRequestFirstlineFromStringPieces("PUSH",
   1634                                                       mci->file_data->filename,
   1635                                                       "");
   1636           mci->bytes_sent = SendSynStream(mci->stream_id, headers);
   1637         } else {
   1638           BalsaHeaders headers;
   1639           headers.CopyFrom(*(mci->file_data->headers));
   1640           mci->bytes_sent = SendSynReply(mci->stream_id, headers);
   1641         }
   1642         return;
   1643       }
   1644       if (mci->body_bytes_consumed >= mci->file_data->body.size()) {
   1645         VLOG(2) << "GetOutput remove_stream_id: [" << mci->stream_id << "]";
   1646         SendEOF(mci->stream_id);
   1647         return;
   1648       }
   1649       size_t num_to_write =
   1650         mci->file_data->body.size() - mci->body_bytes_consumed;
   1651       if (num_to_write > mci->max_segment_size)
   1652         num_to_write = mci->max_segment_size;
   1653 
   1654       bool should_compress = false;
   1655       if (!mci->file_data->headers->HasHeader("content-encoding")) {
   1656         if (mci->file_data->headers->HasHeader("content-type")) {
   1657           string content_type =
   1658               mci->file_data->headers->GetHeader("content-type").as_string();
   1659           if (content_type.find("image") == content_type.npos)
   1660             should_compress = true;
   1661         }
   1662       }
   1663 
   1664       SendDataFrame(mci->stream_id,
   1665                     mci->file_data->body.data() + mci->body_bytes_consumed,
   1666                     num_to_write, 0, should_compress);
   1667       VLOG(2) << "GetOutput SendDataFrame[" << mci->stream_id
   1668         << "]: " << num_to_write;
   1669       mci->body_bytes_consumed += num_to_write;
   1670       mci->bytes_sent += num_to_write;
   1671     }
   1672   }
   1673 };
   1674 
   1675 ////////////////////////////////////////////////////////////////////////////////
   1676 
   1677 class HTTPSM : public BalsaVisitorInterface, public SMInterface {
   1678  private:
   1679   uint64 seq_num_;
   1680   BalsaFrame* framer_;
   1681   BalsaHeaders headers_;
   1682   uint32 stream_id_;
   1683 
   1684   SMServerConnection* connection_;
   1685   OutputList* output_list_;
   1686   OutputOrdering output_ordering_;
   1687   MemoryCache* memory_cache_;
   1688  public:
   1689   explicit HTTPSM(SMServerConnection* connection) :
   1690       seq_num_(0),
   1691       framer_(new BalsaFrame),
   1692       stream_id_(1),
   1693       connection_(connection),
   1694       output_list_(connection->output_list()),
   1695       output_ordering_(connection),
   1696       memory_cache_(connection->memory_cache()) {
   1697     framer_->set_balsa_visitor(this);
   1698     framer_->set_balsa_headers(&headers_);
   1699   }
   1700  private:
   1701   typedef map<string, uint32> ClientTokenMap;
   1702  private:
   1703     virtual void ProcessBodyInput(const char *input, size_t size) {
   1704     }
   1705     virtual void ProcessBodyData(const char *input, size_t size) {
   1706       // ignoring this.
   1707     }
   1708     virtual void ProcessHeaderInput(const char *input, size_t size) {
   1709     }
   1710     virtual void ProcessTrailerInput(const char *input, size_t size) {}
   1711     virtual void ProcessHeaders(const BalsaHeaders& headers) {
   1712       VLOG(2) << "Got new request!";
   1713       // requests started with /testing are loadtime measurement related
   1714       // urls, use LoadtimeMeasurement class to handle them.
   1715       if (headers.request_uri().as_string().find("/testing") == 0) {
   1716         string output;
   1717         global_loadtime_measurement.ProcessRequest(
   1718             headers.request_uri().as_string(), output);
   1719         SendOKResponse(stream_id_, &output);
   1720         stream_id_ += 2;
   1721       } else {
   1722         string filename;
   1723         if (FLAGS_need_to_encode_url) {
   1724           filename = net::UrlToFilenameEncoder::Encode(
   1725               headers.GetHeader("Host").as_string() +
   1726               headers.request_uri().as_string(),
   1727               headers.request_method().as_string() + "_/");
   1728         } else {
   1729          filename = headers.request_method().as_string() + "_" +
   1730                     headers.request_uri().as_string();
   1731         }
   1732         NewStream(stream_id_, 0, filename);
   1733         stream_id_ += 2;
   1734       }
   1735     }
   1736     virtual void ProcessRequestFirstLine(const char* line_input,
   1737                                          size_t line_length,
   1738                                          const char* method_input,
   1739                                          size_t method_length,
   1740                                          const char* request_uri_input,
   1741                                          size_t request_uri_length,
   1742                                          const char* version_input,
   1743                                          size_t version_length) {}
   1744     virtual void ProcessResponseFirstLine(const char *line_input,
   1745                                           size_t line_length,
   1746                                           const char *version_input,
   1747                                           size_t version_length,
   1748                                           const char *status_input,
   1749                                           size_t status_length,
   1750                                           const char *reason_input,
   1751                                           size_t reason_length) {}
   1752     virtual void ProcessChunkLength(size_t chunk_length) {}
   1753     virtual void ProcessChunkExtensions(const char *input, size_t size) {}
   1754     virtual void HeaderDone() {}
   1755     virtual void MessageDone() {
   1756       VLOG(2) << "MessageDone!";
   1757     }
   1758     virtual void HandleHeaderError(BalsaFrame* framer) {
   1759       HandleError();
   1760     }
   1761     virtual void HandleHeaderWarning(BalsaFrame* framer) {}
   1762     virtual void HandleChunkingError(BalsaFrame* framer) {
   1763       HandleError();
   1764     }
   1765     virtual void HandleBodyError(BalsaFrame* framer) {
   1766       HandleError();
   1767     }
   1768 
   1769     void HandleError() {
   1770       VLOG(2) << "Error detected";
   1771     }
   1772 
   1773  public:
   1774   ~HTTPSM() {
   1775     Reset();
   1776   }
   1777   size_t ProcessInput(const char* data, size_t len) {
   1778     return framer_->ProcessInput(data, len);
   1779   }
   1780 
   1781   bool MessageFullyRead() const {
   1782     return framer_->MessageFullyRead();
   1783   }
   1784 
   1785   bool Error() const {
   1786     return framer_->Error();
   1787   }
   1788 
   1789   const char* ErrorAsString() const {
   1790     return BalsaFrameEnums::ErrorCodeToString(framer_->ErrorCode());
   1791   }
   1792 
   1793   void Reset() {
   1794     framer_->Reset();
   1795   }
   1796 
   1797   void ResetForNewConnection() {
   1798     seq_num_ = 0;
   1799     output_ordering_.Reset();
   1800     framer_->Reset();
   1801   }
   1802 
   1803   void PostAcceptHook() {
   1804   }
   1805 
   1806   void NewStream(uint32 stream_id, uint32 priority, const string& filename) {
   1807     MemCacheIter mci;
   1808     mci.stream_id = stream_id;
   1809     mci.priority = priority;
   1810     if (!memory_cache_->AssignFileData(filename, &mci)) {
   1811       SendErrorNotFound(stream_id);
   1812     } else {
   1813       AddToOutputOrder(mci);
   1814     }
   1815   }
   1816 
   1817   void AddToOutputOrder(const MemCacheIter& mci) {
   1818     output_ordering_.AddToOutputOrder(mci);
   1819   }
   1820 
   1821   void SendEOF(uint32 stream_id) {
   1822     SendEOFImpl(stream_id);
   1823   }
   1824 
   1825   void SendErrorNotFound(uint32 stream_id) {
   1826     SendErrorNotFoundImpl(stream_id);
   1827   }
   1828 
   1829   void SendOKResponse(uint32 stream_id, string* output) {
   1830     SendOKResponseImpl(stream_id, output);
   1831   }
   1832 
   1833   size_t SendSynStream(uint32 stream_id, const BalsaHeaders& headers) {
   1834     return 0;
   1835   }
   1836 
   1837   size_t SendSynReply(uint32 stream_id, const BalsaHeaders& headers) {
   1838     return SendSynReplyImpl(stream_id, headers);
   1839   }
   1840 
   1841   void SendDataFrame(uint32 stream_id, const char* data, int64 len,
   1842                      uint32 flags, bool compress) {
   1843     SendDataFrameImpl(stream_id, data, len, flags, compress);
   1844   }
   1845 
   1846   BalsaFrame* flip_framer() { return framer_; }
   1847 
   1848  private:
   1849   void SendEOFImpl(uint32 stream_id) {
   1850     DataFrame df;
   1851     df.data = "0\r\n\r\n";
   1852     df.size = 5;
   1853     df.delete_when_done = false;
   1854     EnqueueDataFrame(df);
   1855   }
   1856 
   1857   void SendErrorNotFoundImpl(uint32 stream_id) {
   1858     BalsaHeaders my_headers;
   1859     my_headers.SetFirstlineFromStringPieces("HTTP/1.1", "404", "Not Found");
   1860     my_headers.RemoveAllOfHeader("content-length");
   1861     my_headers.HackHeader("transfer-encoding", "chunked");
   1862     SendSynReplyImpl(stream_id, my_headers);
   1863     SendDataFrame(stream_id, "wtf?", 4, 0, false);
   1864     SendEOFImpl(stream_id);
   1865     output_ordering_.RemoveStreamId(stream_id);
   1866   }
   1867 
   1868   void SendOKResponseImpl(uint32 stream_id, string* output) {
   1869     BalsaHeaders my_headers;
   1870     my_headers.SetFirstlineFromStringPieces("HTTP/1.1", "200", "OK");
   1871     my_headers.RemoveAllOfHeader("content-length");
   1872     my_headers.HackHeader("transfer-encoding", "chunked");
   1873     SendSynReplyImpl(stream_id, my_headers);
   1874     SendDataFrame(stream_id, output->c_str(), output->size(), 0, false);
   1875     SendEOFImpl(stream_id);
   1876     output_ordering_.RemoveStreamId(stream_id);
   1877   }
   1878 
   1879   size_t SendSynReplyImpl(uint32 stream_id, const BalsaHeaders& headers) {
   1880     SimpleBuffer sb;
   1881     headers.WriteHeaderAndEndingToBuffer(&sb);
   1882     DataFrame df;
   1883     df.size = sb.ReadableBytes();
   1884     char* buffer = new char[df.size];
   1885     df.data = buffer;
   1886     df.delete_when_done = true;
   1887     sb.Read(buffer, df.size);
   1888     VLOG(2) << "******************Sending HTTP Reply header " << stream_id;
   1889     size_t df_size = df.size;
   1890     EnqueueDataFrame(df);
   1891     return df_size;
   1892   }
   1893 
   1894   size_t SendSynStreamImpl(uint32 stream_id, const BalsaHeaders& headers) {
   1895     SimpleBuffer sb;
   1896     headers.WriteHeaderAndEndingToBuffer(&sb);
   1897     DataFrame df;
   1898     df.size = sb.ReadableBytes();
   1899     char* buffer = new char[df.size];
   1900     df.data = buffer;
   1901     df.delete_when_done = true;
   1902     sb.Read(buffer, df.size);
   1903     VLOG(2) << "******************Sending HTTP Reply header " << stream_id;
   1904     size_t df_size = df.size;
   1905     EnqueueDataFrame(df);
   1906     return df_size;
   1907   }
   1908 
   1909   void SendDataFrameImpl(uint32 stream_id, const char* data, int64 len,
   1910                          uint32 flags, bool compress) {
   1911     char chunk_buf[128];
   1912     snprintf(chunk_buf, sizeof(chunk_buf), "%x\r\n", (unsigned int)len);
   1913     string chunk_description(chunk_buf);
   1914     DataFrame df;
   1915     df.size = chunk_description.size() + len + 2;
   1916     char* buffer = new char[df.size];
   1917     df.data = buffer;
   1918     df.delete_when_done = true;
   1919     memcpy(buffer, chunk_description.data(), chunk_description.size());
   1920     memcpy(buffer + chunk_description.size(), data, len);
   1921     memcpy(buffer + chunk_description.size() + len, "\r\n", 2);
   1922     EnqueueDataFrame(df);
   1923   }
   1924 
   1925   void EnqueueDataFrame(const DataFrame& df) {
   1926     connection_->EnqueueDataFrame(df);
   1927   }
   1928 
   1929   void GetOutput() {
   1930     MemCacheIter* mci = output_ordering_.GetIter();
   1931     if (mci == NULL) {
   1932       VLOG(2) << "GetOutput: nothing to output!?";
   1933       return;
   1934     }
   1935     if (!mci->transformed_header) {
   1936       mci->bytes_sent = SendSynReply(mci->stream_id,
   1937                                      *(mci->file_data->headers));
   1938       mci->transformed_header = true;
   1939       VLOG(2) << "GetOutput transformed header stream_id: ["
   1940         << mci->stream_id << "]";
   1941       return;
   1942     }
   1943     if (mci->body_bytes_consumed >= mci->file_data->body.size()) {
   1944       SendEOF(mci->stream_id);
   1945       output_ordering_.RemoveStreamId(mci->stream_id);
   1946       VLOG(2) << "GetOutput remove_stream_id: [" << mci->stream_id << "]";
   1947       return;
   1948     }
   1949     size_t num_to_write =
   1950       mci->file_data->body.size() - mci->body_bytes_consumed;
   1951     if (num_to_write > mci->max_segment_size)
   1952       num_to_write = mci->max_segment_size;
   1953     SendDataFrame(mci->stream_id,
   1954                   mci->file_data->body.data() + mci->body_bytes_consumed,
   1955                   num_to_write, 0, true);
   1956     VLOG(2) << "GetOutput SendDataFrame[" << mci->stream_id
   1957       << "]: " << num_to_write;
   1958     mci->body_bytes_consumed += num_to_write;
   1959     mci->bytes_sent += num_to_write;
   1960   }
   1961 };
   1962 
   1963 ////////////////////////////////////////////////////////////////////////////////
   1964 
   1965 class Notification {
   1966  public:
   1967   explicit Notification(bool value) : value_(value) {}
   1968 
   1969   void Notify() {
   1970     AutoLock al(lock_);
   1971     value_ = true;
   1972   }
   1973   bool HasBeenNotified() {
   1974     AutoLock al(lock_);
   1975     return value_;
   1976   }
   1977   bool value_;
   1978   Lock lock_;
   1979 };
   1980 
   1981 ////////////////////////////////////////////////////////////////////////////////
   1982 
   1983 class SMAcceptorThread : public SimpleThread,
   1984                          public EpollCallbackInterface,
   1985                          public SMServerConnectionPoolInterface {
   1986   EpollServer epoll_server_;
   1987   int listen_fd_;
   1988   int accepts_per_wake_;
   1989 
   1990   vector<SMServerConnection*> unused_server_connections_;
   1991   vector<SMServerConnection*> tmp_unused_server_connections_;
   1992   vector<SMServerConnection*> allocated_server_connections_;
   1993   Notification quitting_;
   1994   SMInterfaceFactory* sm_interface_factory_;
   1995   MemoryCache* memory_cache_;
   1996  public:
   1997 
   1998   SMAcceptorThread(int listen_fd,
   1999                    int accepts_per_wake,
   2000                    SMInterfaceFactory* smif,
   2001                    MemoryCache* memory_cache) :
   2002       SimpleThread("SMAcceptorThread"),
   2003       listen_fd_(listen_fd),
   2004       accepts_per_wake_(accepts_per_wake),
   2005       quitting_(false),
   2006       sm_interface_factory_(smif),
   2007       memory_cache_(memory_cache) {
   2008   }
   2009 
   2010   ~SMAcceptorThread() {
   2011     for (vector<SMServerConnection*>::iterator i =
   2012            allocated_server_connections_.begin();
   2013          i != allocated_server_connections_.end();
   2014          ++i) {
   2015       delete *i;
   2016     }
   2017   }
   2018 
   2019   SMServerConnection* NewConnection() {
   2020     SMServerConnection* server =
   2021       SMServerConnection::NewSMServerConnection(sm_interface_factory_,
   2022                                                 memory_cache_,
   2023                                                 &epoll_server_);
   2024     allocated_server_connections_.push_back(server);
   2025     VLOG(3) << "Making new server: " << server;
   2026     return server;
   2027   }
   2028 
   2029   SMServerConnection* FindOrMakeNewSMServerConnection() {
   2030     if (unused_server_connections_.empty()) {
   2031       return NewConnection();
   2032     }
   2033     SMServerConnection* retval = unused_server_connections_.back();
   2034     unused_server_connections_.pop_back();
   2035     return retval;
   2036   }
   2037 
   2038 
   2039   void InitWorker() {
   2040     epoll_server_.RegisterFD(listen_fd_, this, EPOLLIN | EPOLLET);
   2041   }
   2042 
   2043   void HandleConnection(int client_fd) {
   2044     SMServerConnection* server_connection = FindOrMakeNewSMServerConnection();
   2045     if (server_connection == NULL) {
   2046       VLOG(2) << "Closing " << client_fd;
   2047       close(client_fd);
   2048       return;
   2049     }
   2050     server_connection->InitSMServerConnection(this,
   2051                                             &epoll_server_,
   2052                                             client_fd);
   2053   }
   2054 
   2055   void AcceptFromListenFD() {
   2056     if (accepts_per_wake_ > 0) {
   2057       for (int i = 0; i < accepts_per_wake_; ++i) {
   2058         struct sockaddr address;
   2059         socklen_t socklen = sizeof(address);
   2060         int fd = accept(listen_fd_, &address, &socklen);
   2061         if (fd == -1) {
   2062           VLOG(2) << "accept fail(" << listen_fd_ << "): " << errno;
   2063           break;
   2064         }
   2065         VLOG(2) << "********************Accepted fd: " << fd << "\n\n\n";
   2066         HandleConnection(fd);
   2067       }
   2068     } else {
   2069       while (true) {
   2070         struct sockaddr address;
   2071         socklen_t socklen = sizeof(address);
   2072         int fd = accept(listen_fd_, &address, &socklen);
   2073         if (fd == -1) {
   2074           VLOG(2) << "accept fail(" << listen_fd_ << "): " << errno;
   2075           break;
   2076         }
   2077         VLOG(2) << "********************Accepted fd: " << fd << "\n\n\n";
   2078         HandleConnection(fd);
   2079       }
   2080     }
   2081   }
   2082 
   2083   // EpollCallbackInteface virtual functions.
   2084   virtual void OnRegistration(EpollServer* eps, int fd, int event_mask) { }
   2085   virtual void OnModification(int fd, int event_mask) { }
   2086   virtual void OnEvent(int fd, EpollEvent* event) {
   2087     if (event->in_events | EPOLLIN) {
   2088       VLOG(2) << "Accepting based upon epoll events";
   2089       AcceptFromListenFD();
   2090     }
   2091   }
   2092   virtual void OnUnregistration(int fd, bool replaced) { }
   2093   virtual void OnShutdown(EpollServer* eps, int fd) { }
   2094 
   2095   void Quit() {
   2096     quitting_.Notify();
   2097   }
   2098 
   2099   void Run() {
   2100     while (!quitting_.HasBeenNotified()) {
   2101       epoll_server_.set_timeout_in_us(10 * 1000);  // 10 ms
   2102       epoll_server_.WaitForEventsAndExecuteCallbacks();
   2103       unused_server_connections_.insert(unused_server_connections_.end(),
   2104                                         tmp_unused_server_connections_.begin(),
   2105                                         tmp_unused_server_connections_.end());
   2106       tmp_unused_server_connections_.clear();
   2107     }
   2108   }
   2109 
   2110   // SMServerConnections will use this:
   2111   virtual void SMServerConnectionDone(SMServerConnection* sc) {
   2112     VLOG(3) << "Done with server connection: " << sc;
   2113     sc->close_record_fd();
   2114     tmp_unused_server_connections_.push_back(sc);
   2115   }
   2116 };
   2117 
   2118 ////////////////////////////////////////////////////////////////////////////////
   2119 
   2120 SMInterface* NewFlipSM(SMServerConnection* connection) {
   2121   return new FlipSM(connection);
   2122 }
   2123 
   2124 SMInterface* NewHTTPSM(SMServerConnection* connection) {
   2125   return new HTTPSM(connection);
   2126 }
   2127 
   2128 ////////////////////////////////////////////////////////////////////////////////
   2129 
   2130 int CreateListeningSocket(int port, int backlog_size,
   2131                           bool reuseport, bool no_nagle) {
   2132   int listening_socket = 0;
   2133   char port_buf[256];
   2134   snprintf(port_buf, sizeof(port_buf), "%d", port);
   2135   cerr <<" Attempting to listen on port: " << port_buf << "\n";
   2136   cerr <<" input port: " << port << "\n";
   2137   net::CreateListeningSocket("",
   2138                               port_buf,
   2139                               true,
   2140                               backlog_size,
   2141                               &listening_socket,
   2142                               true,
   2143                               reuseport,
   2144                               &cerr);
   2145   SetNonBlocking(listening_socket);
   2146   if (no_nagle) {
   2147     // set SO_REUSEADDR on the listening socket.
   2148     int on = 1;
   2149     int rc;
   2150     rc = setsockopt(listening_socket, IPPROTO_TCP,  TCP_NODELAY,
   2151                     reinterpret_cast<char *>(&on), sizeof(on));
   2152     if (rc < 0) {
   2153       close(listening_socket);
   2154       LOG(FATAL) << "setsockopt() failed fd=" << listening_socket << "\n";
   2155     }
   2156   }
   2157   return listening_socket;
   2158 }
   2159 
   2160 ////////////////////////////////////////////////////////////////////////////////
   2161 
   2162 bool GotQuitFromStdin() {
   2163   // Make stdin nonblocking. Yes this is done each time. Oh well.
   2164   fcntl(0, F_SETFL, O_NONBLOCK);
   2165   char c;
   2166   string maybequit;
   2167   while (read(0, &c, 1) > 0) {
   2168     maybequit += c;
   2169   }
   2170   if (maybequit.size()) {
   2171     VLOG(2) << "scanning string: \"" << maybequit << "\"";
   2172   }
   2173   return (maybequit.size() > 1 &&
   2174           (maybequit.c_str()[0] == 'q' ||
   2175            maybequit.c_str()[0] == 'Q'));
   2176 }
   2177 
   2178 
   2179 ////////////////////////////////////////////////////////////////////////////////
   2180 
   2181 const char* BoolToStr(bool b) {
   2182   if (b)
   2183     return "true";
   2184   return "false";
   2185 }
   2186 
   2187 ////////////////////////////////////////////////////////////////////////////////
   2188 
   2189 int main(int argc, char**argv) {
   2190   bool use_ssl = FLAGS_use_ssl;
   2191   int response_count_until_close = FLAGS_response_count_until_close;
   2192   int flip_port = FLAGS_flip_port;
   2193   int port = FLAGS_port;
   2194   int backlog_size = FLAGS_accept_backlog_size;
   2195   bool reuseport = FLAGS_reuseport;
   2196   bool no_nagle = FLAGS_no_nagle;
   2197   double server_think_time_in_s = FLAGS_server_think_time_in_s;
   2198   int accepts_per_wake = FLAGS_accepts_per_wake;
   2199   int num_threads = 1;
   2200 
   2201   MemoryCache flip_memory_cache;
   2202   flip_memory_cache.AddFiles();
   2203 
   2204   MemoryCache http_memory_cache;
   2205   http_memory_cache.CloneFrom(flip_memory_cache);
   2206 
   2207   LOG(INFO) <<
   2208     "Starting up with the following state: \n"
   2209     "                      use_ssl: " << use_ssl << "\n"
   2210     "   response_count_until_close: " << response_count_until_close << "\n"
   2211     "                         port: " << port << "\n"
   2212     "                    flip_port: " << flip_port << "\n"
   2213     "                 backlog_size: " << backlog_size << "\n"
   2214     "                    reuseport: " << BoolToStr(reuseport) << "\n"
   2215     "                     no_nagle: " << BoolToStr(no_nagle) << "\n"
   2216     "       server_think_time_in_s: " << server_think_time_in_s << "\n"
   2217     "             accepts_per_wake: " << accepts_per_wake << "\n"
   2218     "                  num_threads: " << num_threads << "\n"
   2219     "                     use_xsub: " << BoolToStr(FLAGS_use_xsub) << "\n"
   2220     "                      use_xac: " << BoolToStr(FLAGS_use_xac) << "\n";
   2221 
   2222   if (use_ssl) {
   2223     global_ssl_state = new GlobalSSLState;
   2224     flip_init_ssl(global_ssl_state);
   2225   } else {
   2226     global_ssl_state = NULL;
   2227   }
   2228   EpollServer epoll_server;
   2229   vector<SMAcceptorThread*> sm_worker_threads_;
   2230 
   2231   {
   2232     // flip
   2233     int listen_fd = -1;
   2234 
   2235     if (reuseport || listen_fd == -1) {
   2236       listen_fd = CreateListeningSocket(flip_port, backlog_size,
   2237                                         reuseport, no_nagle);
   2238       if (listen_fd < 0) {
   2239         LOG(FATAL) << "Unable to open listening socket on flip_port: "
   2240           << flip_port;
   2241       } else {
   2242         LOG(INFO) << "Listening for flip on port: " << flip_port;
   2243       }
   2244     }
   2245     sm_worker_threads_.push_back(
   2246         new SMAcceptorThread(listen_fd,
   2247                              accepts_per_wake,
   2248                              &NewFlipSM,
   2249                              &flip_memory_cache));
   2250     // Note that flip_memory_cache is not threadsafe, it is merely
   2251     // thread compatible. Thus, if ever we are to spawn multiple threads,
   2252     // we either must make the MemoryCache threadsafe, or use
   2253     // a separate MemoryCache for each thread.
   2254     //
   2255     // The latter is what is currently being done as we spawn
   2256     // two threads (one for flip, one for http).
   2257     sm_worker_threads_.back()->InitWorker();
   2258     sm_worker_threads_.back()->Start();
   2259   }
   2260 
   2261   {
   2262     // http
   2263     int listen_fd = -1;
   2264     if (reuseport || listen_fd == -1) {
   2265       listen_fd = CreateListeningSocket(port, backlog_size,
   2266                                         reuseport, no_nagle);
   2267       if (listen_fd < 0) {
   2268         LOG(FATAL) << "Unable to open listening socket on port: " << port;
   2269       } else {
   2270         LOG(INFO) << "Listening for HTTP on port: " << port;
   2271       }
   2272     }
   2273     sm_worker_threads_.push_back(
   2274         new SMAcceptorThread(listen_fd,
   2275                              accepts_per_wake,
   2276                              &NewHTTPSM,
   2277                              &http_memory_cache));
   2278     // Note that flip_memory_cache is not threadsafe, it is merely
   2279     // thread compatible. Thus, if ever we are to spawn multiple threads,
   2280     // we either must make the MemoryCache threadsafe, or use
   2281     // a separate MemoryCache for each thread.
   2282     //
   2283     // The latter is what is currently being done as we spawn
   2284     // two threads (one for flip, one for http).
   2285     sm_worker_threads_.back()->InitWorker();
   2286     sm_worker_threads_.back()->Start();
   2287   }
   2288 
   2289   while (true) {
   2290     if (GotQuitFromStdin()) {
   2291       for (unsigned int i = 0; i < sm_worker_threads_.size(); ++i) {
   2292         sm_worker_threads_[i]->Quit();
   2293       }
   2294       for (unsigned int i = 0; i < sm_worker_threads_.size(); ++i) {
   2295         sm_worker_threads_[i]->Join();
   2296       }
   2297       return 0;
   2298     }
   2299     usleep(1000*10);  // 10 ms
   2300   }
   2301   return 0;
   2302 }
   2303 
   2304