Home | History | Annotate | Download | only in jsonrpc
      1 // Copyright 2010 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package jsonrpc
      6 
      7 import (
      8 	"bytes"
      9 	"encoding/json"
     10 	"errors"
     11 	"fmt"
     12 	"io"
     13 	"io/ioutil"
     14 	"net"
     15 	"net/rpc"
     16 	"strings"
     17 	"testing"
     18 )
     19 
     20 type Args struct {
     21 	A, B int
     22 }
     23 
     24 type Reply struct {
     25 	C int
     26 }
     27 
     28 type Arith int
     29 
     30 type ArithAddResp struct {
     31 	Id     interface{} `json:"id"`
     32 	Result Reply       `json:"result"`
     33 	Error  interface{} `json:"error"`
     34 }
     35 
     36 func (t *Arith) Add(args *Args, reply *Reply) error {
     37 	reply.C = args.A + args.B
     38 	return nil
     39 }
     40 
     41 func (t *Arith) Mul(args *Args, reply *Reply) error {
     42 	reply.C = args.A * args.B
     43 	return nil
     44 }
     45 
     46 func (t *Arith) Div(args *Args, reply *Reply) error {
     47 	if args.B == 0 {
     48 		return errors.New("divide by zero")
     49 	}
     50 	reply.C = args.A / args.B
     51 	return nil
     52 }
     53 
     54 func (t *Arith) Error(args *Args, reply *Reply) error {
     55 	panic("ERROR")
     56 }
     57 
     58 func init() {
     59 	rpc.Register(new(Arith))
     60 }
     61 
     62 func TestServerNoParams(t *testing.T) {
     63 	cli, srv := net.Pipe()
     64 	defer cli.Close()
     65 	go ServeConn(srv)
     66 	dec := json.NewDecoder(cli)
     67 
     68 	fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "123"}`)
     69 	var resp ArithAddResp
     70 	if err := dec.Decode(&resp); err != nil {
     71 		t.Fatalf("Decode after no params: %s", err)
     72 	}
     73 	if resp.Error == nil {
     74 		t.Fatalf("Expected error, got nil")
     75 	}
     76 }
     77 
     78 func TestServerEmptyMessage(t *testing.T) {
     79 	cli, srv := net.Pipe()
     80 	defer cli.Close()
     81 	go ServeConn(srv)
     82 	dec := json.NewDecoder(cli)
     83 
     84 	fmt.Fprintf(cli, "{}")
     85 	var resp ArithAddResp
     86 	if err := dec.Decode(&resp); err != nil {
     87 		t.Fatalf("Decode after empty: %s", err)
     88 	}
     89 	if resp.Error == nil {
     90 		t.Fatalf("Expected error, got nil")
     91 	}
     92 }
     93 
     94 func TestServer(t *testing.T) {
     95 	cli, srv := net.Pipe()
     96 	defer cli.Close()
     97 	go ServeConn(srv)
     98 	dec := json.NewDecoder(cli)
     99 
    100 	// Send hand-coded requests to server, parse responses.
    101 	for i := 0; i < 10; i++ {
    102 		fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "\u%04d", "params": [{"A": %d, "B": %d}]}`, i, i, i+1)
    103 		var resp ArithAddResp
    104 		err := dec.Decode(&resp)
    105 		if err != nil {
    106 			t.Fatalf("Decode: %s", err)
    107 		}
    108 		if resp.Error != nil {
    109 			t.Fatalf("resp.Error: %s", resp.Error)
    110 		}
    111 		if resp.Id.(string) != string(i) {
    112 			t.Fatalf("resp: bad id %q want %q", resp.Id.(string), string(i))
    113 		}
    114 		if resp.Result.C != 2*i+1 {
    115 			t.Fatalf("resp: bad result: %d+%d=%d", i, i+1, resp.Result.C)
    116 		}
    117 	}
    118 }
    119 
    120 func TestClient(t *testing.T) {
    121 	// Assume server is okay (TestServer is above).
    122 	// Test client against server.
    123 	cli, srv := net.Pipe()
    124 	go ServeConn(srv)
    125 
    126 	client := NewClient(cli)
    127 	defer client.Close()
    128 
    129 	// Synchronous calls
    130 	args := &Args{7, 8}
    131 	reply := new(Reply)
    132 	err := client.Call("Arith.Add", args, reply)
    133 	if err != nil {
    134 		t.Errorf("Add: expected no error but got string %q", err.Error())
    135 	}
    136 	if reply.C != args.A+args.B {
    137 		t.Errorf("Add: got %d expected %d", reply.C, args.A+args.B)
    138 	}
    139 
    140 	args = &Args{7, 8}
    141 	reply = new(Reply)
    142 	err = client.Call("Arith.Mul", args, reply)
    143 	if err != nil {
    144 		t.Errorf("Mul: expected no error but got string %q", err.Error())
    145 	}
    146 	if reply.C != args.A*args.B {
    147 		t.Errorf("Mul: got %d expected %d", reply.C, args.A*args.B)
    148 	}
    149 
    150 	// Out of order.
    151 	args = &Args{7, 8}
    152 	mulReply := new(Reply)
    153 	mulCall := client.Go("Arith.Mul", args, mulReply, nil)
    154 	addReply := new(Reply)
    155 	addCall := client.Go("Arith.Add", args, addReply, nil)
    156 
    157 	addCall = <-addCall.Done
    158 	if addCall.Error != nil {
    159 		t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
    160 	}
    161 	if addReply.C != args.A+args.B {
    162 		t.Errorf("Add: got %d expected %d", addReply.C, args.A+args.B)
    163 	}
    164 
    165 	mulCall = <-mulCall.Done
    166 	if mulCall.Error != nil {
    167 		t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
    168 	}
    169 	if mulReply.C != args.A*args.B {
    170 		t.Errorf("Mul: got %d expected %d", mulReply.C, args.A*args.B)
    171 	}
    172 
    173 	// Error test
    174 	args = &Args{7, 0}
    175 	reply = new(Reply)
    176 	err = client.Call("Arith.Div", args, reply)
    177 	// expect an error: zero divide
    178 	if err == nil {
    179 		t.Error("Div: expected error")
    180 	} else if err.Error() != "divide by zero" {
    181 		t.Error("Div: expected divide by zero error; got", err)
    182 	}
    183 }
    184 
    185 func TestMalformedInput(t *testing.T) {
    186 	cli, srv := net.Pipe()
    187 	go cli.Write([]byte(`{id:1}`)) // invalid json
    188 	ServeConn(srv)                 // must return, not loop
    189 }
    190 
    191 func TestMalformedOutput(t *testing.T) {
    192 	cli, srv := net.Pipe()
    193 	go srv.Write([]byte(`{"id":0,"result":null,"error":null}`))
    194 	go ioutil.ReadAll(srv)
    195 
    196 	client := NewClient(cli)
    197 	defer client.Close()
    198 
    199 	args := &Args{7, 8}
    200 	reply := new(Reply)
    201 	err := client.Call("Arith.Add", args, reply)
    202 	if err == nil {
    203 		t.Error("expected error")
    204 	}
    205 }
    206 
    207 func TestServerErrorHasNullResult(t *testing.T) {
    208 	var out bytes.Buffer
    209 	sc := NewServerCodec(struct {
    210 		io.Reader
    211 		io.Writer
    212 		io.Closer
    213 	}{
    214 		Reader: strings.NewReader(`{"method": "Arith.Add", "id": "123", "params": []}`),
    215 		Writer: &out,
    216 		Closer: ioutil.NopCloser(nil),
    217 	})
    218 	r := new(rpc.Request)
    219 	if err := sc.ReadRequestHeader(r); err != nil {
    220 		t.Fatal(err)
    221 	}
    222 	const valueText = "the value we don't want to see"
    223 	const errorText = "some error"
    224 	err := sc.WriteResponse(&rpc.Response{
    225 		ServiceMethod: "Method",
    226 		Seq:           1,
    227 		Error:         errorText,
    228 	}, valueText)
    229 	if err != nil {
    230 		t.Fatal(err)
    231 	}
    232 	if !strings.Contains(out.String(), errorText) {
    233 		t.Fatalf("Response didn't contain expected error %q: %s", errorText, &out)
    234 	}
    235 	if strings.Contains(out.String(), valueText) {
    236 		t.Errorf("Response contains both an error and value: %s", &out)
    237 	}
    238 }
    239 
    240 func TestUnexpectedError(t *testing.T) {
    241 	cli, srv := myPipe()
    242 	go cli.PipeWriter.CloseWithError(errors.New("unexpected error!")) // reader will get this error
    243 	ServeConn(srv)                                                    // must return, not loop
    244 }
    245 
    246 // Copied from package net.
    247 func myPipe() (*pipe, *pipe) {
    248 	r1, w1 := io.Pipe()
    249 	r2, w2 := io.Pipe()
    250 
    251 	return &pipe{r1, w2}, &pipe{r2, w1}
    252 }
    253 
    254 type pipe struct {
    255 	*io.PipeReader
    256 	*io.PipeWriter
    257 }
    258 
    259 type pipeAddr int
    260 
    261 func (pipeAddr) Network() string {
    262 	return "pipe"
    263 }
    264 
    265 func (pipeAddr) String() string {
    266 	return "pipe"
    267 }
    268 
    269 func (p *pipe) Close() error {
    270 	err := p.PipeReader.Close()
    271 	err1 := p.PipeWriter.Close()
    272 	if err == nil {
    273 		err = err1
    274 	}
    275 	return err
    276 }
    277 
    278 func (p *pipe) LocalAddr() net.Addr {
    279 	return pipeAddr(0)
    280 }
    281 
    282 func (p *pipe) RemoteAddr() net.Addr {
    283 	return pipeAddr(0)
    284 }
    285 
    286 func (p *pipe) SetTimeout(nsec int64) error {
    287 	return errors.New("net.Pipe does not support timeouts")
    288 }
    289 
    290 func (p *pipe) SetReadTimeout(nsec int64) error {
    291 	return errors.New("net.Pipe does not support timeouts")
    292 }
    293 
    294 func (p *pipe) SetWriteTimeout(nsec int64) error {
    295 	return errors.New("net.Pipe does not support timeouts")
    296 }
    297