Home | History | Annotate | Download | only in websockets
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/websockets/websocket_deflate_stream.h"
      6 
      7 #include <algorithm>
      8 #include <string>
      9 
     10 #include "base/bind.h"
     11 #include "base/logging.h"
     12 #include "base/memory/ref_counted.h"
     13 #include "base/memory/scoped_ptr.h"
     14 #include "base/memory/scoped_vector.h"
     15 #include "net/base/completion_callback.h"
     16 #include "net/base/io_buffer.h"
     17 #include "net/base/net_errors.h"
     18 #include "net/websockets/websocket_deflate_predictor.h"
     19 #include "net/websockets/websocket_deflater.h"
     20 #include "net/websockets/websocket_errors.h"
     21 #include "net/websockets/websocket_frame.h"
     22 #include "net/websockets/websocket_inflater.h"
     23 #include "net/websockets/websocket_stream.h"
     24 
     25 class GURL;
     26 
     27 namespace net {
     28 
     29 namespace {
     30 
     31 const int kWindowBits = 15;
     32 const size_t kChunkSize = 4 * 1024;
     33 
     34 }  // namespace
     35 
     36 WebSocketDeflateStream::WebSocketDeflateStream(
     37     scoped_ptr<WebSocketStream> stream,
     38     WebSocketDeflater::ContextTakeOverMode mode,
     39     int client_window_bits,
     40     scoped_ptr<WebSocketDeflatePredictor> predictor)
     41     : stream_(stream.Pass()),
     42       deflater_(mode),
     43       inflater_(kChunkSize, kChunkSize),
     44       reading_state_(NOT_READING),
     45       writing_state_(NOT_WRITING),
     46       current_reading_opcode_(WebSocketFrameHeader::kOpCodeText),
     47       current_writing_opcode_(WebSocketFrameHeader::kOpCodeText),
     48       predictor_(predictor.Pass()) {
     49   DCHECK(stream_);
     50   DCHECK_GE(client_window_bits, 8);
     51   DCHECK_LE(client_window_bits, 15);
     52   deflater_.Initialize(client_window_bits);
     53   inflater_.Initialize(kWindowBits);
     54 }
     55 
     56 WebSocketDeflateStream::~WebSocketDeflateStream() {}
     57 
     58 int WebSocketDeflateStream::ReadFrames(ScopedVector<WebSocketFrame>* frames,
     59                                        const CompletionCallback& callback) {
     60   int result = stream_->ReadFrames(
     61       frames,
     62       base::Bind(&WebSocketDeflateStream::OnReadComplete,
     63                  base::Unretained(this),
     64                  base::Unretained(frames),
     65                  callback));
     66   if (result < 0)
     67     return result;
     68   DCHECK_EQ(OK, result);
     69   DCHECK(!frames->empty());
     70 
     71   return InflateAndReadIfNecessary(frames, callback);
     72 }
     73 
     74 int WebSocketDeflateStream::WriteFrames(ScopedVector<WebSocketFrame>* frames,
     75                                         const CompletionCallback& callback) {
     76   int result = Deflate(frames);
     77   if (result != OK)
     78     return result;
     79   if (frames->empty())
     80     return OK;
     81   return stream_->WriteFrames(frames, callback);
     82 }
     83 
     84 void WebSocketDeflateStream::Close() { stream_->Close(); }
     85 
     86 std::string WebSocketDeflateStream::GetSubProtocol() const {
     87   return stream_->GetSubProtocol();
     88 }
     89 
     90 std::string WebSocketDeflateStream::GetExtensions() const {
     91   return stream_->GetExtensions();
     92 }
     93 
     94 void WebSocketDeflateStream::OnReadComplete(
     95     ScopedVector<WebSocketFrame>* frames,
     96     const CompletionCallback& callback,
     97     int result) {
     98   if (result != OK) {
     99     frames->clear();
    100     callback.Run(result);
    101     return;
    102   }
    103 
    104   int r = InflateAndReadIfNecessary(frames, callback);
    105   if (r != ERR_IO_PENDING)
    106     callback.Run(r);
    107 }
    108 
    109 int WebSocketDeflateStream::Deflate(ScopedVector<WebSocketFrame>* frames) {
    110   ScopedVector<WebSocketFrame> frames_to_write;
    111   // Store frames of the currently processed message if writing_state_ equals to
    112   // WRITING_POSSIBLY_COMPRESSED_MESSAGE.
    113   ScopedVector<WebSocketFrame> frames_of_message;
    114   for (size_t i = 0; i < frames->size(); ++i) {
    115     DCHECK(!(*frames)[i]->header.reserved1);
    116     if (!WebSocketFrameHeader::IsKnownDataOpCode((*frames)[i]->header.opcode)) {
    117       frames_to_write.push_back((*frames)[i]);
    118       (*frames)[i] = NULL;
    119       continue;
    120     }
    121     if (writing_state_ == NOT_WRITING)
    122       OnMessageStart(*frames, i);
    123 
    124     scoped_ptr<WebSocketFrame> frame((*frames)[i]);
    125     (*frames)[i] = NULL;
    126     predictor_->RecordInputDataFrame(frame.get());
    127 
    128     if (writing_state_ == WRITING_UNCOMPRESSED_MESSAGE) {
    129       if (frame->header.final)
    130         writing_state_ = NOT_WRITING;
    131       predictor_->RecordWrittenDataFrame(frame.get());
    132       frames_to_write.push_back(frame.release());
    133       current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
    134     } else {
    135       if (frame->data.get() &&
    136           !deflater_.AddBytes(frame->data->data(),
    137                               frame->header.payload_length)) {
    138         DVLOG(1) << "WebSocket protocol error. "
    139                  << "deflater_.AddBytes() returns an error.";
    140         return ERR_WS_PROTOCOL_ERROR;
    141       }
    142       if (frame->header.final && !deflater_.Finish()) {
    143         DVLOG(1) << "WebSocket protocol error. "
    144                  << "deflater_.Finish() returns an error.";
    145         return ERR_WS_PROTOCOL_ERROR;
    146       }
    147 
    148       if (writing_state_ == WRITING_COMPRESSED_MESSAGE) {
    149         if (deflater_.CurrentOutputSize() >= kChunkSize ||
    150             frame->header.final) {
    151           int result = AppendCompressedFrame(frame->header, &frames_to_write);
    152           if (result != OK)
    153             return result;
    154         }
    155         if (frame->header.final)
    156           writing_state_ = NOT_WRITING;
    157       } else {
    158         DCHECK_EQ(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
    159         bool final = frame->header.final;
    160         frames_of_message.push_back(frame.release());
    161         if (final) {
    162           int result = AppendPossiblyCompressedMessage(&frames_of_message,
    163                                                        &frames_to_write);
    164           if (result != OK)
    165             return result;
    166           frames_of_message.clear();
    167           writing_state_ = NOT_WRITING;
    168         }
    169       }
    170     }
    171   }
    172   DCHECK_NE(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
    173   frames->swap(frames_to_write);
    174   return OK;
    175 }
    176 
    177 void WebSocketDeflateStream::OnMessageStart(
    178     const ScopedVector<WebSocketFrame>& frames, size_t index) {
    179   WebSocketFrame* frame = frames[index];
    180   current_writing_opcode_ = frame->header.opcode;
    181   DCHECK(current_writing_opcode_ == WebSocketFrameHeader::kOpCodeText ||
    182          current_writing_opcode_ == WebSocketFrameHeader::kOpCodeBinary);
    183   WebSocketDeflatePredictor::Result prediction =
    184       predictor_->Predict(frames, index);
    185 
    186   switch (prediction) {
    187     case WebSocketDeflatePredictor::DEFLATE:
    188       writing_state_ = WRITING_COMPRESSED_MESSAGE;
    189       return;
    190     case WebSocketDeflatePredictor::DO_NOT_DEFLATE:
    191       writing_state_ = WRITING_UNCOMPRESSED_MESSAGE;
    192       return;
    193     case WebSocketDeflatePredictor::TRY_DEFLATE:
    194       writing_state_ = WRITING_POSSIBLY_COMPRESSED_MESSAGE;
    195       return;
    196   }
    197   NOTREACHED();
    198 }
    199 
    200 int WebSocketDeflateStream::AppendCompressedFrame(
    201     const WebSocketFrameHeader& header,
    202     ScopedVector<WebSocketFrame>* frames_to_write) {
    203   const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
    204   scoped_refptr<IOBufferWithSize> compressed_payload =
    205       deflater_.GetOutput(deflater_.CurrentOutputSize());
    206   if (!compressed_payload.get()) {
    207     DVLOG(1) << "WebSocket protocol error. "
    208              << "deflater_.GetOutput() returns an error.";
    209     return ERR_WS_PROTOCOL_ERROR;
    210   }
    211   scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode));
    212   compressed->header.CopyFrom(header);
    213   compressed->header.opcode = opcode;
    214   compressed->header.final = header.final;
    215   compressed->header.reserved1 =
    216       (opcode != WebSocketFrameHeader::kOpCodeContinuation);
    217   compressed->data = compressed_payload;
    218   compressed->header.payload_length = compressed_payload->size();
    219 
    220   current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
    221   predictor_->RecordWrittenDataFrame(compressed.get());
    222   frames_to_write->push_back(compressed.release());
    223   return OK;
    224 }
    225 
    226 int WebSocketDeflateStream::AppendPossiblyCompressedMessage(
    227     ScopedVector<WebSocketFrame>* frames,
    228     ScopedVector<WebSocketFrame>* frames_to_write) {
    229   DCHECK(!frames->empty());
    230 
    231   const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
    232   scoped_refptr<IOBufferWithSize> compressed_payload =
    233       deflater_.GetOutput(deflater_.CurrentOutputSize());
    234   if (!compressed_payload.get()) {
    235     DVLOG(1) << "WebSocket protocol error. "
    236              << "deflater_.GetOutput() returns an error.";
    237     return ERR_WS_PROTOCOL_ERROR;
    238   }
    239 
    240   uint64 original_payload_length = 0;
    241   for (size_t i = 0; i < frames->size(); ++i) {
    242     WebSocketFrame* frame = (*frames)[i];
    243     // Asserts checking that frames represent one whole data message.
    244     DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode));
    245     DCHECK_EQ(i == 0,
    246               WebSocketFrameHeader::kOpCodeContinuation !=
    247               frame->header.opcode);
    248     DCHECK_EQ(i == frames->size() - 1, frame->header.final);
    249     original_payload_length += frame->header.payload_length;
    250   }
    251   if (original_payload_length <=
    252       static_cast<uint64>(compressed_payload->size())) {
    253     // Compression is not effective. Use the original frames.
    254     for (size_t i = 0; i < frames->size(); ++i) {
    255       WebSocketFrame* frame = (*frames)[i];
    256       frames_to_write->push_back(frame);
    257       predictor_->RecordWrittenDataFrame(frame);
    258       (*frames)[i] = NULL;
    259     }
    260     frames->weak_clear();
    261     return OK;
    262   }
    263   scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode));
    264   compressed->header.CopyFrom((*frames)[0]->header);
    265   compressed->header.opcode = opcode;
    266   compressed->header.final = true;
    267   compressed->header.reserved1 = true;
    268   compressed->data = compressed_payload;
    269   compressed->header.payload_length = compressed_payload->size();
    270 
    271   predictor_->RecordWrittenDataFrame(compressed.get());
    272   frames_to_write->push_back(compressed.release());
    273   return OK;
    274 }
    275 
    276 int WebSocketDeflateStream::Inflate(ScopedVector<WebSocketFrame>* frames) {
    277   ScopedVector<WebSocketFrame> frames_to_output;
    278   ScopedVector<WebSocketFrame> frames_passed;
    279   frames->swap(frames_passed);
    280   for (size_t i = 0; i < frames_passed.size(); ++i) {
    281     scoped_ptr<WebSocketFrame> frame(frames_passed[i]);
    282     frames_passed[i] = NULL;
    283     DVLOG(3) << "Input frame: opcode=" << frame->header.opcode
    284              << " final=" << frame->header.final
    285              << " reserved1=" << frame->header.reserved1
    286              << " payload_length=" << frame->header.payload_length;
    287 
    288     if (!WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)) {
    289       frames_to_output.push_back(frame.release());
    290       continue;
    291     }
    292 
    293     if (reading_state_ == NOT_READING) {
    294       if (frame->header.reserved1)
    295         reading_state_ = READING_COMPRESSED_MESSAGE;
    296       else
    297         reading_state_ = READING_UNCOMPRESSED_MESSAGE;
    298       current_reading_opcode_ = frame->header.opcode;
    299     } else {
    300       if (frame->header.reserved1) {
    301         DVLOG(1) << "WebSocket protocol error. "
    302                  << "Receiving a non-first frame with RSV1 flag set.";
    303         return ERR_WS_PROTOCOL_ERROR;
    304       }
    305     }
    306 
    307     if (reading_state_ == READING_UNCOMPRESSED_MESSAGE) {
    308       if (frame->header.final)
    309         reading_state_ = NOT_READING;
    310       current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
    311       frames_to_output.push_back(frame.release());
    312     } else {
    313       DCHECK_EQ(reading_state_, READING_COMPRESSED_MESSAGE);
    314       if (frame->data.get() &&
    315           !inflater_.AddBytes(frame->data->data(),
    316                               frame->header.payload_length)) {
    317         DVLOG(1) << "WebSocket protocol error. "
    318                  << "inflater_.AddBytes() returns an error.";
    319         return ERR_WS_PROTOCOL_ERROR;
    320       }
    321       if (frame->header.final) {
    322         if (!inflater_.Finish()) {
    323           DVLOG(1) << "WebSocket protocol error. "
    324                    << "inflater_.Finish() returns an error.";
    325           return ERR_WS_PROTOCOL_ERROR;
    326         }
    327       }
    328       // TODO(yhirano): Many frames can be generated by the inflater and
    329       // memory consumption can grow.
    330       // We could avoid it, but avoiding it makes this class much more
    331       // complicated.
    332       while (inflater_.CurrentOutputSize() >= kChunkSize ||
    333              frame->header.final) {
    334         size_t size = std::min(kChunkSize, inflater_.CurrentOutputSize());
    335         scoped_ptr<WebSocketFrame> inflated(
    336             new WebSocketFrame(WebSocketFrameHeader::kOpCodeText));
    337         scoped_refptr<IOBufferWithSize> data = inflater_.GetOutput(size);
    338         bool is_final = !inflater_.CurrentOutputSize() && frame->header.final;
    339         if (!data.get()) {
    340           DVLOG(1) << "WebSocket protocol error. "
    341                    << "inflater_.GetOutput() returns an error.";
    342           return ERR_WS_PROTOCOL_ERROR;
    343         }
    344         inflated->header.CopyFrom(frame->header);
    345         inflated->header.opcode = current_reading_opcode_;
    346         inflated->header.final = is_final;
    347         inflated->header.reserved1 = false;
    348         inflated->data = data;
    349         inflated->header.payload_length = data->size();
    350         DVLOG(3) << "Inflated frame: opcode=" << inflated->header.opcode
    351                  << " final=" << inflated->header.final
    352                  << " reserved1=" << inflated->header.reserved1
    353                  << " payload_length=" << inflated->header.payload_length;
    354         frames_to_output.push_back(inflated.release());
    355         current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
    356         if (is_final)
    357           break;
    358       }
    359       if (frame->header.final)
    360         reading_state_ = NOT_READING;
    361     }
    362   }
    363   frames->swap(frames_to_output);
    364   return frames->empty() ? ERR_IO_PENDING : OK;
    365 }
    366 
    367 int WebSocketDeflateStream::InflateAndReadIfNecessary(
    368     ScopedVector<WebSocketFrame>* frames,
    369     const CompletionCallback& callback) {
    370   int result = Inflate(frames);
    371   while (result == ERR_IO_PENDING) {
    372     DCHECK(frames->empty());
    373 
    374     result = stream_->ReadFrames(
    375         frames,
    376         base::Bind(&WebSocketDeflateStream::OnReadComplete,
    377                    base::Unretained(this),
    378                    base::Unretained(frames),
    379                    callback));
    380     if (result < 0)
    381       break;
    382     DCHECK_EQ(OK, result);
    383     DCHECK(!frames->empty());
    384 
    385     result = Inflate(frames);
    386   }
    387   if (result < 0)
    388     frames->clear();
    389   return result;
    390 }
    391 
    392 }  // namespace net
    393