1 package server 2 3 import ( 4 "encoding/base64" 5 "encoding/json" 6 "net/http" 7 "net/http/httptest" 8 "net/url" 9 "reflect" 10 "strconv" 11 "testing" 12 ) 13 14 func composeQuery(path string, code int, headers http.Header, body []byte) (string, error) { 15 u, err := url.Parse(path) 16 if err != nil { 17 return "", err 18 } 19 q := u.Query() 20 if code > 0 { 21 q.Set("respStatus", strconv.Itoa(code)) 22 } 23 if headers != nil { 24 h, err := json.Marshal(headers) 25 if err != nil { 26 return "", err 27 } 28 q.Set("respHeader", base64.URLEncoding.EncodeToString(h)) 29 } 30 if len(body) > 0 { 31 q.Set("respBody", base64.URLEncoding.EncodeToString(body)) 32 } 33 u.RawQuery = q.Encode() 34 return u.String(), nil 35 } 36 37 func TestResponseOverride(t *testing.T) { 38 tests := []struct { 39 name string 40 code int 41 headers http.Header 42 body []byte 43 }{ 44 {name: "code", code: 204}, 45 {name: "body", body: []byte("new body")}, 46 { 47 name: "headers", 48 headers: http.Header{ 49 "Via": []string{"Via1", "Via2"}, 50 "Content-Type": []string{"random content"}, 51 }, 52 }, 53 { 54 name: "everything", 55 code: 204, 56 body: []byte("new body"), 57 headers: http.Header{ 58 "Via": []string{"Via1", "Via2"}, 59 "Content-Type": []string{"random content"}, 60 }, 61 }, 62 } 63 64 for _, test := range tests { 65 u, err := composeQuery("http://test.com/override", test.code, test.headers, test.body) 66 if err != nil { 67 t.Errorf("%s: composeQuery: %v", test.name, err) 68 return 69 } 70 req, err := http.NewRequest("GET", u, nil) 71 if err != nil { 72 t.Errorf("%s: http.NewRequest: %v", test.name, err) 73 return 74 } 75 w := httptest.NewRecorder() 76 defaultResponse(w, req) 77 if test.code > 0 { 78 if got, want := w.Code, test.code; got != want { 79 t.Errorf("%s: response code: got %d want %d", test.name, got, want) 80 return 81 } 82 } 83 if test.headers != nil { 84 for k, want := range test.headers { 85 got, ok := w.HeaderMap[k] 86 if !ok || !reflect.DeepEqual(got, want) { 87 t.Errorf("%s: header %s: code: got %v want %v", test.name, k, got, want) 88 return 89 } 90 } 91 } 92 if test.body != nil { 93 if got, want := string(w.Body.Bytes()), string(test.body); got != want { 94 t.Errorf("%s: body: got %s want %s", test.name, got, want) 95 return 96 } 97 } 98 } 99 } 100