Home | History | Annotate | Download | only in testserver
      1 package server
      2 
      3 import (
      4 	"encoding/base64"
      5 	"encoding/json"
      6 	"net/http"
      7 	"net/http/httptest"
      8 	"net/url"
      9 	"reflect"
     10 	"strconv"
     11 	"testing"
     12 )
     13 
     14 func composeQuery(path string, code int, headers http.Header, body []byte) (string, error) {
     15 	u, err := url.Parse(path)
     16 	if err != nil {
     17 		return "", err
     18 	}
     19 	q := u.Query()
     20 	if code > 0 {
     21 		q.Set("respStatus", strconv.Itoa(code))
     22 	}
     23 	if headers != nil {
     24 		h, err := json.Marshal(headers)
     25 		if err != nil {
     26 			return "", err
     27 		}
     28 		q.Set("respHeader", base64.URLEncoding.EncodeToString(h))
     29 	}
     30 	if len(body) > 0 {
     31 		q.Set("respBody", base64.URLEncoding.EncodeToString(body))
     32 	}
     33 	u.RawQuery = q.Encode()
     34 	return u.String(), nil
     35 }
     36 
     37 func TestResponseOverride(t *testing.T) {
     38 	tests := []struct {
     39 		name    string
     40 		code    int
     41 		headers http.Header
     42 		body    []byte
     43 	}{
     44 		{name: "code", code: 204},
     45 		{name: "body", body: []byte("new body")},
     46 		{
     47 			name: "headers",
     48 			headers: http.Header{
     49 				"Via":          []string{"Via1", "Via2"},
     50 				"Content-Type": []string{"random content"},
     51 			},
     52 		},
     53 		{
     54 			name: "everything",
     55 			code: 204,
     56 			body: []byte("new body"),
     57 			headers: http.Header{
     58 				"Via":          []string{"Via1", "Via2"},
     59 				"Content-Type": []string{"random content"},
     60 			},
     61 		},
     62 	}
     63 
     64 	for _, test := range tests {
     65 		u, err := composeQuery("http://test.com/override", test.code, test.headers, test.body)
     66 		if err != nil {
     67 			t.Errorf("%s: composeQuery: %v", test.name, err)
     68 			return
     69 		}
     70 		req, err := http.NewRequest("GET", u, nil)
     71 		if err != nil {
     72 			t.Errorf("%s: http.NewRequest: %v", test.name, err)
     73 			return
     74 		}
     75 		w := httptest.NewRecorder()
     76 		defaultResponse(w, req)
     77 		if test.code > 0 {
     78 			if got, want := w.Code, test.code; got != want {
     79 				t.Errorf("%s: response code: got %d want %d", test.name, got, want)
     80 				return
     81 			}
     82 		}
     83 		if test.headers != nil {
     84 			for k, want := range test.headers {
     85 				got, ok := w.HeaderMap[k]
     86 				if !ok || !reflect.DeepEqual(got, want) {
     87 					t.Errorf("%s: header %s: code: got %v want %v", test.name, k, got, want)
     88 					return
     89 				}
     90 			}
     91 		}
     92 		if test.body != nil {
     93 			if got, want := string(w.Body.Bytes()), string(test.body); got != want {
     94 				t.Errorf("%s: body: got %s want %s", test.name, got, want)
     95 				return
     96 			}
     97 		}
     98 	}
     99 }
    100