1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package rpc 6 7 import ( 8 "errors" 9 "fmt" 10 "io" 11 "log" 12 "net" 13 "net/http/httptest" 14 "runtime" 15 "strings" 16 "sync" 17 "sync/atomic" 18 "testing" 19 "time" 20 ) 21 22 var ( 23 newServer *Server 24 serverAddr, newServerAddr string 25 httpServerAddr string 26 once, newOnce, httpOnce sync.Once 27 ) 28 29 const ( 30 newHttpPath = "/foo" 31 ) 32 33 type Args struct { 34 A, B int 35 } 36 37 type Reply struct { 38 C int 39 } 40 41 type Arith int 42 43 // Some of Arith's methods have value args, some have pointer args. That's deliberate. 44 45 func (t *Arith) Add(args Args, reply *Reply) error { 46 reply.C = args.A + args.B 47 return nil 48 } 49 50 func (t *Arith) Mul(args *Args, reply *Reply) error { 51 reply.C = args.A * args.B 52 return nil 53 } 54 55 func (t *Arith) Div(args Args, reply *Reply) error { 56 if args.B == 0 { 57 return errors.New("divide by zero") 58 } 59 reply.C = args.A / args.B 60 return nil 61 } 62 63 func (t *Arith) String(args *Args, reply *string) error { 64 *reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B) 65 return nil 66 } 67 68 func (t *Arith) Scan(args string, reply *Reply) (err error) { 69 _, err = fmt.Sscan(args, &reply.C) 70 return 71 } 72 73 func (t *Arith) Error(args *Args, reply *Reply) error { 74 panic("ERROR") 75 } 76 77 func listenTCP() (net.Listener, string) { 78 l, e := net.Listen("tcp", "127.0.0.1:0") // any available address 79 if e != nil { 80 log.Fatalf("net.Listen tcp :0: %v", e) 81 } 82 return l, l.Addr().String() 83 } 84 85 func startServer() { 86 Register(new(Arith)) 87 RegisterName("net.rpc.Arith", new(Arith)) 88 89 var l net.Listener 90 l, serverAddr = listenTCP() 91 log.Println("Test RPC server listening on", serverAddr) 92 go Accept(l) 93 94 HandleHTTP() 95 httpOnce.Do(startHttpServer) 96 } 97 98 func startNewServer() { 99 newServer = NewServer() 100 newServer.Register(new(Arith)) 101 newServer.RegisterName("net.rpc.Arith", new(Arith)) 102 newServer.RegisterName("newServer.Arith", new(Arith)) 103 104 var l net.Listener 105 l, newServerAddr = listenTCP() 106 log.Println("NewServer test RPC server listening on", newServerAddr) 107 go newServer.Accept(l) 108 109 newServer.HandleHTTP(newHttpPath, "/bar") 110 httpOnce.Do(startHttpServer) 111 } 112 113 func startHttpServer() { 114 server := httptest.NewServer(nil) 115 httpServerAddr = server.Listener.Addr().String() 116 log.Println("Test HTTP RPC server listening on", httpServerAddr) 117 } 118 119 func TestRPC(t *testing.T) { 120 once.Do(startServer) 121 testRPC(t, serverAddr) 122 newOnce.Do(startNewServer) 123 testRPC(t, newServerAddr) 124 testNewServerRPC(t, newServerAddr) 125 } 126 127 func testRPC(t *testing.T, addr string) { 128 client, err := Dial("tcp", addr) 129 if err != nil { 130 t.Fatal("dialing", err) 131 } 132 defer client.Close() 133 134 // Synchronous calls 135 args := &Args{7, 8} 136 reply := new(Reply) 137 err = client.Call("Arith.Add", args, reply) 138 if err != nil { 139 t.Errorf("Add: expected no error but got string %q", err.Error()) 140 } 141 if reply.C != args.A+args.B { 142 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 143 } 144 145 // Nonexistent method 146 args = &Args{7, 0} 147 reply = new(Reply) 148 err = client.Call("Arith.BadOperation", args, reply) 149 // expect an error 150 if err == nil { 151 t.Error("BadOperation: expected error") 152 } else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") { 153 t.Errorf("BadOperation: expected can't find method error; got %q", err) 154 } 155 156 // Unknown service 157 args = &Args{7, 8} 158 reply = new(Reply) 159 err = client.Call("Arith.Unknown", args, reply) 160 if err == nil { 161 t.Error("expected error calling unknown service") 162 } else if strings.Index(err.Error(), "method") < 0 { 163 t.Error("expected error about method; got", err) 164 } 165 166 // Out of order. 167 args = &Args{7, 8} 168 mulReply := new(Reply) 169 mulCall := client.Go("Arith.Mul", args, mulReply, nil) 170 addReply := new(Reply) 171 addCall := client.Go("Arith.Add", args, addReply, nil) 172 173 addCall = <-addCall.Done 174 if addCall.Error != nil { 175 t.Errorf("Add: expected no error but got string %q", addCall.Error.Error()) 176 } 177 if addReply.C != args.A+args.B { 178 t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B) 179 } 180 181 mulCall = <-mulCall.Done 182 if mulCall.Error != nil { 183 t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error()) 184 } 185 if mulReply.C != args.A*args.B { 186 t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B) 187 } 188 189 // Error test 190 args = &Args{7, 0} 191 reply = new(Reply) 192 err = client.Call("Arith.Div", args, reply) 193 // expect an error: zero divide 194 if err == nil { 195 t.Error("Div: expected error") 196 } else if err.Error() != "divide by zero" { 197 t.Error("Div: expected divide by zero error; got", err) 198 } 199 200 // Bad type. 201 reply = new(Reply) 202 err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use 203 if err == nil { 204 t.Error("expected error calling Arith.Add with wrong arg type") 205 } else if strings.Index(err.Error(), "type") < 0 { 206 t.Error("expected error about type; got", err) 207 } 208 209 // Non-struct argument 210 const Val = 12345 211 str := fmt.Sprint(Val) 212 reply = new(Reply) 213 err = client.Call("Arith.Scan", &str, reply) 214 if err != nil { 215 t.Errorf("Scan: expected no error but got string %q", err.Error()) 216 } else if reply.C != Val { 217 t.Errorf("Scan: expected %d got %d", Val, reply.C) 218 } 219 220 // Non-struct reply 221 args = &Args{27, 35} 222 str = "" 223 err = client.Call("Arith.String", args, &str) 224 if err != nil { 225 t.Errorf("String: expected no error but got string %q", err.Error()) 226 } 227 expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B) 228 if str != expect { 229 t.Errorf("String: expected %s got %s", expect, str) 230 } 231 232 args = &Args{7, 8} 233 reply = new(Reply) 234 err = client.Call("Arith.Mul", args, reply) 235 if err != nil { 236 t.Errorf("Mul: expected no error but got string %q", err.Error()) 237 } 238 if reply.C != args.A*args.B { 239 t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B) 240 } 241 242 // ServiceName contain "." character 243 args = &Args{7, 8} 244 reply = new(Reply) 245 err = client.Call("net.rpc.Arith.Add", args, reply) 246 if err != nil { 247 t.Errorf("Add: expected no error but got string %q", err.Error()) 248 } 249 if reply.C != args.A+args.B { 250 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 251 } 252 } 253 254 func testNewServerRPC(t *testing.T, addr string) { 255 client, err := Dial("tcp", addr) 256 if err != nil { 257 t.Fatal("dialing", err) 258 } 259 defer client.Close() 260 261 // Synchronous calls 262 args := &Args{7, 8} 263 reply := new(Reply) 264 err = client.Call("newServer.Arith.Add", args, reply) 265 if err != nil { 266 t.Errorf("Add: expected no error but got string %q", err.Error()) 267 } 268 if reply.C != args.A+args.B { 269 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 270 } 271 } 272 273 func TestHTTP(t *testing.T) { 274 once.Do(startServer) 275 testHTTPRPC(t, "") 276 newOnce.Do(startNewServer) 277 testHTTPRPC(t, newHttpPath) 278 } 279 280 func testHTTPRPC(t *testing.T, path string) { 281 var client *Client 282 var err error 283 if path == "" { 284 client, err = DialHTTP("tcp", httpServerAddr) 285 } else { 286 client, err = DialHTTPPath("tcp", httpServerAddr, path) 287 } 288 if err != nil { 289 t.Fatal("dialing", err) 290 } 291 defer client.Close() 292 293 // Synchronous calls 294 args := &Args{7, 8} 295 reply := new(Reply) 296 err = client.Call("Arith.Add", args, reply) 297 if err != nil { 298 t.Errorf("Add: expected no error but got string %q", err.Error()) 299 } 300 if reply.C != args.A+args.B { 301 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 302 } 303 } 304 305 // CodecEmulator provides a client-like api and a ServerCodec interface. 306 // Can be used to test ServeRequest. 307 type CodecEmulator struct { 308 server *Server 309 serviceMethod string 310 args *Args 311 reply *Reply 312 err error 313 } 314 315 func (codec *CodecEmulator) Call(serviceMethod string, args *Args, reply *Reply) error { 316 codec.serviceMethod = serviceMethod 317 codec.args = args 318 codec.reply = reply 319 codec.err = nil 320 var serverError error 321 if codec.server == nil { 322 serverError = ServeRequest(codec) 323 } else { 324 serverError = codec.server.ServeRequest(codec) 325 } 326 if codec.err == nil && serverError != nil { 327 codec.err = serverError 328 } 329 return codec.err 330 } 331 332 func (codec *CodecEmulator) ReadRequestHeader(req *Request) error { 333 req.ServiceMethod = codec.serviceMethod 334 req.Seq = 0 335 return nil 336 } 337 338 func (codec *CodecEmulator) ReadRequestBody(argv interface{}) error { 339 if codec.args == nil { 340 return io.ErrUnexpectedEOF 341 } 342 *(argv.(*Args)) = *codec.args 343 return nil 344 } 345 346 func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) error { 347 if resp.Error != "" { 348 codec.err = errors.New(resp.Error) 349 } else { 350 *codec.reply = *(reply.(*Reply)) 351 } 352 return nil 353 } 354 355 func (codec *CodecEmulator) Close() error { 356 return nil 357 } 358 359 func TestServeRequest(t *testing.T) { 360 once.Do(startServer) 361 testServeRequest(t, nil) 362 newOnce.Do(startNewServer) 363 testServeRequest(t, newServer) 364 } 365 366 func testServeRequest(t *testing.T, server *Server) { 367 client := CodecEmulator{server: server} 368 defer client.Close() 369 370 args := &Args{7, 8} 371 reply := new(Reply) 372 err := client.Call("Arith.Add", args, reply) 373 if err != nil { 374 t.Errorf("Add: expected no error but got string %q", err.Error()) 375 } 376 if reply.C != args.A+args.B { 377 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 378 } 379 380 err = client.Call("Arith.Add", nil, reply) 381 if err == nil { 382 t.Errorf("expected error calling Arith.Add with nil arg") 383 } 384 } 385 386 type ReplyNotPointer int 387 type ArgNotPublic int 388 type ReplyNotPublic int 389 type NeedsPtrType int 390 type local struct{} 391 392 func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error { 393 return nil 394 } 395 396 func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) error { 397 return nil 398 } 399 400 func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error { 401 return nil 402 } 403 404 func (t *NeedsPtrType) NeedsPtrType(args *Args, reply *Reply) error { 405 return nil 406 } 407 408 // Check that registration handles lots of bad methods and a type with no suitable methods. 409 func TestRegistrationError(t *testing.T) { 410 err := Register(new(ReplyNotPointer)) 411 if err == nil { 412 t.Error("expected error registering ReplyNotPointer") 413 } 414 err = Register(new(ArgNotPublic)) 415 if err == nil { 416 t.Error("expected error registering ArgNotPublic") 417 } 418 err = Register(new(ReplyNotPublic)) 419 if err == nil { 420 t.Error("expected error registering ReplyNotPublic") 421 } 422 err = Register(NeedsPtrType(0)) 423 if err == nil { 424 t.Error("expected error registering NeedsPtrType") 425 } else if !strings.Contains(err.Error(), "pointer") { 426 t.Error("expected hint when registering NeedsPtrType") 427 } 428 } 429 430 type WriteFailCodec int 431 432 func (WriteFailCodec) WriteRequest(*Request, interface{}) error { 433 // the panic caused by this error used to not unlock a lock. 434 return errors.New("fail") 435 } 436 437 func (WriteFailCodec) ReadResponseHeader(*Response) error { 438 select {} 439 } 440 441 func (WriteFailCodec) ReadResponseBody(interface{}) error { 442 select {} 443 } 444 445 func (WriteFailCodec) Close() error { 446 return nil 447 } 448 449 func TestSendDeadlock(t *testing.T) { 450 client := NewClientWithCodec(WriteFailCodec(0)) 451 defer client.Close() 452 453 done := make(chan bool) 454 go func() { 455 testSendDeadlock(client) 456 testSendDeadlock(client) 457 done <- true 458 }() 459 select { 460 case <-done: 461 return 462 case <-time.After(5 * time.Second): 463 t.Fatal("deadlock") 464 } 465 } 466 467 func testSendDeadlock(client *Client) { 468 defer func() { 469 recover() 470 }() 471 args := &Args{7, 8} 472 reply := new(Reply) 473 client.Call("Arith.Add", args, reply) 474 } 475 476 func dialDirect() (*Client, error) { 477 return Dial("tcp", serverAddr) 478 } 479 480 func dialHTTP() (*Client, error) { 481 return DialHTTP("tcp", httpServerAddr) 482 } 483 484 func countMallocs(dial func() (*Client, error), t *testing.T) float64 { 485 once.Do(startServer) 486 client, err := dial() 487 if err != nil { 488 t.Fatal("error dialing", err) 489 } 490 defer client.Close() 491 492 args := &Args{7, 8} 493 reply := new(Reply) 494 return testing.AllocsPerRun(100, func() { 495 err := client.Call("Arith.Add", args, reply) 496 if err != nil { 497 t.Errorf("Add: expected no error but got string %q", err.Error()) 498 } 499 if reply.C != args.A+args.B { 500 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 501 } 502 }) 503 } 504 505 func TestCountMallocs(t *testing.T) { 506 if testing.Short() { 507 t.Skip("skipping malloc count in short mode") 508 } 509 if runtime.GOMAXPROCS(0) > 1 { 510 t.Skip("skipping; GOMAXPROCS>1") 511 } 512 fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t)) 513 } 514 515 func TestCountMallocsOverHTTP(t *testing.T) { 516 if testing.Short() { 517 t.Skip("skipping malloc count in short mode") 518 } 519 if runtime.GOMAXPROCS(0) > 1 { 520 t.Skip("skipping; GOMAXPROCS>1") 521 } 522 fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t)) 523 } 524 525 type writeCrasher struct { 526 done chan bool 527 } 528 529 func (writeCrasher) Close() error { 530 return nil 531 } 532 533 func (w *writeCrasher) Read(p []byte) (int, error) { 534 <-w.done 535 return 0, io.EOF 536 } 537 538 func (writeCrasher) Write(p []byte) (int, error) { 539 return 0, errors.New("fake write failure") 540 } 541 542 func TestClientWriteError(t *testing.T) { 543 w := &writeCrasher{done: make(chan bool)} 544 c := NewClient(w) 545 defer c.Close() 546 547 res := false 548 err := c.Call("foo", 1, &res) 549 if err == nil { 550 t.Fatal("expected error") 551 } 552 if err.Error() != "fake write failure" { 553 t.Error("unexpected value of error:", err) 554 } 555 w.done <- true 556 } 557 558 func TestTCPClose(t *testing.T) { 559 once.Do(startServer) 560 561 client, err := dialHTTP() 562 if err != nil { 563 t.Fatalf("dialing: %v", err) 564 } 565 defer client.Close() 566 567 args := Args{17, 8} 568 var reply Reply 569 err = client.Call("Arith.Mul", args, &reply) 570 if err != nil { 571 t.Fatal("arith error:", err) 572 } 573 t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply) 574 if reply.C != args.A*args.B { 575 t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B) 576 } 577 } 578 579 func TestErrorAfterClientClose(t *testing.T) { 580 once.Do(startServer) 581 582 client, err := dialHTTP() 583 if err != nil { 584 t.Fatalf("dialing: %v", err) 585 } 586 err = client.Close() 587 if err != nil { 588 t.Fatal("close error:", err) 589 } 590 err = client.Call("Arith.Add", &Args{7, 9}, new(Reply)) 591 if err != ErrShutdown { 592 t.Errorf("Forever: expected ErrShutdown got %v", err) 593 } 594 } 595 596 func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) { 597 once.Do(startServer) 598 client, err := dial() 599 if err != nil { 600 b.Fatal("error dialing:", err) 601 } 602 defer client.Close() 603 604 // Synchronous calls 605 args := &Args{7, 8} 606 b.ResetTimer() 607 608 b.RunParallel(func(pb *testing.PB) { 609 reply := new(Reply) 610 for pb.Next() { 611 err := client.Call("Arith.Add", args, reply) 612 if err != nil { 613 b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error()) 614 } 615 if reply.C != args.A+args.B { 616 b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B) 617 } 618 } 619 }) 620 } 621 622 func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) { 623 const MaxConcurrentCalls = 100 624 once.Do(startServer) 625 client, err := dial() 626 if err != nil { 627 b.Fatal("error dialing:", err) 628 } 629 defer client.Close() 630 631 // Asynchronous calls 632 args := &Args{7, 8} 633 procs := 4 * runtime.GOMAXPROCS(-1) 634 send := int32(b.N) 635 recv := int32(b.N) 636 var wg sync.WaitGroup 637 wg.Add(procs) 638 gate := make(chan bool, MaxConcurrentCalls) 639 res := make(chan *Call, MaxConcurrentCalls) 640 b.ResetTimer() 641 642 for p := 0; p < procs; p++ { 643 go func() { 644 for atomic.AddInt32(&send, -1) >= 0 { 645 gate <- true 646 reply := new(Reply) 647 client.Go("Arith.Add", args, reply, res) 648 } 649 }() 650 go func() { 651 for call := range res { 652 A := call.Args.(*Args).A 653 B := call.Args.(*Args).B 654 C := call.Reply.(*Reply).C 655 if A+B != C { 656 b.Fatalf("incorrect reply: Add: expected %d got %d", A+B, C) 657 } 658 <-gate 659 if atomic.AddInt32(&recv, -1) == 0 { 660 close(res) 661 } 662 } 663 wg.Done() 664 }() 665 } 666 wg.Wait() 667 } 668 669 func BenchmarkEndToEnd(b *testing.B) { 670 benchmarkEndToEnd(dialDirect, b) 671 } 672 673 func BenchmarkEndToEndHTTP(b *testing.B) { 674 benchmarkEndToEnd(dialHTTP, b) 675 } 676 677 func BenchmarkEndToEndAsync(b *testing.B) { 678 benchmarkEndToEndAsync(dialDirect, b) 679 } 680 681 func BenchmarkEndToEndAsyncHTTP(b *testing.B) { 682 benchmarkEndToEndAsync(dialHTTP, b) 683 } 684