Home | History | Annotate | Download | only in websockets
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/websockets/websocket_deflate_stream.h"
      6 
      7 #include <algorithm>
      8 #include <string>
      9 
     10 #include "base/bind.h"
     11 #include "base/logging.h"
     12 #include "base/memory/ref_counted.h"
     13 #include "base/memory/scoped_ptr.h"
     14 #include "base/memory/scoped_vector.h"
     15 #include "net/base/completion_callback.h"
     16 #include "net/base/io_buffer.h"
     17 #include "net/base/net_errors.h"
     18 #include "net/websockets/websocket_deflate_predictor.h"
     19 #include "net/websockets/websocket_deflater.h"
     20 #include "net/websockets/websocket_errors.h"
     21 #include "net/websockets/websocket_frame.h"
     22 #include "net/websockets/websocket_inflater.h"
     23 #include "net/websockets/websocket_stream.h"
     24 
     25 class GURL;
     26 
     27 namespace net {
     28 
     29 namespace {
     30 
     31 const int kWindowBits = 15;
     32 const size_t kChunkSize = 4 * 1024;
     33 
     34 }  // namespace
     35 
     36 WebSocketDeflateStream::WebSocketDeflateStream(
     37     scoped_ptr<WebSocketStream> stream,
     38     WebSocketDeflater::ContextTakeOverMode mode,
     39     scoped_ptr<WebSocketDeflatePredictor> predictor)
     40     : stream_(stream.Pass()),
     41       deflater_(mode),
     42       inflater_(kChunkSize, kChunkSize),
     43       reading_state_(NOT_READING),
     44       writing_state_(NOT_WRITING),
     45       current_reading_opcode_(WebSocketFrameHeader::kOpCodeText),
     46       current_writing_opcode_(WebSocketFrameHeader::kOpCodeText),
     47       predictor_(predictor.Pass()) {
     48   DCHECK(stream_);
     49   deflater_.Initialize(kWindowBits);
     50   inflater_.Initialize(kWindowBits);
     51 }
     52 
     53 WebSocketDeflateStream::~WebSocketDeflateStream() {}
     54 
     55 int WebSocketDeflateStream::ReadFrames(ScopedVector<WebSocketFrame>* frames,
     56                                        const CompletionCallback& callback) {
     57   CompletionCallback callback_to_pass =
     58       base::Bind(&WebSocketDeflateStream::OnReadComplete,
     59                  base::Unretained(this),
     60                  base::Unretained(frames),
     61                  callback);
     62   int result = stream_->ReadFrames(frames, callback_to_pass);
     63   if (result < 0)
     64     return result;
     65   DCHECK_EQ(OK, result);
     66   return InflateAndReadIfNecessary(frames, callback_to_pass);
     67 }
     68 
     69 int WebSocketDeflateStream::WriteFrames(ScopedVector<WebSocketFrame>* frames,
     70                                         const CompletionCallback& callback) {
     71   int result = Deflate(frames);
     72   if (result != OK)
     73     return result;
     74   if (frames->empty())
     75     return OK;
     76   return stream_->WriteFrames(frames, callback);
     77 }
     78 
     79 void WebSocketDeflateStream::Close() { stream_->Close(); }
     80 
     81 std::string WebSocketDeflateStream::GetSubProtocol() const {
     82   return stream_->GetSubProtocol();
     83 }
     84 
     85 std::string WebSocketDeflateStream::GetExtensions() const {
     86   return stream_->GetExtensions();
     87 }
     88 
     89 void WebSocketDeflateStream::OnReadComplete(
     90     ScopedVector<WebSocketFrame>* frames,
     91     const CompletionCallback& callback,
     92     int result) {
     93   if (result != OK) {
     94     frames->clear();
     95     callback.Run(result);
     96     return;
     97   }
     98 
     99   int r = InflateAndReadIfNecessary(frames, callback);
    100   if (r != ERR_IO_PENDING)
    101     callback.Run(r);
    102 }
    103 
    104 int WebSocketDeflateStream::Deflate(ScopedVector<WebSocketFrame>* frames) {
    105   ScopedVector<WebSocketFrame> frames_to_write;
    106   // Store frames of the currently processed message if writing_state_ equals to
    107   // WRITING_POSSIBLY_COMPRESSED_MESSAGE.
    108   ScopedVector<WebSocketFrame> frames_of_message;
    109   for (size_t i = 0; i < frames->size(); ++i) {
    110     DCHECK(!(*frames)[i]->header.reserved1);
    111     if (!WebSocketFrameHeader::IsKnownDataOpCode((*frames)[i]->header.opcode)) {
    112       frames_to_write.push_back((*frames)[i]);
    113       (*frames)[i] = NULL;
    114       continue;
    115     }
    116     if (writing_state_ == NOT_WRITING)
    117       OnMessageStart(*frames, i);
    118 
    119     scoped_ptr<WebSocketFrame> frame((*frames)[i]);
    120     (*frames)[i] = NULL;
    121     predictor_->RecordInputDataFrame(frame.get());
    122 
    123     if (writing_state_ == WRITING_UNCOMPRESSED_MESSAGE) {
    124       if (frame->header.final)
    125         writing_state_ = NOT_WRITING;
    126       predictor_->RecordWrittenDataFrame(frame.get());
    127       frames_to_write.push_back(frame.release());
    128       current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
    129     } else {
    130       if (frame->data && !deflater_.AddBytes(frame->data->data(),
    131                                              frame->header.payload_length)) {
    132         DVLOG(1) << "WebSocket protocol error. "
    133                  << "deflater_.AddBytes() returns an error.";
    134         return ERR_WS_PROTOCOL_ERROR;
    135       }
    136       if (frame->header.final && !deflater_.Finish()) {
    137         DVLOG(1) << "WebSocket protocol error. "
    138                  << "deflater_.Finish() returns an error.";
    139         return ERR_WS_PROTOCOL_ERROR;
    140       }
    141 
    142       if (writing_state_ == WRITING_COMPRESSED_MESSAGE) {
    143         if (deflater_.CurrentOutputSize() >= kChunkSize ||
    144             frame->header.final) {
    145           int result = AppendCompressedFrame(frame->header, &frames_to_write);
    146           if (result != OK)
    147             return result;
    148         }
    149         if (frame->header.final)
    150           writing_state_ = NOT_WRITING;
    151       } else {
    152         DCHECK_EQ(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
    153         bool final = frame->header.final;
    154         frames_of_message.push_back(frame.release());
    155         if (final) {
    156           int result = AppendPossiblyCompressedMessage(&frames_of_message,
    157                                                        &frames_to_write);
    158           if (result != OK)
    159             return result;
    160           frames_of_message.clear();
    161           writing_state_ = NOT_WRITING;
    162         }
    163       }
    164     }
    165   }
    166   DCHECK_NE(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
    167   frames->swap(frames_to_write);
    168   return OK;
    169 }
    170 
    171 void WebSocketDeflateStream::OnMessageStart(
    172     const ScopedVector<WebSocketFrame>& frames, size_t index) {
    173   WebSocketFrame* frame = frames[index];
    174   current_writing_opcode_ = frame->header.opcode;
    175   DCHECK(current_writing_opcode_ == WebSocketFrameHeader::kOpCodeText ||
    176          current_writing_opcode_ == WebSocketFrameHeader::kOpCodeBinary);
    177   WebSocketDeflatePredictor::Result prediction =
    178       predictor_->Predict(frames, index);
    179 
    180   switch (prediction) {
    181     case WebSocketDeflatePredictor::DEFLATE:
    182       writing_state_ = WRITING_COMPRESSED_MESSAGE;
    183       return;
    184     case WebSocketDeflatePredictor::DO_NOT_DEFLATE:
    185       writing_state_ = WRITING_UNCOMPRESSED_MESSAGE;
    186       return;
    187     case WebSocketDeflatePredictor::TRY_DEFLATE:
    188       writing_state_ = WRITING_POSSIBLY_COMPRESSED_MESSAGE;
    189       return;
    190   }
    191   NOTREACHED();
    192 }
    193 
    194 int WebSocketDeflateStream::AppendCompressedFrame(
    195     const WebSocketFrameHeader& header,
    196     ScopedVector<WebSocketFrame>* frames_to_write) {
    197   const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
    198   scoped_refptr<IOBufferWithSize> compressed_payload =
    199       deflater_.GetOutput(deflater_.CurrentOutputSize());
    200   if (!compressed_payload) {
    201     DVLOG(1) << "WebSocket protocol error. "
    202              << "deflater_.GetOutput() returns an error.";
    203     return ERR_WS_PROTOCOL_ERROR;
    204   }
    205   scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode));
    206   compressed->header.CopyFrom(header);
    207   compressed->header.opcode = opcode;
    208   compressed->header.final = header.final;
    209   compressed->header.reserved1 =
    210       (opcode != WebSocketFrameHeader::kOpCodeContinuation);
    211   compressed->data = compressed_payload;
    212   compressed->header.payload_length = compressed_payload->size();
    213 
    214   current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
    215   predictor_->RecordWrittenDataFrame(compressed.get());
    216   frames_to_write->push_back(compressed.release());
    217   return OK;
    218 }
    219 
    220 int WebSocketDeflateStream::AppendPossiblyCompressedMessage(
    221     ScopedVector<WebSocketFrame>* frames,
    222     ScopedVector<WebSocketFrame>* frames_to_write) {
    223   DCHECK(!frames->empty());
    224 
    225   const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
    226   scoped_refptr<IOBufferWithSize> compressed_payload =
    227       deflater_.GetOutput(deflater_.CurrentOutputSize());
    228   if (!compressed_payload) {
    229     DVLOG(1) << "WebSocket protocol error. "
    230              << "deflater_.GetOutput() returns an error.";
    231     return ERR_WS_PROTOCOL_ERROR;
    232   }
    233 
    234   uint64 original_payload_length = 0;
    235   for (size_t i = 0; i < frames->size(); ++i) {
    236     WebSocketFrame* frame = (*frames)[i];
    237     // Asserts checking that frames represent one whole data message.
    238     DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode));
    239     DCHECK_EQ(i == 0,
    240               WebSocketFrameHeader::kOpCodeContinuation !=
    241               frame->header.opcode);
    242     DCHECK_EQ(i == frames->size() - 1, frame->header.final);
    243     original_payload_length += frame->header.payload_length;
    244   }
    245   if (original_payload_length <=
    246       static_cast<uint64>(compressed_payload->size())) {
    247     // Compression is not effective. Use the original frames.
    248     for (size_t i = 0; i < frames->size(); ++i) {
    249       WebSocketFrame* frame = (*frames)[i];
    250       frames_to_write->push_back(frame);
    251       predictor_->RecordWrittenDataFrame(frame);
    252       (*frames)[i] = NULL;
    253     }
    254     frames->weak_clear();
    255     return OK;
    256   }
    257   scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode));
    258   compressed->header.CopyFrom((*frames)[0]->header);
    259   compressed->header.opcode = opcode;
    260   compressed->header.final = true;
    261   compressed->header.reserved1 = true;
    262   compressed->data = compressed_payload;
    263   compressed->header.payload_length = compressed_payload->size();
    264 
    265   predictor_->RecordWrittenDataFrame(compressed.get());
    266   frames_to_write->push_back(compressed.release());
    267   return OK;
    268 }
    269 
    270 int WebSocketDeflateStream::Inflate(ScopedVector<WebSocketFrame>* frames) {
    271   ScopedVector<WebSocketFrame> frames_to_output;
    272   ScopedVector<WebSocketFrame> frames_passed;
    273   frames->swap(frames_passed);
    274   for (size_t i = 0; i < frames_passed.size(); ++i) {
    275     scoped_ptr<WebSocketFrame> frame(frames_passed[i]);
    276     frames_passed[i] = NULL;
    277     if (!WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)) {
    278       frames_to_output.push_back(frame.release());
    279       continue;
    280     }
    281 
    282     if (reading_state_ == NOT_READING) {
    283       if (frame->header.reserved1)
    284         reading_state_ = READING_COMPRESSED_MESSAGE;
    285       else
    286         reading_state_ = READING_UNCOMPRESSED_MESSAGE;
    287       current_reading_opcode_ = frame->header.opcode;
    288     } else {
    289       if (frame->header.reserved1) {
    290         DVLOG(1) << "WebSocket protocol error. "
    291                  << "Receiving a non-first frame with RSV1 flag set.";
    292         return ERR_WS_PROTOCOL_ERROR;
    293       }
    294     }
    295 
    296     if (reading_state_ == READING_UNCOMPRESSED_MESSAGE) {
    297       if (frame->header.final)
    298         reading_state_ = NOT_READING;
    299       current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
    300       frames_to_output.push_back(frame.release());
    301     } else {
    302       DCHECK_EQ(reading_state_, READING_COMPRESSED_MESSAGE);
    303       if (frame->data && !inflater_.AddBytes(frame->data->data(),
    304                                              frame->header.payload_length)) {
    305         DVLOG(1) << "WebSocket protocol error. "
    306                  << "inflater_.AddBytes() returns an error.";
    307         return ERR_WS_PROTOCOL_ERROR;
    308       }
    309       if (frame->header.final) {
    310         if (!inflater_.Finish()) {
    311           DVLOG(1) << "WebSocket protocol error. "
    312                    << "inflater_.Finish() returns an error.";
    313           return ERR_WS_PROTOCOL_ERROR;
    314         }
    315       }
    316       // TODO(yhirano): Many frames can be generated by the inflater and
    317       // memory consumption can grow.
    318       // We could avoid it, but avoiding it makes this class much more
    319       // complicated.
    320       while (inflater_.CurrentOutputSize() >= kChunkSize ||
    321              frame->header.final) {
    322         size_t size = std::min(kChunkSize, inflater_.CurrentOutputSize());
    323         scoped_ptr<WebSocketFrame> inflated(
    324             new WebSocketFrame(WebSocketFrameHeader::kOpCodeText));
    325         scoped_refptr<IOBufferWithSize> data = inflater_.GetOutput(size);
    326         bool is_final = !inflater_.CurrentOutputSize();
    327         // |is_final| can't be true if |frame->header.final| is false.
    328         DCHECK(!(is_final && !frame->header.final));
    329         if (!data) {
    330           DVLOG(1) << "WebSocket protocol error. "
    331                    << "inflater_.GetOutput() returns an error.";
    332           return ERR_WS_PROTOCOL_ERROR;
    333         }
    334         inflated->header.CopyFrom(frame->header);
    335         inflated->header.opcode = current_reading_opcode_;
    336         inflated->header.final = is_final;
    337         inflated->header.reserved1 = false;
    338         inflated->data = data;
    339         inflated->header.payload_length = data->size();
    340 
    341         frames_to_output.push_back(inflated.release());
    342         current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
    343         if (is_final)
    344           break;
    345       }
    346       if (frame->header.final)
    347         reading_state_ = NOT_READING;
    348     }
    349   }
    350   frames->swap(frames_to_output);
    351   return frames->empty() ? ERR_IO_PENDING : OK;
    352 }
    353 
    354 int WebSocketDeflateStream::InflateAndReadIfNecessary(
    355     ScopedVector<WebSocketFrame>* frames,
    356     const CompletionCallback& callback) {
    357   int result = Inflate(frames);
    358   while (result == ERR_IO_PENDING) {
    359     DCHECK(frames->empty());
    360     result = stream_->ReadFrames(frames, callback);
    361     if (result < 0)
    362       break;
    363     DCHECK_EQ(OK, result);
    364     DCHECK(!frames->empty());
    365     result = Inflate(frames);
    366   }
    367   if (result < 0)
    368     frames->clear();
    369   return result;
    370 }
    371 
    372 }  // namespace net
    373