Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2009 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #ifndef NET_SOCKET_TCP_PINGER_H_
      6 #define NET_SOCKET_TCP_PINGER_H_
      7 
      8 #include "base/compiler_specific.h"
      9 #include "base/ref_counted.h"
     10 #include "base/scoped_ptr.h"
     11 #include "base/task.h"
     12 #include "base/thread.h"
     13 #include "base/waitable_event.h"
     14 #include "net/base/address_list.h"
     15 #include "net/base/completion_callback.h"
     16 #include "net/base/net_errors.h"
     17 #include "net/socket/tcp_client_socket.h"
     18 
     19 namespace base {
     20 class TimeDelta;
     21 }
     22 
     23 namespace net {
     24 
     25 // Simple class to wait until a TCP server is accepting connections.
     26 class TCPPinger {
     27  public:
     28   explicit TCPPinger(const net::AddressList& addr)
     29     : io_thread_("TCPPinger"),
     30       worker_(new Worker(addr)) {
     31     worker_->AddRef();
     32     // Start up a throwaway IO thread just for this.
     33     // TODO(dkegel): use some existing thread pool instead?
     34     base::Thread::Options options;
     35     options.message_loop_type = MessageLoop::TYPE_IO;
     36     io_thread_.StartWithOptions(options);
     37   }
     38 
     39   ~TCPPinger() {
     40     io_thread_.message_loop()->ReleaseSoon(FROM_HERE, worker_);
     41   }
     42 
     43   int Ping() {
     44     // Default is 10 tries, each with a timeout of 1000ms,
     45     // for a total max timeout of 10 seconds.
     46     return Ping(base::TimeDelta::FromMilliseconds(1000), 10);
     47   }
     48 
     49   int Ping(base::TimeDelta tryTimeout, int nTries) {
     50     int err = ERR_IO_PENDING;
     51     // Post a request to do the connect on that thread.
     52     for (int i = 0; i < nTries; i++) {
     53       io_thread_.message_loop()->PostTask(FROM_HERE,
     54         NewRunnableMethod(worker_,
     55         &net::TCPPinger::Worker::DoConnect));
     56       // Timeout here in case remote host offline
     57       err = worker_->TimedWaitForResult(tryTimeout);
     58       if (err == net::OK)
     59         break;
     60       PlatformThread::Sleep(static_cast<int>(tryTimeout.InMilliseconds()));
     61 
     62       // Cancel leftover activity, if any
     63       io_thread_.message_loop()->PostTask(FROM_HERE,
     64         NewRunnableMethod(worker_,
     65         &net::TCPPinger::Worker::DoDisconnect));
     66       worker_->WaitForResult();
     67     }
     68     return err;
     69   }
     70 
     71  private:
     72 
     73   // Inner class to handle all actual socket calls.
     74   // This makes the outer interface simpler,
     75   // and helps us obey the "all socket calls
     76   // must be on same thread" restriction.
     77   class Worker : public base::RefCountedThreadSafe<Worker> {
     78    public:
     79     explicit Worker(const net::AddressList& addr)
     80       : event_(false, false),
     81         net_error_(ERR_IO_PENDING),
     82         addr_(addr),
     83         ALLOW_THIS_IN_INITIALIZER_LIST(connect_callback_(this,
     84             &net::TCPPinger::Worker::ConnectDone)) {
     85     }
     86 
     87     void DoConnect() {
     88       sock_.reset(new TCPClientSocket(addr_));
     89       int rv = sock_->Connect(&connect_callback_, NULL);
     90       // Regardless of success or failure, if we're done now,
     91       // signal the customer.
     92       if (rv != ERR_IO_PENDING)
     93         ConnectDone(rv);
     94     }
     95 
     96     void DoDisconnect() {
     97       sock_.reset();
     98       event_.Signal();
     99     }
    100 
    101     void ConnectDone(int rv) {
    102       sock_.reset();
    103       net_error_ = rv;
    104       event_.Signal();
    105     }
    106 
    107     int TimedWaitForResult(base::TimeDelta tryTimeout) {
    108       event_.TimedWait(tryTimeout);
    109       return net_error_;
    110     }
    111 
    112     int WaitForResult() {
    113       event_.Wait();
    114       return net_error_;
    115     }
    116 
    117    private:
    118     friend class base::RefCountedThreadSafe<Worker>;
    119 
    120     ~Worker() {}
    121 
    122     base::WaitableEvent event_;
    123     int net_error_;
    124     net::AddressList addr_;
    125     scoped_ptr<TCPClientSocket> sock_;
    126     net::CompletionCallbackImpl<Worker> connect_callback_;
    127   };
    128 
    129   base::Thread io_thread_;
    130   Worker* worker_;
    131   DISALLOW_COPY_AND_ASSIGN(TCPPinger);
    132 };
    133 
    134 }  // namespace net
    135 
    136 #endif  // NET_SOCKET_TCP_PINGER_H_
    137