1 /* 2 * Copyright 2004 The WebRTC Project Authors. All rights reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include "webrtc/base/asynctcpsocket.h" 12 13 #include <string.h> 14 15 #include "webrtc/base/byteorder.h" 16 #include "webrtc/base/common.h" 17 #include "webrtc/base/logging.h" 18 19 #if defined(WEBRTC_POSIX) 20 #include <errno.h> 21 #endif // WEBRTC_POSIX 22 23 namespace rtc { 24 25 static const size_t kMaxPacketSize = 64 * 1024; 26 27 typedef uint16 PacketLength; 28 static const size_t kPacketLenSize = sizeof(PacketLength); 29 30 static const size_t kBufSize = kMaxPacketSize + kPacketLenSize; 31 32 static const int kListenBacklog = 5; 33 34 // Binds and connects |socket| 35 AsyncSocket* AsyncTCPSocketBase::ConnectSocket( 36 rtc::AsyncSocket* socket, 37 const rtc::SocketAddress& bind_address, 38 const rtc::SocketAddress& remote_address) { 39 rtc::scoped_ptr<rtc::AsyncSocket> owned_socket(socket); 40 if (socket->Bind(bind_address) < 0) { 41 LOG(LS_ERROR) << "Bind() failed with error " << socket->GetError(); 42 return NULL; 43 } 44 if (socket->Connect(remote_address) < 0) { 45 LOG(LS_ERROR) << "Connect() failed with error " << socket->GetError(); 46 return NULL; 47 } 48 return owned_socket.release(); 49 } 50 51 AsyncTCPSocketBase::AsyncTCPSocketBase(AsyncSocket* socket, bool listen, 52 size_t max_packet_size) 53 : socket_(socket), 54 listen_(listen), 55 insize_(max_packet_size), 56 inpos_(0), 57 outsize_(max_packet_size), 58 outpos_(0) { 59 inbuf_ = new char[insize_]; 60 outbuf_ = new char[outsize_]; 61 62 ASSERT(socket_.get() != NULL); 63 socket_->SignalConnectEvent.connect( 64 this, &AsyncTCPSocketBase::OnConnectEvent); 65 socket_->SignalReadEvent.connect(this, &AsyncTCPSocketBase::OnReadEvent); 66 socket_->SignalWriteEvent.connect(this, &AsyncTCPSocketBase::OnWriteEvent); 67 socket_->SignalCloseEvent.connect(this, &AsyncTCPSocketBase::OnCloseEvent); 68 69 if (listen_) { 70 if (socket_->Listen(kListenBacklog) < 0) { 71 LOG(LS_ERROR) << "Listen() failed with error " << socket_->GetError(); 72 } 73 } 74 } 75 76 AsyncTCPSocketBase::~AsyncTCPSocketBase() { 77 delete [] inbuf_; 78 delete [] outbuf_; 79 } 80 81 SocketAddress AsyncTCPSocketBase::GetLocalAddress() const { 82 return socket_->GetLocalAddress(); 83 } 84 85 SocketAddress AsyncTCPSocketBase::GetRemoteAddress() const { 86 return socket_->GetRemoteAddress(); 87 } 88 89 int AsyncTCPSocketBase::Close() { 90 return socket_->Close(); 91 } 92 93 AsyncTCPSocket::State AsyncTCPSocketBase::GetState() const { 94 switch (socket_->GetState()) { 95 case Socket::CS_CLOSED: 96 return STATE_CLOSED; 97 case Socket::CS_CONNECTING: 98 if (listen_) { 99 return STATE_BOUND; 100 } else { 101 return STATE_CONNECTING; 102 } 103 case Socket::CS_CONNECTED: 104 return STATE_CONNECTED; 105 default: 106 ASSERT(false); 107 return STATE_CLOSED; 108 } 109 } 110 111 int AsyncTCPSocketBase::GetOption(Socket::Option opt, int* value) { 112 return socket_->GetOption(opt, value); 113 } 114 115 int AsyncTCPSocketBase::SetOption(Socket::Option opt, int value) { 116 return socket_->SetOption(opt, value); 117 } 118 119 int AsyncTCPSocketBase::GetError() const { 120 return socket_->GetError(); 121 } 122 123 void AsyncTCPSocketBase::SetError(int error) { 124 return socket_->SetError(error); 125 } 126 127 int AsyncTCPSocketBase::SendTo(const void *pv, size_t cb, 128 const SocketAddress& addr, 129 const rtc::PacketOptions& options) { 130 if (addr == GetRemoteAddress()) 131 return Send(pv, cb, options); 132 133 ASSERT(false); 134 socket_->SetError(ENOTCONN); 135 return -1; 136 } 137 138 int AsyncTCPSocketBase::SendRaw(const void * pv, size_t cb) { 139 if (outpos_ + cb > outsize_) { 140 socket_->SetError(EMSGSIZE); 141 return -1; 142 } 143 144 memcpy(outbuf_ + outpos_, pv, cb); 145 outpos_ += cb; 146 147 return FlushOutBuffer(); 148 } 149 150 int AsyncTCPSocketBase::FlushOutBuffer() { 151 int res = socket_->Send(outbuf_, outpos_); 152 if (res <= 0) { 153 return res; 154 } 155 if (static_cast<size_t>(res) <= outpos_) { 156 outpos_ -= res; 157 } else { 158 ASSERT(false); 159 return -1; 160 } 161 if (outpos_ > 0) { 162 memmove(outbuf_, outbuf_ + res, outpos_); 163 } 164 return res; 165 } 166 167 void AsyncTCPSocketBase::AppendToOutBuffer(const void* pv, size_t cb) { 168 ASSERT(outpos_ + cb < outsize_); 169 memcpy(outbuf_ + outpos_, pv, cb); 170 outpos_ += cb; 171 } 172 173 void AsyncTCPSocketBase::OnConnectEvent(AsyncSocket* socket) { 174 SignalConnect(this); 175 } 176 177 void AsyncTCPSocketBase::OnReadEvent(AsyncSocket* socket) { 178 ASSERT(socket_.get() == socket); 179 180 if (listen_) { 181 rtc::SocketAddress address; 182 rtc::AsyncSocket* new_socket = socket->Accept(&address); 183 if (!new_socket) { 184 // TODO: Do something better like forwarding the error 185 // to the user. 186 LOG(LS_ERROR) << "TCP accept failed with error " << socket_->GetError(); 187 return; 188 } 189 190 HandleIncomingConnection(new_socket); 191 192 // Prime a read event in case data is waiting. 193 new_socket->SignalReadEvent(new_socket); 194 } else { 195 int len = socket_->Recv(inbuf_ + inpos_, insize_ - inpos_); 196 if (len < 0) { 197 // TODO: Do something better like forwarding the error to the user. 198 if (!socket_->IsBlocking()) { 199 LOG(LS_ERROR) << "Recv() returned error: " << socket_->GetError(); 200 } 201 return; 202 } 203 204 inpos_ += len; 205 206 ProcessInput(inbuf_, &inpos_); 207 208 if (inpos_ >= insize_) { 209 LOG(LS_ERROR) << "input buffer overflow"; 210 ASSERT(false); 211 inpos_ = 0; 212 } 213 } 214 } 215 216 void AsyncTCPSocketBase::OnWriteEvent(AsyncSocket* socket) { 217 ASSERT(socket_.get() == socket); 218 219 if (outpos_ > 0) { 220 FlushOutBuffer(); 221 } 222 223 if (outpos_ == 0) { 224 SignalReadyToSend(this); 225 } 226 } 227 228 void AsyncTCPSocketBase::OnCloseEvent(AsyncSocket* socket, int error) { 229 SignalClose(this, error); 230 } 231 232 // AsyncTCPSocket 233 // Binds and connects |socket| and creates AsyncTCPSocket for 234 // it. Takes ownership of |socket|. Returns NULL if bind() or 235 // connect() fail (|socket| is destroyed in that case). 236 AsyncTCPSocket* AsyncTCPSocket::Create( 237 AsyncSocket* socket, 238 const SocketAddress& bind_address, 239 const SocketAddress& remote_address) { 240 return new AsyncTCPSocket(AsyncTCPSocketBase::ConnectSocket( 241 socket, bind_address, remote_address), false); 242 } 243 244 AsyncTCPSocket::AsyncTCPSocket(AsyncSocket* socket, bool listen) 245 : AsyncTCPSocketBase(socket, listen, kBufSize) { 246 } 247 248 int AsyncTCPSocket::Send(const void *pv, size_t cb, 249 const rtc::PacketOptions& options) { 250 if (cb > kBufSize) { 251 SetError(EMSGSIZE); 252 return -1; 253 } 254 255 // If we are blocking on send, then silently drop this packet 256 if (!IsOutBufferEmpty()) 257 return static_cast<int>(cb); 258 259 PacketLength pkt_len = HostToNetwork16(static_cast<PacketLength>(cb)); 260 AppendToOutBuffer(&pkt_len, kPacketLenSize); 261 AppendToOutBuffer(pv, cb); 262 263 int res = FlushOutBuffer(); 264 if (res <= 0) { 265 // drop packet if we made no progress 266 ClearOutBuffer(); 267 return res; 268 } 269 270 // We claim to have sent the whole thing, even if we only sent partial 271 return static_cast<int>(cb); 272 } 273 274 void AsyncTCPSocket::ProcessInput(char * data, size_t* len) { 275 SocketAddress remote_addr(GetRemoteAddress()); 276 277 while (true) { 278 if (*len < kPacketLenSize) 279 return; 280 281 PacketLength pkt_len = rtc::GetBE16(data); 282 if (*len < kPacketLenSize + pkt_len) 283 return; 284 285 SignalReadPacket(this, data + kPacketLenSize, pkt_len, remote_addr, 286 CreatePacketTime(0)); 287 288 *len -= kPacketLenSize + pkt_len; 289 if (*len > 0) { 290 memmove(data, data + kPacketLenSize + pkt_len, *len); 291 } 292 } 293 } 294 295 void AsyncTCPSocket::HandleIncomingConnection(AsyncSocket* socket) { 296 SignalNewConnection(this, new AsyncTCPSocket(socket, false)); 297 } 298 299 } // namespace rtc 300