Home | History | Annotate | Download | only in jsonrpc
      1 // Copyright 2010 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package jsonrpc
      6 
      7 import (
      8 	"bytes"
      9 	"encoding/json"
     10 	"errors"
     11 	"fmt"
     12 	"io"
     13 	"io/ioutil"
     14 	"net"
     15 	"net/rpc"
     16 	"reflect"
     17 	"strings"
     18 	"testing"
     19 )
     20 
     21 type Args struct {
     22 	A, B int
     23 }
     24 
     25 type Reply struct {
     26 	C int
     27 }
     28 
     29 type Arith int
     30 
     31 type ArithAddResp struct {
     32 	Id     interface{} `json:"id"`
     33 	Result Reply       `json:"result"`
     34 	Error  interface{} `json:"error"`
     35 }
     36 
     37 func (t *Arith) Add(args *Args, reply *Reply) error {
     38 	reply.C = args.A + args.B
     39 	return nil
     40 }
     41 
     42 func (t *Arith) Mul(args *Args, reply *Reply) error {
     43 	reply.C = args.A * args.B
     44 	return nil
     45 }
     46 
     47 func (t *Arith) Div(args *Args, reply *Reply) error {
     48 	if args.B == 0 {
     49 		return errors.New("divide by zero")
     50 	}
     51 	reply.C = args.A / args.B
     52 	return nil
     53 }
     54 
     55 func (t *Arith) Error(args *Args, reply *Reply) error {
     56 	panic("ERROR")
     57 }
     58 
     59 type BuiltinTypes struct{}
     60 
     61 func (BuiltinTypes) Map(i int, reply *map[int]int) error {
     62 	(*reply)[i] = i
     63 	return nil
     64 }
     65 
     66 func (BuiltinTypes) Slice(i int, reply *[]int) error {
     67 	*reply = append(*reply, i)
     68 	return nil
     69 }
     70 
     71 func (BuiltinTypes) Array(i int, reply *[1]int) error {
     72 	(*reply)[0] = i
     73 	return nil
     74 }
     75 
     76 func init() {
     77 	rpc.Register(new(Arith))
     78 	rpc.Register(BuiltinTypes{})
     79 }
     80 
     81 func TestServerNoParams(t *testing.T) {
     82 	cli, srv := net.Pipe()
     83 	defer cli.Close()
     84 	go ServeConn(srv)
     85 	dec := json.NewDecoder(cli)
     86 
     87 	fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "123"}`)
     88 	var resp ArithAddResp
     89 	if err := dec.Decode(&resp); err != nil {
     90 		t.Fatalf("Decode after no params: %s", err)
     91 	}
     92 	if resp.Error == nil {
     93 		t.Fatalf("Expected error, got nil")
     94 	}
     95 }
     96 
     97 func TestServerEmptyMessage(t *testing.T) {
     98 	cli, srv := net.Pipe()
     99 	defer cli.Close()
    100 	go ServeConn(srv)
    101 	dec := json.NewDecoder(cli)
    102 
    103 	fmt.Fprintf(cli, "{}")
    104 	var resp ArithAddResp
    105 	if err := dec.Decode(&resp); err != nil {
    106 		t.Fatalf("Decode after empty: %s", err)
    107 	}
    108 	if resp.Error == nil {
    109 		t.Fatalf("Expected error, got nil")
    110 	}
    111 }
    112 
    113 func TestServer(t *testing.T) {
    114 	cli, srv := net.Pipe()
    115 	defer cli.Close()
    116 	go ServeConn(srv)
    117 	dec := json.NewDecoder(cli)
    118 
    119 	// Send hand-coded requests to server, parse responses.
    120 	for i := 0; i < 10; i++ {
    121 		fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "\u%04d", "params": [{"A": %d, "B": %d}]}`, i, i, i+1)
    122 		var resp ArithAddResp
    123 		err := dec.Decode(&resp)
    124 		if err != nil {
    125 			t.Fatalf("Decode: %s", err)
    126 		}
    127 		if resp.Error != nil {
    128 			t.Fatalf("resp.Error: %s", resp.Error)
    129 		}
    130 		if resp.Id.(string) != string(i) {
    131 			t.Fatalf("resp: bad id %q want %q", resp.Id.(string), string(i))
    132 		}
    133 		if resp.Result.C != 2*i+1 {
    134 			t.Fatalf("resp: bad result: %d+%d=%d", i, i+1, resp.Result.C)
    135 		}
    136 	}
    137 }
    138 
    139 func TestClient(t *testing.T) {
    140 	// Assume server is okay (TestServer is above).
    141 	// Test client against server.
    142 	cli, srv := net.Pipe()
    143 	go ServeConn(srv)
    144 
    145 	client := NewClient(cli)
    146 	defer client.Close()
    147 
    148 	// Synchronous calls
    149 	args := &Args{7, 8}
    150 	reply := new(Reply)
    151 	err := client.Call("Arith.Add", args, reply)
    152 	if err != nil {
    153 		t.Errorf("Add: expected no error but got string %q", err.Error())
    154 	}
    155 	if reply.C != args.A+args.B {
    156 		t.Errorf("Add: got %d expected %d", reply.C, args.A+args.B)
    157 	}
    158 
    159 	args = &Args{7, 8}
    160 	reply = new(Reply)
    161 	err = client.Call("Arith.Mul", args, reply)
    162 	if err != nil {
    163 		t.Errorf("Mul: expected no error but got string %q", err.Error())
    164 	}
    165 	if reply.C != args.A*args.B {
    166 		t.Errorf("Mul: got %d expected %d", reply.C, args.A*args.B)
    167 	}
    168 
    169 	// Out of order.
    170 	args = &Args{7, 8}
    171 	mulReply := new(Reply)
    172 	mulCall := client.Go("Arith.Mul", args, mulReply, nil)
    173 	addReply := new(Reply)
    174 	addCall := client.Go("Arith.Add", args, addReply, nil)
    175 
    176 	addCall = <-addCall.Done
    177 	if addCall.Error != nil {
    178 		t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
    179 	}
    180 	if addReply.C != args.A+args.B {
    181 		t.Errorf("Add: got %d expected %d", addReply.C, args.A+args.B)
    182 	}
    183 
    184 	mulCall = <-mulCall.Done
    185 	if mulCall.Error != nil {
    186 		t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
    187 	}
    188 	if mulReply.C != args.A*args.B {
    189 		t.Errorf("Mul: got %d expected %d", mulReply.C, args.A*args.B)
    190 	}
    191 
    192 	// Error test
    193 	args = &Args{7, 0}
    194 	reply = new(Reply)
    195 	err = client.Call("Arith.Div", args, reply)
    196 	// expect an error: zero divide
    197 	if err == nil {
    198 		t.Error("Div: expected error")
    199 	} else if err.Error() != "divide by zero" {
    200 		t.Error("Div: expected divide by zero error; got", err)
    201 	}
    202 }
    203 
    204 func TestBuiltinTypes(t *testing.T) {
    205 	cli, srv := net.Pipe()
    206 	go ServeConn(srv)
    207 
    208 	client := NewClient(cli)
    209 	defer client.Close()
    210 
    211 	// Map
    212 	arg := 7
    213 	replyMap := map[int]int{}
    214 	err := client.Call("BuiltinTypes.Map", arg, &replyMap)
    215 	if err != nil {
    216 		t.Errorf("Map: expected no error but got string %q", err.Error())
    217 	}
    218 	if replyMap[arg] != arg {
    219 		t.Errorf("Map: expected %d got %d", arg, replyMap[arg])
    220 	}
    221 
    222 	// Slice
    223 	replySlice := []int{}
    224 	err = client.Call("BuiltinTypes.Slice", arg, &replySlice)
    225 	if err != nil {
    226 		t.Errorf("Slice: expected no error but got string %q", err.Error())
    227 	}
    228 	if e := []int{arg}; !reflect.DeepEqual(replySlice, e) {
    229 		t.Errorf("Slice: expected %v got %v", e, replySlice)
    230 	}
    231 
    232 	// Array
    233 	replyArray := [1]int{}
    234 	err = client.Call("BuiltinTypes.Array", arg, &replyArray)
    235 	if err != nil {
    236 		t.Errorf("Array: expected no error but got string %q", err.Error())
    237 	}
    238 	if e := [1]int{arg}; !reflect.DeepEqual(replyArray, e) {
    239 		t.Errorf("Array: expected %v got %v", e, replyArray)
    240 	}
    241 }
    242 
    243 func TestMalformedInput(t *testing.T) {
    244 	cli, srv := net.Pipe()
    245 	go cli.Write([]byte(`{id:1}`)) // invalid json
    246 	ServeConn(srv)                 // must return, not loop
    247 }
    248 
    249 func TestMalformedOutput(t *testing.T) {
    250 	cli, srv := net.Pipe()
    251 	go srv.Write([]byte(`{"id":0,"result":null,"error":null}`))
    252 	go ioutil.ReadAll(srv)
    253 
    254 	client := NewClient(cli)
    255 	defer client.Close()
    256 
    257 	args := &Args{7, 8}
    258 	reply := new(Reply)
    259 	err := client.Call("Arith.Add", args, reply)
    260 	if err == nil {
    261 		t.Error("expected error")
    262 	}
    263 }
    264 
    265 func TestServerErrorHasNullResult(t *testing.T) {
    266 	var out bytes.Buffer
    267 	sc := NewServerCodec(struct {
    268 		io.Reader
    269 		io.Writer
    270 		io.Closer
    271 	}{
    272 		Reader: strings.NewReader(`{"method": "Arith.Add", "id": "123", "params": []}`),
    273 		Writer: &out,
    274 		Closer: ioutil.NopCloser(nil),
    275 	})
    276 	r := new(rpc.Request)
    277 	if err := sc.ReadRequestHeader(r); err != nil {
    278 		t.Fatal(err)
    279 	}
    280 	const valueText = "the value we don't want to see"
    281 	const errorText = "some error"
    282 	err := sc.WriteResponse(&rpc.Response{
    283 		ServiceMethod: "Method",
    284 		Seq:           1,
    285 		Error:         errorText,
    286 	}, valueText)
    287 	if err != nil {
    288 		t.Fatal(err)
    289 	}
    290 	if !strings.Contains(out.String(), errorText) {
    291 		t.Fatalf("Response didn't contain expected error %q: %s", errorText, &out)
    292 	}
    293 	if strings.Contains(out.String(), valueText) {
    294 		t.Errorf("Response contains both an error and value: %s", &out)
    295 	}
    296 }
    297 
    298 func TestUnexpectedError(t *testing.T) {
    299 	cli, srv := myPipe()
    300 	go cli.PipeWriter.CloseWithError(errors.New("unexpected error!")) // reader will get this error
    301 	ServeConn(srv)                                                    // must return, not loop
    302 }
    303 
    304 // Copied from package net.
    305 func myPipe() (*pipe, *pipe) {
    306 	r1, w1 := io.Pipe()
    307 	r2, w2 := io.Pipe()
    308 
    309 	return &pipe{r1, w2}, &pipe{r2, w1}
    310 }
    311 
    312 type pipe struct {
    313 	*io.PipeReader
    314 	*io.PipeWriter
    315 }
    316 
    317 type pipeAddr int
    318 
    319 func (pipeAddr) Network() string {
    320 	return "pipe"
    321 }
    322 
    323 func (pipeAddr) String() string {
    324 	return "pipe"
    325 }
    326 
    327 func (p *pipe) Close() error {
    328 	err := p.PipeReader.Close()
    329 	err1 := p.PipeWriter.Close()
    330 	if err == nil {
    331 		err = err1
    332 	}
    333 	return err
    334 }
    335 
    336 func (p *pipe) LocalAddr() net.Addr {
    337 	return pipeAddr(0)
    338 }
    339 
    340 func (p *pipe) RemoteAddr() net.Addr {
    341 	return pipeAddr(0)
    342 }
    343 
    344 func (p *pipe) SetTimeout(nsec int64) error {
    345 	return errors.New("net.Pipe does not support timeouts")
    346 }
    347 
    348 func (p *pipe) SetReadTimeout(nsec int64) error {
    349 	return errors.New("net.Pipe does not support timeouts")
    350 }
    351 
    352 func (p *pipe) SetWriteTimeout(nsec int64) error {
    353 	return errors.New("net.Pipe does not support timeouts")
    354 }
    355