Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/stream_listen_socket.h"
      6 
      7 #if defined(OS_WIN)
      8 // winsock2.h must be included first in order to ensure it is included before
      9 // windows.h.
     10 #include <winsock2.h>
     11 #elif defined(OS_POSIX)
     12 #include <arpa/inet.h>
     13 #include <errno.h>
     14 #include <netinet/in.h>
     15 #include <sys/socket.h>
     16 #include <sys/types.h>
     17 #include "net/base/net_errors.h"
     18 #endif
     19 
     20 #include "base/logging.h"
     21 #include "base/memory/ref_counted.h"
     22 #include "base/memory/scoped_ptr.h"
     23 #include "base/posix/eintr_wrapper.h"
     24 #include "base/sys_byteorder.h"
     25 #include "base/threading/platform_thread.h"
     26 #include "build/build_config.h"
     27 #include "net/base/ip_endpoint.h"
     28 #include "net/base/net_errors.h"
     29 #include "net/base/net_util.h"
     30 
     31 using std::string;
     32 
     33 #if defined(OS_WIN)
     34 typedef int socklen_t;
     35 #endif  // defined(OS_WIN)
     36 
     37 namespace net {
     38 
     39 namespace {
     40 
     41 const int kReadBufSize = 4096;
     42 
     43 }  // namespace
     44 
     45 #if defined(OS_WIN)
     46 const SocketDescriptor StreamListenSocket::kInvalidSocket = INVALID_SOCKET;
     47 const int StreamListenSocket::kSocketError = SOCKET_ERROR;
     48 #elif defined(OS_POSIX)
     49 const SocketDescriptor StreamListenSocket::kInvalidSocket = -1;
     50 const int StreamListenSocket::kSocketError = -1;
     51 #endif
     52 
     53 StreamListenSocket::StreamListenSocket(SocketDescriptor s,
     54                                        StreamListenSocket::Delegate* del)
     55     : socket_delegate_(del),
     56       socket_(s),
     57       reads_paused_(false),
     58       has_pending_reads_(false) {
     59 #if defined(OS_WIN)
     60   socket_event_ = WSACreateEvent();
     61   // TODO(ibrar): error handling in case of socket_event_ == WSA_INVALID_EVENT.
     62   WatchSocket(NOT_WAITING);
     63 #elif defined(OS_POSIX)
     64   wait_state_ = NOT_WAITING;
     65 #endif
     66 }
     67 
     68 StreamListenSocket::~StreamListenSocket() {
     69 #if defined(OS_WIN)
     70   if (socket_event_) {
     71     WSACloseEvent(socket_event_);
     72     socket_event_ = WSA_INVALID_EVENT;
     73   }
     74 #endif
     75   CloseSocket(socket_);
     76 }
     77 
     78 void StreamListenSocket::Send(const char* bytes, int len,
     79                               bool append_linefeed) {
     80   SendInternal(bytes, len);
     81   if (append_linefeed)
     82     SendInternal("\r\n", 2);
     83 }
     84 
     85 void StreamListenSocket::Send(const string& str, bool append_linefeed) {
     86   Send(str.data(), static_cast<int>(str.length()), append_linefeed);
     87 }
     88 
     89 int StreamListenSocket::GetLocalAddress(IPEndPoint* address) {
     90   SockaddrStorage storage;
     91   if (getsockname(socket_, storage.addr, &storage.addr_len)) {
     92 #if defined(OS_WIN)
     93     int err = WSAGetLastError();
     94 #else
     95     int err = errno;
     96 #endif
     97     return MapSystemError(err);
     98   }
     99   if (!address->FromSockAddr(storage.addr, storage.addr_len))
    100     return ERR_FAILED;
    101   return OK;
    102 }
    103 
    104 SocketDescriptor StreamListenSocket::AcceptSocket() {
    105   SocketDescriptor conn = HANDLE_EINTR(accept(socket_, NULL, NULL));
    106   if (conn == kInvalidSocket)
    107     LOG(ERROR) << "Error accepting connection.";
    108   else
    109     SetNonBlocking(conn);
    110   return conn;
    111 }
    112 
    113 void StreamListenSocket::SendInternal(const char* bytes, int len) {
    114   char* send_buf = const_cast<char *>(bytes);
    115   int len_left = len;
    116   while (true) {
    117     int sent = HANDLE_EINTR(send(socket_, send_buf, len_left, 0));
    118     if (sent == len_left) {  // A shortcut to avoid extraneous checks.
    119       break;
    120     }
    121     if (sent == kSocketError) {
    122 #if defined(OS_WIN)
    123       if (WSAGetLastError() != WSAEWOULDBLOCK) {
    124         LOG(ERROR) << "send failed: WSAGetLastError()==" << WSAGetLastError();
    125 #elif defined(OS_POSIX)
    126       if (errno != EWOULDBLOCK && errno != EAGAIN) {
    127         LOG(ERROR) << "send failed: errno==" << errno;
    128 #endif
    129         break;
    130       }
    131       // Otherwise we would block, and now we have to wait for a retry.
    132       // Fall through to PlatformThread::YieldCurrentThread()
    133     } else {
    134       // sent != len_left according to the shortcut above.
    135       // Shift the buffer start and send the remainder after a short while.
    136       send_buf += sent;
    137       len_left -= sent;
    138     }
    139     base::PlatformThread::YieldCurrentThread();
    140   }
    141 }
    142 
    143 void StreamListenSocket::Listen() {
    144   int backlog = 10;  // TODO(erikkay): maybe don't allow any backlog?
    145   if (listen(socket_, backlog) == -1) {
    146     // TODO(erikkay): error handling.
    147     LOG(ERROR) << "Could not listen on socket.";
    148     return;
    149   }
    150 #if defined(OS_POSIX)
    151   WatchSocket(WAITING_ACCEPT);
    152 #endif
    153 }
    154 
    155 void StreamListenSocket::Read() {
    156   char buf[kReadBufSize + 1];  // +1 for null termination.
    157   int len;
    158   do {
    159     len = HANDLE_EINTR(recv(socket_, buf, kReadBufSize, 0));
    160     if (len == kSocketError) {
    161 #if defined(OS_WIN)
    162       int err = WSAGetLastError();
    163       if (err == WSAEWOULDBLOCK) {
    164 #elif defined(OS_POSIX)
    165       if (errno == EWOULDBLOCK || errno == EAGAIN) {
    166 #endif
    167         break;
    168       } else {
    169         // TODO(ibrar): some error handling required here.
    170         break;
    171       }
    172     } else if (len == 0) {
    173       // In Windows, Close() is called by OnObjectSignaled. In POSIX, we need
    174       // to call it here.
    175 #if defined(OS_POSIX)
    176       Close();
    177 #endif
    178     } else {
    179       // TODO(ibrar): maybe change DidRead to take a length instead.
    180       DCHECK_GT(len, 0);
    181       DCHECK_LE(len, kReadBufSize);
    182       buf[len] = 0;  // Already create a buffer with +1 length.
    183       socket_delegate_->DidRead(this, buf, len);
    184     }
    185   } while (len == kReadBufSize);
    186 }
    187 
    188 void StreamListenSocket::Close() {
    189 #if defined(OS_POSIX)
    190   if (wait_state_ == NOT_WAITING)
    191     return;
    192   wait_state_ = NOT_WAITING;
    193 #endif
    194   UnwatchSocket();
    195   socket_delegate_->DidClose(this);
    196 }
    197 
    198 void StreamListenSocket::CloseSocket(SocketDescriptor s) {
    199   if (s && s != kInvalidSocket) {
    200     UnwatchSocket();
    201 #if defined(OS_WIN)
    202     closesocket(s);
    203 #elif defined(OS_POSIX)
    204     close(s);
    205 #endif
    206   }
    207 }
    208 
    209 void StreamListenSocket::WatchSocket(WaitState state) {
    210 #if defined(OS_WIN)
    211   WSAEventSelect(socket_, socket_event_, FD_ACCEPT | FD_CLOSE | FD_READ);
    212   watcher_.StartWatching(socket_event_, this);
    213 #elif defined(OS_POSIX)
    214   // Implicitly calls StartWatchingFileDescriptor().
    215   base::MessageLoopForIO::current()->WatchFileDescriptor(
    216       socket_, true, base::MessageLoopForIO::WATCH_READ, &watcher_, this);
    217   wait_state_ = state;
    218 #endif
    219 }
    220 
    221 void StreamListenSocket::UnwatchSocket() {
    222 #if defined(OS_WIN)
    223   watcher_.StopWatching();
    224 #elif defined(OS_POSIX)
    225   watcher_.StopWatchingFileDescriptor();
    226 #endif
    227 }
    228 
    229 // TODO(ibrar): We can add these functions into OS dependent files.
    230 #if defined(OS_WIN)
    231 // MessageLoop watcher callback.
    232 void StreamListenSocket::OnObjectSignaled(HANDLE object) {
    233   WSANETWORKEVENTS ev;
    234   if (kSocketError == WSAEnumNetworkEvents(socket_, socket_event_, &ev)) {
    235     // TODO
    236     return;
    237   }
    238 
    239   if (ev.lNetworkEvents & FD_CLOSE) {
    240     Close();
    241     // Close might have deleted this object. We should return immediately.
    242     return;
    243   }
    244 
    245   // The object was reset by WSAEnumNetworkEvents.  Watch for the next signal.
    246   watcher_.StartWatching(object, this);
    247 
    248   if (ev.lNetworkEvents == 0) {
    249     // Occasionally the event is set even though there is no new data.
    250     // The net seems to think that this is ignorable.
    251     return;
    252   }
    253   if (ev.lNetworkEvents & FD_ACCEPT) {
    254     Accept();
    255   }
    256   if (ev.lNetworkEvents & FD_READ) {
    257     if (reads_paused_) {
    258       has_pending_reads_ = true;
    259     } else {
    260       Read();
    261       // Read() might call Close() internally and 'this' can be invalid here
    262       return;
    263     }
    264   }
    265 }
    266 #elif defined(OS_POSIX)
    267 void StreamListenSocket::OnFileCanReadWithoutBlocking(int fd) {
    268   switch (wait_state_) {
    269     case WAITING_ACCEPT:
    270       Accept();
    271       break;
    272     case WAITING_READ:
    273       if (reads_paused_) {
    274         has_pending_reads_ = true;
    275       } else {
    276         Read();
    277       }
    278       break;
    279     default:
    280       // Close() is called by Read() in the Linux case.
    281       NOTREACHED();
    282       break;
    283   }
    284 }
    285 
    286 void StreamListenSocket::OnFileCanWriteWithoutBlocking(int fd) {
    287   // MessagePumpLibevent callback, we don't listen for write events
    288   // so we shouldn't ever reach here.
    289   NOTREACHED();
    290 }
    291 
    292 #endif
    293 
    294 void StreamListenSocket::PauseReads() {
    295   DCHECK(!reads_paused_);
    296   reads_paused_ = true;
    297 }
    298 
    299 void StreamListenSocket::ResumeReads() {
    300   DCHECK(reads_paused_);
    301   reads_paused_ = false;
    302   if (has_pending_reads_) {
    303     has_pending_reads_ = false;
    304     Read();
    305   }
    306 }
    307 
    308 }  // namespace net
    309