Home | History | Annotate | Download | only in sql
      1 // Copyright 2011 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package sql
      6 
      7 import (
      8 	"database/sql/driver"
      9 	"fmt"
     10 	"reflect"
     11 	"runtime"
     12 	"strings"
     13 	"sync"
     14 	"testing"
     15 	"time"
     16 )
     17 
     18 var someTime = time.Unix(123, 0)
     19 var answer int64 = 42
     20 
     21 type (
     22 	userDefined       float64
     23 	userDefinedSlice  []int
     24 	userDefinedString string
     25 )
     26 
     27 type conversionTest struct {
     28 	s, d interface{} // source and destination
     29 
     30 	// following are used if they're non-zero
     31 	wantint    int64
     32 	wantuint   uint64
     33 	wantstr    string
     34 	wantbytes  []byte
     35 	wantraw    RawBytes
     36 	wantf32    float32
     37 	wantf64    float64
     38 	wanttime   time.Time
     39 	wantbool   bool // used if d is of type *bool
     40 	wanterr    string
     41 	wantiface  interface{}
     42 	wantptr    *int64 // if non-nil, *d's pointed value must be equal to *wantptr
     43 	wantnil    bool   // if true, *d must be *int64(nil)
     44 	wantusrdef userDefined
     45 	wantusrstr userDefinedString
     46 }
     47 
     48 // Target variables for scanning into.
     49 var (
     50 	scanstr    string
     51 	scanbytes  []byte
     52 	scanraw    RawBytes
     53 	scanint    int
     54 	scanint8   int8
     55 	scanint16  int16
     56 	scanint32  int32
     57 	scanuint8  uint8
     58 	scanuint16 uint16
     59 	scanbool   bool
     60 	scanf32    float32
     61 	scanf64    float64
     62 	scantime   time.Time
     63 	scanptr    *int64
     64 	scaniface  interface{}
     65 )
     66 
     67 func conversionTests() []conversionTest {
     68 	// Return a fresh instance to test so "go test -count 2" works correctly.
     69 	return []conversionTest{
     70 		// Exact conversions (destination pointer type matches source type)
     71 		{s: "foo", d: &scanstr, wantstr: "foo"},
     72 		{s: 123, d: &scanint, wantint: 123},
     73 		{s: someTime, d: &scantime, wanttime: someTime},
     74 
     75 		// To strings
     76 		{s: "string", d: &scanstr, wantstr: "string"},
     77 		{s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"},
     78 		{s: 123, d: &scanstr, wantstr: "123"},
     79 		{s: int8(123), d: &scanstr, wantstr: "123"},
     80 		{s: int64(123), d: &scanstr, wantstr: "123"},
     81 		{s: uint8(123), d: &scanstr, wantstr: "123"},
     82 		{s: uint16(123), d: &scanstr, wantstr: "123"},
     83 		{s: uint32(123), d: &scanstr, wantstr: "123"},
     84 		{s: uint64(123), d: &scanstr, wantstr: "123"},
     85 		{s: 1.5, d: &scanstr, wantstr: "1.5"},
     86 
     87 		// From time.Time:
     88 		{s: time.Unix(1, 0).UTC(), d: &scanstr, wantstr: "1970-01-01T00:00:01Z"},
     89 		{s: time.Unix(1453874597, 0).In(time.FixedZone("here", -3600*8)), d: &scanstr, wantstr: "2016-01-26T22:03:17-08:00"},
     90 		{s: time.Unix(1, 2).UTC(), d: &scanstr, wantstr: "1970-01-01T00:00:01.000000002Z"},
     91 		{s: time.Time{}, d: &scanstr, wantstr: "0001-01-01T00:00:00Z"},
     92 		{s: time.Unix(1, 2).UTC(), d: &scanbytes, wantbytes: []byte("1970-01-01T00:00:01.000000002Z")},
     93 		{s: time.Unix(1, 2).UTC(), d: &scaniface, wantiface: time.Unix(1, 2).UTC()},
     94 
     95 		// To []byte
     96 		{s: nil, d: &scanbytes, wantbytes: nil},
     97 		{s: "string", d: &scanbytes, wantbytes: []byte("string")},
     98 		{s: []byte("byteslice"), d: &scanbytes, wantbytes: []byte("byteslice")},
     99 		{s: 123, d: &scanbytes, wantbytes: []byte("123")},
    100 		{s: int8(123), d: &scanbytes, wantbytes: []byte("123")},
    101 		{s: int64(123), d: &scanbytes, wantbytes: []byte("123")},
    102 		{s: uint8(123), d: &scanbytes, wantbytes: []byte("123")},
    103 		{s: uint16(123), d: &scanbytes, wantbytes: []byte("123")},
    104 		{s: uint32(123), d: &scanbytes, wantbytes: []byte("123")},
    105 		{s: uint64(123), d: &scanbytes, wantbytes: []byte("123")},
    106 		{s: 1.5, d: &scanbytes, wantbytes: []byte("1.5")},
    107 
    108 		// To RawBytes
    109 		{s: nil, d: &scanraw, wantraw: nil},
    110 		{s: []byte("byteslice"), d: &scanraw, wantraw: RawBytes("byteslice")},
    111 		{s: "string", d: &scanraw, wantraw: RawBytes("string")},
    112 		{s: 123, d: &scanraw, wantraw: RawBytes("123")},
    113 		{s: int8(123), d: &scanraw, wantraw: RawBytes("123")},
    114 		{s: int64(123), d: &scanraw, wantraw: RawBytes("123")},
    115 		{s: uint8(123), d: &scanraw, wantraw: RawBytes("123")},
    116 		{s: uint16(123), d: &scanraw, wantraw: RawBytes("123")},
    117 		{s: uint32(123), d: &scanraw, wantraw: RawBytes("123")},
    118 		{s: uint64(123), d: &scanraw, wantraw: RawBytes("123")},
    119 		{s: 1.5, d: &scanraw, wantraw: RawBytes("1.5")},
    120 		// time.Time has been placed here to check that the RawBytes slice gets
    121 		// correctly reset when calling time.Time.AppendFormat.
    122 		{s: time.Unix(2, 5).UTC(), d: &scanraw, wantraw: RawBytes("1970-01-01T00:00:02.000000005Z")},
    123 
    124 		// Strings to integers
    125 		{s: "255", d: &scanuint8, wantuint: 255},
    126 		{s: "256", d: &scanuint8, wanterr: "converting driver.Value type string (\"256\") to a uint8: value out of range"},
    127 		{s: "256", d: &scanuint16, wantuint: 256},
    128 		{s: "-1", d: &scanint, wantint: -1},
    129 		{s: "foo", d: &scanint, wanterr: "converting driver.Value type string (\"foo\") to a int: invalid syntax"},
    130 
    131 		// int64 to smaller integers
    132 		{s: int64(5), d: &scanuint8, wantuint: 5},
    133 		{s: int64(256), d: &scanuint8, wanterr: "converting driver.Value type int64 (\"256\") to a uint8: value out of range"},
    134 		{s: int64(256), d: &scanuint16, wantuint: 256},
    135 		{s: int64(65536), d: &scanuint16, wanterr: "converting driver.Value type int64 (\"65536\") to a uint16: value out of range"},
    136 
    137 		// True bools
    138 		{s: true, d: &scanbool, wantbool: true},
    139 		{s: "True", d: &scanbool, wantbool: true},
    140 		{s: "TRUE", d: &scanbool, wantbool: true},
    141 		{s: "1", d: &scanbool, wantbool: true},
    142 		{s: 1, d: &scanbool, wantbool: true},
    143 		{s: int64(1), d: &scanbool, wantbool: true},
    144 		{s: uint16(1), d: &scanbool, wantbool: true},
    145 
    146 		// False bools
    147 		{s: false, d: &scanbool, wantbool: false},
    148 		{s: "false", d: &scanbool, wantbool: false},
    149 		{s: "FALSE", d: &scanbool, wantbool: false},
    150 		{s: "0", d: &scanbool, wantbool: false},
    151 		{s: 0, d: &scanbool, wantbool: false},
    152 		{s: int64(0), d: &scanbool, wantbool: false},
    153 		{s: uint16(0), d: &scanbool, wantbool: false},
    154 
    155 		// Not bools
    156 		{s: "yup", d: &scanbool, wanterr: `sql/driver: couldn't convert "yup" into type bool`},
    157 		{s: 2, d: &scanbool, wanterr: `sql/driver: couldn't convert 2 into type bool`},
    158 
    159 		// Floats
    160 		{s: float64(1.5), d: &scanf64, wantf64: float64(1.5)},
    161 		{s: int64(1), d: &scanf64, wantf64: float64(1)},
    162 		{s: float64(1.5), d: &scanf32, wantf32: float32(1.5)},
    163 		{s: "1.5", d: &scanf32, wantf32: float32(1.5)},
    164 		{s: "1.5", d: &scanf64, wantf64: float64(1.5)},
    165 
    166 		// Pointers
    167 		{s: interface{}(nil), d: &scanptr, wantnil: true},
    168 		{s: int64(42), d: &scanptr, wantptr: &answer},
    169 
    170 		// To interface{}
    171 		{s: float64(1.5), d: &scaniface, wantiface: float64(1.5)},
    172 		{s: int64(1), d: &scaniface, wantiface: int64(1)},
    173 		{s: "str", d: &scaniface, wantiface: "str"},
    174 		{s: []byte("byteslice"), d: &scaniface, wantiface: []byte("byteslice")},
    175 		{s: true, d: &scaniface, wantiface: true},
    176 		{s: nil, d: &scaniface},
    177 		{s: []byte(nil), d: &scaniface, wantiface: []byte(nil)},
    178 
    179 		// To a user-defined type
    180 		{s: 1.5, d: new(userDefined), wantusrdef: 1.5},
    181 		{s: int64(123), d: new(userDefined), wantusrdef: 123},
    182 		{s: "1.5", d: new(userDefined), wantusrdef: 1.5},
    183 		{s: []byte{1, 2, 3}, d: new(userDefinedSlice), wanterr: `unsupported Scan, storing driver.Value type []uint8 into type *sql.userDefinedSlice`},
    184 		{s: "str", d: new(userDefinedString), wantusrstr: "str"},
    185 
    186 		// Other errors
    187 		{s: complex(1, 2), d: &scanstr, wanterr: `unsupported Scan, storing driver.Value type complex128 into type *string`},
    188 	}
    189 }
    190 
    191 func intPtrValue(intptr interface{}) interface{} {
    192 	return reflect.Indirect(reflect.Indirect(reflect.ValueOf(intptr))).Int()
    193 }
    194 
    195 func intValue(intptr interface{}) int64 {
    196 	return reflect.Indirect(reflect.ValueOf(intptr)).Int()
    197 }
    198 
    199 func uintValue(intptr interface{}) uint64 {
    200 	return reflect.Indirect(reflect.ValueOf(intptr)).Uint()
    201 }
    202 
    203 func float64Value(ptr interface{}) float64 {
    204 	return *(ptr.(*float64))
    205 }
    206 
    207 func float32Value(ptr interface{}) float32 {
    208 	return *(ptr.(*float32))
    209 }
    210 
    211 func timeValue(ptr interface{}) time.Time {
    212 	return *(ptr.(*time.Time))
    213 }
    214 
    215 func TestConversions(t *testing.T) {
    216 	for n, ct := range conversionTests() {
    217 		err := convertAssign(ct.d, ct.s)
    218 		errstr := ""
    219 		if err != nil {
    220 			errstr = err.Error()
    221 		}
    222 		errf := func(format string, args ...interface{}) {
    223 			base := fmt.Sprintf("convertAssign #%d: for %v (%T) -> %T, ", n, ct.s, ct.s, ct.d)
    224 			t.Errorf(base+format, args...)
    225 		}
    226 		if errstr != ct.wanterr {
    227 			errf("got error %q, want error %q", errstr, ct.wanterr)
    228 		}
    229 		if ct.wantstr != "" && ct.wantstr != scanstr {
    230 			errf("want string %q, got %q", ct.wantstr, scanstr)
    231 		}
    232 		if ct.wantbytes != nil && string(ct.wantbytes) != string(scanbytes) {
    233 			errf("want byte %q, got %q", ct.wantbytes, scanbytes)
    234 		}
    235 		if ct.wantraw != nil && string(ct.wantraw) != string(scanraw) {
    236 			errf("want RawBytes %q, got %q", ct.wantraw, scanraw)
    237 		}
    238 		if ct.wantint != 0 && ct.wantint != intValue(ct.d) {
    239 			errf("want int %d, got %d", ct.wantint, intValue(ct.d))
    240 		}
    241 		if ct.wantuint != 0 && ct.wantuint != uintValue(ct.d) {
    242 			errf("want uint %d, got %d", ct.wantuint, uintValue(ct.d))
    243 		}
    244 		if ct.wantf32 != 0 && ct.wantf32 != float32Value(ct.d) {
    245 			errf("want float32 %v, got %v", ct.wantf32, float32Value(ct.d))
    246 		}
    247 		if ct.wantf64 != 0 && ct.wantf64 != float64Value(ct.d) {
    248 			errf("want float32 %v, got %v", ct.wantf64, float64Value(ct.d))
    249 		}
    250 		if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" {
    251 			errf("want bool %v, got %v", ct.wantbool, *bp)
    252 		}
    253 		if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) {
    254 			errf("want time %v, got %v", ct.wanttime, timeValue(ct.d))
    255 		}
    256 		if ct.wantnil && *ct.d.(**int64) != nil {
    257 			errf("want nil, got %v", intPtrValue(ct.d))
    258 		}
    259 		if ct.wantptr != nil {
    260 			if *ct.d.(**int64) == nil {
    261 				errf("want pointer to %v, got nil", *ct.wantptr)
    262 			} else if *ct.wantptr != intPtrValue(ct.d) {
    263 				errf("want pointer to %v, got %v", *ct.wantptr, intPtrValue(ct.d))
    264 			}
    265 		}
    266 		if ifptr, ok := ct.d.(*interface{}); ok {
    267 			if !reflect.DeepEqual(ct.wantiface, scaniface) {
    268 				errf("want interface %#v, got %#v", ct.wantiface, scaniface)
    269 				continue
    270 			}
    271 			if srcBytes, ok := ct.s.([]byte); ok {
    272 				dstBytes := (*ifptr).([]byte)
    273 				if len(srcBytes) > 0 && &dstBytes[0] == &srcBytes[0] {
    274 					errf("copy into interface{} didn't copy []byte data")
    275 				}
    276 			}
    277 		}
    278 		if ct.wantusrdef != 0 && ct.wantusrdef != *ct.d.(*userDefined) {
    279 			errf("want userDefined %f, got %f", ct.wantusrdef, *ct.d.(*userDefined))
    280 		}
    281 		if len(ct.wantusrstr) != 0 && ct.wantusrstr != *ct.d.(*userDefinedString) {
    282 			errf("want userDefined %q, got %q", ct.wantusrstr, *ct.d.(*userDefinedString))
    283 		}
    284 	}
    285 }
    286 
    287 func TestNullString(t *testing.T) {
    288 	var ns NullString
    289 	convertAssign(&ns, []byte("foo"))
    290 	if !ns.Valid {
    291 		t.Errorf("expecting not null")
    292 	}
    293 	if ns.String != "foo" {
    294 		t.Errorf("expecting foo; got %q", ns.String)
    295 	}
    296 	convertAssign(&ns, nil)
    297 	if ns.Valid {
    298 		t.Errorf("expecting null on nil")
    299 	}
    300 	if ns.String != "" {
    301 		t.Errorf("expecting blank on nil; got %q", ns.String)
    302 	}
    303 }
    304 
    305 type valueConverterTest struct {
    306 	c       driver.ValueConverter
    307 	in, out interface{}
    308 	err     string
    309 }
    310 
    311 var valueConverterTests = []valueConverterTest{
    312 	{driver.DefaultParameterConverter, NullString{"hi", true}, "hi", ""},
    313 	{driver.DefaultParameterConverter, NullString{"", false}, nil, ""},
    314 }
    315 
    316 func TestValueConverters(t *testing.T) {
    317 	for i, tt := range valueConverterTests {
    318 		out, err := tt.c.ConvertValue(tt.in)
    319 		goterr := ""
    320 		if err != nil {
    321 			goterr = err.Error()
    322 		}
    323 		if goterr != tt.err {
    324 			t.Errorf("test %d: %T(%T(%v)) error = %q; want error = %q",
    325 				i, tt.c, tt.in, tt.in, goterr, tt.err)
    326 		}
    327 		if tt.err != "" {
    328 			continue
    329 		}
    330 		if !reflect.DeepEqual(out, tt.out) {
    331 			t.Errorf("test %d: %T(%T(%v)) = %v (%T); want %v (%T)",
    332 				i, tt.c, tt.in, tt.in, out, out, tt.out, tt.out)
    333 		}
    334 	}
    335 }
    336 
    337 // Tests that assigning to RawBytes doesn't allocate (and also works).
    338 func TestRawBytesAllocs(t *testing.T) {
    339 	var tests = []struct {
    340 		name string
    341 		in   interface{}
    342 		want string
    343 	}{
    344 		{"uint64", uint64(12345678), "12345678"},
    345 		{"uint32", uint32(1234), "1234"},
    346 		{"uint16", uint16(12), "12"},
    347 		{"uint8", uint8(1), "1"},
    348 		{"uint", uint(123), "123"},
    349 		{"int", int(123), "123"},
    350 		{"int8", int8(1), "1"},
    351 		{"int16", int16(12), "12"},
    352 		{"int32", int32(1234), "1234"},
    353 		{"int64", int64(12345678), "12345678"},
    354 		{"float32", float32(1.5), "1.5"},
    355 		{"float64", float64(64), "64"},
    356 		{"bool", false, "false"},
    357 		{"time", time.Unix(2, 5).UTC(), "1970-01-01T00:00:02.000000005Z"},
    358 	}
    359 
    360 	buf := make(RawBytes, 10)
    361 	test := func(name string, in interface{}, want string) {
    362 		if err := convertAssign(&buf, in); err != nil {
    363 			t.Fatalf("%s: convertAssign = %v", name, err)
    364 		}
    365 		match := len(buf) == len(want)
    366 		if match {
    367 			for i, b := range buf {
    368 				if want[i] != b {
    369 					match = false
    370 					break
    371 				}
    372 			}
    373 		}
    374 		if !match {
    375 			t.Fatalf("%s: got %q (len %d); want %q (len %d)", name, buf, len(buf), want, len(want))
    376 		}
    377 	}
    378 
    379 	n := testing.AllocsPerRun(100, func() {
    380 		for _, tt := range tests {
    381 			test(tt.name, tt.in, tt.want)
    382 		}
    383 	})
    384 
    385 	// The numbers below are only valid for 64-bit interface word sizes,
    386 	// and gc. With 32-bit words there are more convT2E allocs, and
    387 	// with gccgo, only pointers currently go in interface data.
    388 	// So only care on amd64 gc for now.
    389 	measureAllocs := runtime.GOARCH == "amd64" && runtime.Compiler == "gc"
    390 
    391 	if n > 0.5 && measureAllocs {
    392 		t.Fatalf("allocs = %v; want 0", n)
    393 	}
    394 
    395 	// This one involves a convT2E allocation, string -> interface{}
    396 	n = testing.AllocsPerRun(100, func() {
    397 		test("string", "foo", "foo")
    398 	})
    399 	if n > 1.5 && measureAllocs {
    400 		t.Fatalf("allocs = %v; want max 1", n)
    401 	}
    402 }
    403 
    404 // https://golang.org/issues/13905
    405 func TestUserDefinedBytes(t *testing.T) {
    406 	type userDefinedBytes []byte
    407 	var u userDefinedBytes
    408 	v := []byte("foo")
    409 
    410 	convertAssign(&u, v)
    411 	if &u[0] == &v[0] {
    412 		t.Fatal("userDefinedBytes got potentially dirty driver memory")
    413 	}
    414 }
    415 
    416 type Valuer_V string
    417 
    418 func (v Valuer_V) Value() (driver.Value, error) {
    419 	return strings.ToUpper(string(v)), nil
    420 }
    421 
    422 type Valuer_P string
    423 
    424 func (p *Valuer_P) Value() (driver.Value, error) {
    425 	if p == nil {
    426 		return "nil-to-str", nil
    427 	}
    428 	return strings.ToUpper(string(*p)), nil
    429 }
    430 
    431 func TestDriverArgs(t *testing.T) {
    432 	var nilValuerVPtr *Valuer_V
    433 	var nilValuerPPtr *Valuer_P
    434 	var nilStrPtr *string
    435 	tests := []struct {
    436 		args []interface{}
    437 		want []driver.NamedValue
    438 	}{
    439 		0: {
    440 			args: []interface{}{Valuer_V("foo")},
    441 			want: []driver.NamedValue{
    442 				driver.NamedValue{
    443 					Ordinal: 1,
    444 					Value:   "FOO",
    445 				},
    446 			},
    447 		},
    448 		1: {
    449 			args: []interface{}{nilValuerVPtr},
    450 			want: []driver.NamedValue{
    451 				driver.NamedValue{
    452 					Ordinal: 1,
    453 					Value:   nil,
    454 				},
    455 			},
    456 		},
    457 		2: {
    458 			args: []interface{}{nilValuerPPtr},
    459 			want: []driver.NamedValue{
    460 				driver.NamedValue{
    461 					Ordinal: 1,
    462 					Value:   "nil-to-str",
    463 				},
    464 			},
    465 		},
    466 		3: {
    467 			args: []interface{}{"plain-str"},
    468 			want: []driver.NamedValue{
    469 				driver.NamedValue{
    470 					Ordinal: 1,
    471 					Value:   "plain-str",
    472 				},
    473 			},
    474 		},
    475 		4: {
    476 			args: []interface{}{nilStrPtr},
    477 			want: []driver.NamedValue{
    478 				driver.NamedValue{
    479 					Ordinal: 1,
    480 					Value:   nil,
    481 				},
    482 			},
    483 		},
    484 	}
    485 	for i, tt := range tests {
    486 		ds := &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{nil}}
    487 		got, err := driverArgsConnLocked(nil, ds, tt.args)
    488 		if err != nil {
    489 			t.Errorf("test[%d]: %v", i, err)
    490 			continue
    491 		}
    492 		if !reflect.DeepEqual(got, tt.want) {
    493 			t.Errorf("test[%d]: got %v, want %v", i, got, tt.want)
    494 		}
    495 	}
    496 }
    497