Home | History | Annotate | Download | only in rpc
      1 // Copyright 2009 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package rpc
      6 
      7 import (
      8 	"bufio"
      9 	"encoding/gob"
     10 	"errors"
     11 	"io"
     12 	"log"
     13 	"net"
     14 	"net/http"
     15 	"sync"
     16 )
     17 
     18 // ServerError represents an error that has been returned from
     19 // the remote side of the RPC connection.
     20 type ServerError string
     21 
     22 func (e ServerError) Error() string {
     23 	return string(e)
     24 }
     25 
     26 var ErrShutdown = errors.New("connection is shut down")
     27 
     28 // Call represents an active RPC.
     29 type Call struct {
     30 	ServiceMethod string      // The name of the service and method to call.
     31 	Args          interface{} // The argument to the function (*struct).
     32 	Reply         interface{} // The reply from the function (*struct).
     33 	Error         error       // After completion, the error status.
     34 	Done          chan *Call  // Strobes when call is complete.
     35 }
     36 
     37 // Client represents an RPC Client.
     38 // There may be multiple outstanding Calls associated
     39 // with a single Client, and a Client may be used by
     40 // multiple goroutines simultaneously.
     41 type Client struct {
     42 	codec ClientCodec
     43 
     44 	reqMutex sync.Mutex // protects following
     45 	request  Request
     46 
     47 	mutex    sync.Mutex // protects following
     48 	seq      uint64
     49 	pending  map[uint64]*Call
     50 	closing  bool // user has called Close
     51 	shutdown bool // server has told us to stop
     52 }
     53 
     54 // A ClientCodec implements writing of RPC requests and
     55 // reading of RPC responses for the client side of an RPC session.
     56 // The client calls WriteRequest to write a request to the connection
     57 // and calls ReadResponseHeader and ReadResponseBody in pairs
     58 // to read responses. The client calls Close when finished with the
     59 // connection. ReadResponseBody may be called with a nil
     60 // argument to force the body of the response to be read and then
     61 // discarded.
     62 type ClientCodec interface {
     63 	// WriteRequest must be safe for concurrent use by multiple goroutines.
     64 	WriteRequest(*Request, interface{}) error
     65 	ReadResponseHeader(*Response) error
     66 	ReadResponseBody(interface{}) error
     67 
     68 	Close() error
     69 }
     70 
     71 func (client *Client) send(call *Call) {
     72 	client.reqMutex.Lock()
     73 	defer client.reqMutex.Unlock()
     74 
     75 	// Register this call.
     76 	client.mutex.Lock()
     77 	if client.shutdown || client.closing {
     78 		call.Error = ErrShutdown
     79 		client.mutex.Unlock()
     80 		call.done()
     81 		return
     82 	}
     83 	seq := client.seq
     84 	client.seq++
     85 	client.pending[seq] = call
     86 	client.mutex.Unlock()
     87 
     88 	// Encode and send the request.
     89 	client.request.Seq = seq
     90 	client.request.ServiceMethod = call.ServiceMethod
     91 	err := client.codec.WriteRequest(&client.request, call.Args)
     92 	if err != nil {
     93 		client.mutex.Lock()
     94 		call = client.pending[seq]
     95 		delete(client.pending, seq)
     96 		client.mutex.Unlock()
     97 		if call != nil {
     98 			call.Error = err
     99 			call.done()
    100 		}
    101 	}
    102 }
    103 
    104 func (client *Client) input() {
    105 	var err error
    106 	var response Response
    107 	for err == nil {
    108 		response = Response{}
    109 		err = client.codec.ReadResponseHeader(&response)
    110 		if err != nil {
    111 			break
    112 		}
    113 		seq := response.Seq
    114 		client.mutex.Lock()
    115 		call := client.pending[seq]
    116 		delete(client.pending, seq)
    117 		client.mutex.Unlock()
    118 
    119 		switch {
    120 		case call == nil:
    121 			// We've got no pending call. That usually means that
    122 			// WriteRequest partially failed, and call was already
    123 			// removed; response is a server telling us about an
    124 			// error reading request body. We should still attempt
    125 			// to read error body, but there's no one to give it to.
    126 			err = client.codec.ReadResponseBody(nil)
    127 			if err != nil {
    128 				err = errors.New("reading error body: " + err.Error())
    129 			}
    130 		case response.Error != "":
    131 			// We've got an error response. Give this to the request;
    132 			// any subsequent requests will get the ReadResponseBody
    133 			// error if there is one.
    134 			call.Error = ServerError(response.Error)
    135 			err = client.codec.ReadResponseBody(nil)
    136 			if err != nil {
    137 				err = errors.New("reading error body: " + err.Error())
    138 			}
    139 			call.done()
    140 		default:
    141 			err = client.codec.ReadResponseBody(call.Reply)
    142 			if err != nil {
    143 				call.Error = errors.New("reading body " + err.Error())
    144 			}
    145 			call.done()
    146 		}
    147 	}
    148 	// Terminate pending calls.
    149 	client.reqMutex.Lock()
    150 	client.mutex.Lock()
    151 	client.shutdown = true
    152 	closing := client.closing
    153 	if err == io.EOF {
    154 		if closing {
    155 			err = ErrShutdown
    156 		} else {
    157 			err = io.ErrUnexpectedEOF
    158 		}
    159 	}
    160 	for _, call := range client.pending {
    161 		call.Error = err
    162 		call.done()
    163 	}
    164 	client.mutex.Unlock()
    165 	client.reqMutex.Unlock()
    166 	if debugLog && err != io.EOF && !closing {
    167 		log.Println("rpc: client protocol error:", err)
    168 	}
    169 }
    170 
    171 func (call *Call) done() {
    172 	select {
    173 	case call.Done <- call:
    174 		// ok
    175 	default:
    176 		// We don't want to block here. It is the caller's responsibility to make
    177 		// sure the channel has enough buffer space. See comment in Go().
    178 		if debugLog {
    179 			log.Println("rpc: discarding Call reply due to insufficient Done chan capacity")
    180 		}
    181 	}
    182 }
    183 
    184 // NewClient returns a new Client to handle requests to the
    185 // set of services at the other end of the connection.
    186 // It adds a buffer to the write side of the connection so
    187 // the header and payload are sent as a unit.
    188 func NewClient(conn io.ReadWriteCloser) *Client {
    189 	encBuf := bufio.NewWriter(conn)
    190 	client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf}
    191 	return NewClientWithCodec(client)
    192 }
    193 
    194 // NewClientWithCodec is like NewClient but uses the specified
    195 // codec to encode requests and decode responses.
    196 func NewClientWithCodec(codec ClientCodec) *Client {
    197 	client := &Client{
    198 		codec:   codec,
    199 		pending: make(map[uint64]*Call),
    200 	}
    201 	go client.input()
    202 	return client
    203 }
    204 
    205 type gobClientCodec struct {
    206 	rwc    io.ReadWriteCloser
    207 	dec    *gob.Decoder
    208 	enc    *gob.Encoder
    209 	encBuf *bufio.Writer
    210 }
    211 
    212 func (c *gobClientCodec) WriteRequest(r *Request, body interface{}) (err error) {
    213 	if err = c.enc.Encode(r); err != nil {
    214 		return
    215 	}
    216 	if err = c.enc.Encode(body); err != nil {
    217 		return
    218 	}
    219 	return c.encBuf.Flush()
    220 }
    221 
    222 func (c *gobClientCodec) ReadResponseHeader(r *Response) error {
    223 	return c.dec.Decode(r)
    224 }
    225 
    226 func (c *gobClientCodec) ReadResponseBody(body interface{}) error {
    227 	return c.dec.Decode(body)
    228 }
    229 
    230 func (c *gobClientCodec) Close() error {
    231 	return c.rwc.Close()
    232 }
    233 
    234 // DialHTTP connects to an HTTP RPC server at the specified network address
    235 // listening on the default HTTP RPC path.
    236 func DialHTTP(network, address string) (*Client, error) {
    237 	return DialHTTPPath(network, address, DefaultRPCPath)
    238 }
    239 
    240 // DialHTTPPath connects to an HTTP RPC server
    241 // at the specified network address and path.
    242 func DialHTTPPath(network, address, path string) (*Client, error) {
    243 	var err error
    244 	conn, err := net.Dial(network, address)
    245 	if err != nil {
    246 		return nil, err
    247 	}
    248 	io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n")
    249 
    250 	// Require successful HTTP response
    251 	// before switching to RPC protocol.
    252 	resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
    253 	if err == nil && resp.Status == connected {
    254 		return NewClient(conn), nil
    255 	}
    256 	if err == nil {
    257 		err = errors.New("unexpected HTTP response: " + resp.Status)
    258 	}
    259 	conn.Close()
    260 	return nil, &net.OpError{
    261 		Op:   "dial-http",
    262 		Net:  network + " " + address,
    263 		Addr: nil,
    264 		Err:  err,
    265 	}
    266 }
    267 
    268 // Dial connects to an RPC server at the specified network address.
    269 func Dial(network, address string) (*Client, error) {
    270 	conn, err := net.Dial(network, address)
    271 	if err != nil {
    272 		return nil, err
    273 	}
    274 	return NewClient(conn), nil
    275 }
    276 
    277 // Close calls the underlying codec's Close method. If the connection is already
    278 // shutting down, ErrShutdown is returned.
    279 func (client *Client) Close() error {
    280 	client.mutex.Lock()
    281 	if client.closing {
    282 		client.mutex.Unlock()
    283 		return ErrShutdown
    284 	}
    285 	client.closing = true
    286 	client.mutex.Unlock()
    287 	return client.codec.Close()
    288 }
    289 
    290 // Go invokes the function asynchronously. It returns the Call structure representing
    291 // the invocation. The done channel will signal when the call is complete by returning
    292 // the same Call object. If done is nil, Go will allocate a new channel.
    293 // If non-nil, done must be buffered or Go will deliberately crash.
    294 func (client *Client) Go(serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call {
    295 	call := new(Call)
    296 	call.ServiceMethod = serviceMethod
    297 	call.Args = args
    298 	call.Reply = reply
    299 	if done == nil {
    300 		done = make(chan *Call, 10) // buffered.
    301 	} else {
    302 		// If caller passes done != nil, it must arrange that
    303 		// done has enough buffer for the number of simultaneous
    304 		// RPCs that will be using that channel. If the channel
    305 		// is totally unbuffered, it's best not to run at all.
    306 		if cap(done) == 0 {
    307 			log.Panic("rpc: done channel is unbuffered")
    308 		}
    309 	}
    310 	call.Done = done
    311 	client.send(call)
    312 	return call
    313 }
    314 
    315 // Call invokes the named function, waits for it to complete, and returns its error status.
    316 func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) error {
    317 	call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
    318 	return call.Error
    319 }
    320