1 /* 2 * libjingle 3 * Copyright 2004--2010, Google Inc. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, 9 * this list of conditions and the following disclaimer. 10 * 2. Redistributions in binary form must reproduce the above copyright notice, 11 * this list of conditions and the following disclaimer in the documentation 12 * and/or other materials provided with the distribution. 13 * 3. The name of the author may not be used to endorse or promote products 14 * derived from this software without specific prior written permission. 15 * 16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED 17 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 18 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO 19 * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 20 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 22 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 23 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 24 * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 25 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 */ 27 28 #include "talk/base/asynctcpsocket.h" 29 30 #include <cstring> 31 32 #include "talk/base/byteorder.h" 33 #include "talk/base/common.h" 34 #include "talk/base/logging.h" 35 36 #ifdef POSIX 37 #include <errno.h> 38 #endif // POSIX 39 40 namespace talk_base { 41 42 static const size_t kMaxPacketSize = 64 * 1024; 43 44 typedef uint16 PacketLength; 45 static const size_t kPacketLenSize = sizeof(PacketLength); 46 47 static const size_t kBufSize = kMaxPacketSize + kPacketLenSize; 48 49 static const int kListenBacklog = 5; 50 51 // Binds and connects |socket| 52 AsyncSocket* AsyncTCPSocketBase::ConnectSocket( 53 talk_base::AsyncSocket* socket, 54 const talk_base::SocketAddress& bind_address, 55 const talk_base::SocketAddress& remote_address) { 56 talk_base::scoped_ptr<talk_base::AsyncSocket> owned_socket(socket); 57 if (socket->Bind(bind_address) < 0) { 58 LOG(LS_ERROR) << "Bind() failed with error " << socket->GetError(); 59 return NULL; 60 } 61 if (socket->Connect(remote_address) < 0) { 62 LOG(LS_ERROR) << "Connect() failed with error " << socket->GetError(); 63 return NULL; 64 } 65 return owned_socket.release(); 66 } 67 68 AsyncTCPSocketBase::AsyncTCPSocketBase(AsyncSocket* socket, bool listen, 69 size_t max_packet_size) 70 : socket_(socket), 71 listen_(listen), 72 insize_(max_packet_size), 73 inpos_(0), 74 outsize_(max_packet_size), 75 outpos_(0) { 76 inbuf_ = new char[insize_]; 77 outbuf_ = new char[outsize_]; 78 79 ASSERT(socket_.get() != NULL); 80 socket_->SignalConnectEvent.connect( 81 this, &AsyncTCPSocketBase::OnConnectEvent); 82 socket_->SignalReadEvent.connect(this, &AsyncTCPSocketBase::OnReadEvent); 83 socket_->SignalWriteEvent.connect(this, &AsyncTCPSocketBase::OnWriteEvent); 84 socket_->SignalCloseEvent.connect(this, &AsyncTCPSocketBase::OnCloseEvent); 85 86 if (listen_) { 87 if (socket_->Listen(kListenBacklog) < 0) { 88 LOG(LS_ERROR) << "Listen() failed with error " << socket_->GetError(); 89 } 90 } 91 } 92 93 AsyncTCPSocketBase::~AsyncTCPSocketBase() { 94 delete [] inbuf_; 95 delete [] outbuf_; 96 } 97 98 SocketAddress AsyncTCPSocketBase::GetLocalAddress() const { 99 return socket_->GetLocalAddress(); 100 } 101 102 SocketAddress AsyncTCPSocketBase::GetRemoteAddress() const { 103 return socket_->GetRemoteAddress(); 104 } 105 106 int AsyncTCPSocketBase::Close() { 107 return socket_->Close(); 108 } 109 110 AsyncTCPSocket::State AsyncTCPSocketBase::GetState() const { 111 switch (socket_->GetState()) { 112 case Socket::CS_CLOSED: 113 return STATE_CLOSED; 114 case Socket::CS_CONNECTING: 115 if (listen_) { 116 return STATE_BOUND; 117 } else { 118 return STATE_CONNECTING; 119 } 120 case Socket::CS_CONNECTED: 121 return STATE_CONNECTED; 122 default: 123 ASSERT(false); 124 return STATE_CLOSED; 125 } 126 } 127 128 int AsyncTCPSocketBase::GetOption(Socket::Option opt, int* value) { 129 return socket_->GetOption(opt, value); 130 } 131 132 int AsyncTCPSocketBase::SetOption(Socket::Option opt, int value) { 133 return socket_->SetOption(opt, value); 134 } 135 136 int AsyncTCPSocketBase::GetError() const { 137 return socket_->GetError(); 138 } 139 140 void AsyncTCPSocketBase::SetError(int error) { 141 return socket_->SetError(error); 142 } 143 144 int AsyncTCPSocketBase::SendTo(const void *pv, size_t cb, 145 const SocketAddress& addr) { 146 if (addr == GetRemoteAddress()) 147 return Send(pv, cb); 148 149 ASSERT(false); 150 socket_->SetError(ENOTCONN); 151 return -1; 152 } 153 154 int AsyncTCPSocketBase::SendRaw(const void * pv, size_t cb) { 155 if (outpos_ + cb > outsize_) { 156 socket_->SetError(EMSGSIZE); 157 return -1; 158 } 159 160 memcpy(outbuf_ + outpos_, pv, cb); 161 outpos_ += cb; 162 163 return FlushOutBuffer(); 164 } 165 166 int AsyncTCPSocketBase::FlushOutBuffer() { 167 int res = socket_->Send(outbuf_, outpos_); 168 if (res <= 0) { 169 return res; 170 } 171 if (static_cast<size_t>(res) <= outpos_) { 172 outpos_ -= res; 173 } else { 174 ASSERT(false); 175 return -1; 176 } 177 if (outpos_ > 0) { 178 memmove(outbuf_, outbuf_ + res, outpos_); 179 } 180 return res; 181 } 182 183 void AsyncTCPSocketBase::AppendToOutBuffer(const void* pv, size_t cb) { 184 ASSERT(outpos_ + cb < outsize_); 185 memcpy(outbuf_ + outpos_, pv, cb); 186 outpos_ += cb; 187 } 188 189 void AsyncTCPSocketBase::OnConnectEvent(AsyncSocket* socket) { 190 SignalConnect(this); 191 } 192 193 void AsyncTCPSocketBase::OnReadEvent(AsyncSocket* socket) { 194 ASSERT(socket_.get() == socket); 195 196 if (listen_) { 197 talk_base::SocketAddress address; 198 talk_base::AsyncSocket* new_socket = socket->Accept(&address); 199 if (!new_socket) { 200 // TODO: Do something better like forwarding the error 201 // to the user. 202 LOG(LS_ERROR) << "TCP accept failed with error " << socket_->GetError(); 203 return; 204 } 205 206 HandleIncomingConnection(new_socket); 207 208 // Prime a read event in case data is waiting. 209 new_socket->SignalReadEvent(new_socket); 210 } else { 211 int len = socket_->Recv(inbuf_ + inpos_, insize_ - inpos_); 212 if (len < 0) { 213 // TODO: Do something better like forwarding the error to the user. 214 if (!socket_->IsBlocking()) { 215 LOG(LS_ERROR) << "Recv() returned error: " << socket_->GetError(); 216 } 217 return; 218 } 219 220 inpos_ += len; 221 222 ProcessInput(inbuf_, &inpos_); 223 224 if (inpos_ >= insize_) { 225 LOG(LS_ERROR) << "input buffer overflow"; 226 ASSERT(false); 227 inpos_ = 0; 228 } 229 } 230 } 231 232 void AsyncTCPSocketBase::OnWriteEvent(AsyncSocket* socket) { 233 ASSERT(socket_.get() == socket); 234 235 if (outpos_ > 0) { 236 FlushOutBuffer(); 237 } 238 239 if (outpos_ == 0) { 240 SignalReadyToSend(this); 241 } 242 } 243 244 void AsyncTCPSocketBase::OnCloseEvent(AsyncSocket* socket, int error) { 245 SignalClose(this, error); 246 } 247 248 // AsyncTCPSocket 249 // Binds and connects |socket| and creates AsyncTCPSocket for 250 // it. Takes ownership of |socket|. Returns NULL if bind() or 251 // connect() fail (|socket| is destroyed in that case). 252 AsyncTCPSocket* AsyncTCPSocket::Create( 253 AsyncSocket* socket, 254 const SocketAddress& bind_address, 255 const SocketAddress& remote_address) { 256 return new AsyncTCPSocket(AsyncTCPSocketBase::ConnectSocket( 257 socket, bind_address, remote_address), false); 258 } 259 260 AsyncTCPSocket::AsyncTCPSocket(AsyncSocket* socket, bool listen) 261 : AsyncTCPSocketBase(socket, listen, kBufSize) { 262 } 263 264 int AsyncTCPSocket::Send(const void *pv, size_t cb) { 265 if (cb > kBufSize) { 266 SetError(EMSGSIZE); 267 return -1; 268 } 269 270 // If we are blocking on send, then silently drop this packet 271 if (!IsOutBufferEmpty()) 272 return static_cast<int>(cb); 273 274 PacketLength pkt_len = HostToNetwork16(static_cast<PacketLength>(cb)); 275 AppendToOutBuffer(&pkt_len, kPacketLenSize); 276 AppendToOutBuffer(pv, cb); 277 278 int res = FlushOutBuffer(); 279 if (res <= 0) { 280 // drop packet if we made no progress 281 ClearOutBuffer(); 282 return res; 283 } 284 285 // We claim to have sent the whole thing, even if we only sent partial 286 return static_cast<int>(cb); 287 } 288 289 void AsyncTCPSocket::ProcessInput(char * data, size_t* len) { 290 SocketAddress remote_addr(GetRemoteAddress()); 291 292 while (true) { 293 if (*len < kPacketLenSize) 294 return; 295 296 PacketLength pkt_len = talk_base::GetBE16(data); 297 if (*len < kPacketLenSize + pkt_len) 298 return; 299 300 SignalReadPacket(this, data + kPacketLenSize, pkt_len, remote_addr); 301 302 *len -= kPacketLenSize + pkt_len; 303 if (*len > 0) { 304 memmove(data, data + kPacketLenSize + pkt_len, *len); 305 } 306 } 307 } 308 309 void AsyncTCPSocket::HandleIncomingConnection(AsyncSocket* socket) { 310 SignalNewConnection(this, new AsyncTCPSocket(socket, false)); 311 } 312 313 } // namespace talk_base 314