Home | History | Annotate | Download | only in sql
      1 // Copyright 2011 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package sql
      6 
      7 import (
      8 	"database/sql/driver"
      9 	"fmt"
     10 	"reflect"
     11 	"runtime"
     12 	"strings"
     13 	"testing"
     14 	"time"
     15 )
     16 
     17 var someTime = time.Unix(123, 0)
     18 var answer int64 = 42
     19 
     20 type userDefined float64
     21 
     22 type userDefinedSlice []int
     23 
     24 type conversionTest struct {
     25 	s, d interface{} // source and destination
     26 
     27 	// following are used if they're non-zero
     28 	wantint    int64
     29 	wantuint   uint64
     30 	wantstr    string
     31 	wantbytes  []byte
     32 	wantraw    RawBytes
     33 	wantf32    float32
     34 	wantf64    float64
     35 	wanttime   time.Time
     36 	wantbool   bool // used if d is of type *bool
     37 	wanterr    string
     38 	wantiface  interface{}
     39 	wantptr    *int64 // if non-nil, *d's pointed value must be equal to *wantptr
     40 	wantnil    bool   // if true, *d must be *int64(nil)
     41 	wantusrdef userDefined
     42 }
     43 
     44 // Target variables for scanning into.
     45 var (
     46 	scanstr    string
     47 	scanbytes  []byte
     48 	scanraw    RawBytes
     49 	scanint    int
     50 	scanint8   int8
     51 	scanint16  int16
     52 	scanint32  int32
     53 	scanuint8  uint8
     54 	scanuint16 uint16
     55 	scanbool   bool
     56 	scanf32    float32
     57 	scanf64    float64
     58 	scantime   time.Time
     59 	scanptr    *int64
     60 	scaniface  interface{}
     61 )
     62 
     63 var conversionTests = []conversionTest{
     64 	// Exact conversions (destination pointer type matches source type)
     65 	{s: "foo", d: &scanstr, wantstr: "foo"},
     66 	{s: 123, d: &scanint, wantint: 123},
     67 	{s: someTime, d: &scantime, wanttime: someTime},
     68 
     69 	// To strings
     70 	{s: "string", d: &scanstr, wantstr: "string"},
     71 	{s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"},
     72 	{s: 123, d: &scanstr, wantstr: "123"},
     73 	{s: int8(123), d: &scanstr, wantstr: "123"},
     74 	{s: int64(123), d: &scanstr, wantstr: "123"},
     75 	{s: uint8(123), d: &scanstr, wantstr: "123"},
     76 	{s: uint16(123), d: &scanstr, wantstr: "123"},
     77 	{s: uint32(123), d: &scanstr, wantstr: "123"},
     78 	{s: uint64(123), d: &scanstr, wantstr: "123"},
     79 	{s: 1.5, d: &scanstr, wantstr: "1.5"},
     80 
     81 	// From time.Time:
     82 	{s: time.Unix(1, 0).UTC(), d: &scanstr, wantstr: "1970-01-01T00:00:01Z"},
     83 	{s: time.Unix(1453874597, 0).In(time.FixedZone("here", -3600*8)), d: &scanstr, wantstr: "2016-01-26T22:03:17-08:00"},
     84 	{s: time.Unix(1, 2).UTC(), d: &scanstr, wantstr: "1970-01-01T00:00:01.000000002Z"},
     85 	{s: time.Time{}, d: &scanstr, wantstr: "0001-01-01T00:00:00Z"},
     86 	{s: time.Unix(1, 2).UTC(), d: &scanbytes, wantbytes: []byte("1970-01-01T00:00:01.000000002Z")},
     87 	{s: time.Unix(1, 2).UTC(), d: &scaniface, wantiface: time.Unix(1, 2).UTC()},
     88 
     89 	// To []byte
     90 	{s: nil, d: &scanbytes, wantbytes: nil},
     91 	{s: "string", d: &scanbytes, wantbytes: []byte("string")},
     92 	{s: []byte("byteslice"), d: &scanbytes, wantbytes: []byte("byteslice")},
     93 	{s: 123, d: &scanbytes, wantbytes: []byte("123")},
     94 	{s: int8(123), d: &scanbytes, wantbytes: []byte("123")},
     95 	{s: int64(123), d: &scanbytes, wantbytes: []byte("123")},
     96 	{s: uint8(123), d: &scanbytes, wantbytes: []byte("123")},
     97 	{s: uint16(123), d: &scanbytes, wantbytes: []byte("123")},
     98 	{s: uint32(123), d: &scanbytes, wantbytes: []byte("123")},
     99 	{s: uint64(123), d: &scanbytes, wantbytes: []byte("123")},
    100 	{s: 1.5, d: &scanbytes, wantbytes: []byte("1.5")},
    101 
    102 	// To RawBytes
    103 	{s: nil, d: &scanraw, wantraw: nil},
    104 	{s: []byte("byteslice"), d: &scanraw, wantraw: RawBytes("byteslice")},
    105 	{s: 123, d: &scanraw, wantraw: RawBytes("123")},
    106 	{s: int8(123), d: &scanraw, wantraw: RawBytes("123")},
    107 	{s: int64(123), d: &scanraw, wantraw: RawBytes("123")},
    108 	{s: uint8(123), d: &scanraw, wantraw: RawBytes("123")},
    109 	{s: uint16(123), d: &scanraw, wantraw: RawBytes("123")},
    110 	{s: uint32(123), d: &scanraw, wantraw: RawBytes("123")},
    111 	{s: uint64(123), d: &scanraw, wantraw: RawBytes("123")},
    112 	{s: 1.5, d: &scanraw, wantraw: RawBytes("1.5")},
    113 
    114 	// Strings to integers
    115 	{s: "255", d: &scanuint8, wantuint: 255},
    116 	{s: "256", d: &scanuint8, wanterr: "converting driver.Value type string (\"256\") to a uint8: value out of range"},
    117 	{s: "256", d: &scanuint16, wantuint: 256},
    118 	{s: "-1", d: &scanint, wantint: -1},
    119 	{s: "foo", d: &scanint, wanterr: "converting driver.Value type string (\"foo\") to a int: invalid syntax"},
    120 
    121 	// int64 to smaller integers
    122 	{s: int64(5), d: &scanuint8, wantuint: 5},
    123 	{s: int64(256), d: &scanuint8, wanterr: "converting driver.Value type int64 (\"256\") to a uint8: value out of range"},
    124 	{s: int64(256), d: &scanuint16, wantuint: 256},
    125 	{s: int64(65536), d: &scanuint16, wanterr: "converting driver.Value type int64 (\"65536\") to a uint16: value out of range"},
    126 
    127 	// True bools
    128 	{s: true, d: &scanbool, wantbool: true},
    129 	{s: "True", d: &scanbool, wantbool: true},
    130 	{s: "TRUE", d: &scanbool, wantbool: true},
    131 	{s: "1", d: &scanbool, wantbool: true},
    132 	{s: 1, d: &scanbool, wantbool: true},
    133 	{s: int64(1), d: &scanbool, wantbool: true},
    134 	{s: uint16(1), d: &scanbool, wantbool: true},
    135 
    136 	// False bools
    137 	{s: false, d: &scanbool, wantbool: false},
    138 	{s: "false", d: &scanbool, wantbool: false},
    139 	{s: "FALSE", d: &scanbool, wantbool: false},
    140 	{s: "0", d: &scanbool, wantbool: false},
    141 	{s: 0, d: &scanbool, wantbool: false},
    142 	{s: int64(0), d: &scanbool, wantbool: false},
    143 	{s: uint16(0), d: &scanbool, wantbool: false},
    144 
    145 	// Not bools
    146 	{s: "yup", d: &scanbool, wanterr: `sql/driver: couldn't convert "yup" into type bool`},
    147 	{s: 2, d: &scanbool, wanterr: `sql/driver: couldn't convert 2 into type bool`},
    148 
    149 	// Floats
    150 	{s: float64(1.5), d: &scanf64, wantf64: float64(1.5)},
    151 	{s: int64(1), d: &scanf64, wantf64: float64(1)},
    152 	{s: float64(1.5), d: &scanf32, wantf32: float32(1.5)},
    153 	{s: "1.5", d: &scanf32, wantf32: float32(1.5)},
    154 	{s: "1.5", d: &scanf64, wantf64: float64(1.5)},
    155 
    156 	// Pointers
    157 	{s: interface{}(nil), d: &scanptr, wantnil: true},
    158 	{s: int64(42), d: &scanptr, wantptr: &answer},
    159 
    160 	// To interface{}
    161 	{s: float64(1.5), d: &scaniface, wantiface: float64(1.5)},
    162 	{s: int64(1), d: &scaniface, wantiface: int64(1)},
    163 	{s: "str", d: &scaniface, wantiface: "str"},
    164 	{s: []byte("byteslice"), d: &scaniface, wantiface: []byte("byteslice")},
    165 	{s: true, d: &scaniface, wantiface: true},
    166 	{s: nil, d: &scaniface},
    167 	{s: []byte(nil), d: &scaniface, wantiface: []byte(nil)},
    168 
    169 	// To a user-defined type
    170 	{s: 1.5, d: new(userDefined), wantusrdef: 1.5},
    171 	{s: int64(123), d: new(userDefined), wantusrdef: 123},
    172 	{s: "1.5", d: new(userDefined), wantusrdef: 1.5},
    173 	{s: []byte{1, 2, 3}, d: new(userDefinedSlice), wanterr: `unsupported Scan, storing driver.Value type []uint8 into type *sql.userDefinedSlice`},
    174 
    175 	// Other errors
    176 	{s: complex(1, 2), d: &scanstr, wanterr: `unsupported Scan, storing driver.Value type complex128 into type *string`},
    177 }
    178 
    179 func intPtrValue(intptr interface{}) interface{} {
    180 	return reflect.Indirect(reflect.Indirect(reflect.ValueOf(intptr))).Int()
    181 }
    182 
    183 func intValue(intptr interface{}) int64 {
    184 	return reflect.Indirect(reflect.ValueOf(intptr)).Int()
    185 }
    186 
    187 func uintValue(intptr interface{}) uint64 {
    188 	return reflect.Indirect(reflect.ValueOf(intptr)).Uint()
    189 }
    190 
    191 func float64Value(ptr interface{}) float64 {
    192 	return *(ptr.(*float64))
    193 }
    194 
    195 func float32Value(ptr interface{}) float32 {
    196 	return *(ptr.(*float32))
    197 }
    198 
    199 func timeValue(ptr interface{}) time.Time {
    200 	return *(ptr.(*time.Time))
    201 }
    202 
    203 func TestConversions(t *testing.T) {
    204 	for n, ct := range conversionTests {
    205 		err := convertAssign(ct.d, ct.s)
    206 		errstr := ""
    207 		if err != nil {
    208 			errstr = err.Error()
    209 		}
    210 		errf := func(format string, args ...interface{}) {
    211 			base := fmt.Sprintf("convertAssign #%d: for %v (%T) -> %T, ", n, ct.s, ct.s, ct.d)
    212 			t.Errorf(base+format, args...)
    213 		}
    214 		if errstr != ct.wanterr {
    215 			errf("got error %q, want error %q", errstr, ct.wanterr)
    216 		}
    217 		if ct.wantstr != "" && ct.wantstr != scanstr {
    218 			errf("want string %q, got %q", ct.wantstr, scanstr)
    219 		}
    220 		if ct.wantint != 0 && ct.wantint != intValue(ct.d) {
    221 			errf("want int %d, got %d", ct.wantint, intValue(ct.d))
    222 		}
    223 		if ct.wantuint != 0 && ct.wantuint != uintValue(ct.d) {
    224 			errf("want uint %d, got %d", ct.wantuint, uintValue(ct.d))
    225 		}
    226 		if ct.wantf32 != 0 && ct.wantf32 != float32Value(ct.d) {
    227 			errf("want float32 %v, got %v", ct.wantf32, float32Value(ct.d))
    228 		}
    229 		if ct.wantf64 != 0 && ct.wantf64 != float64Value(ct.d) {
    230 			errf("want float32 %v, got %v", ct.wantf64, float64Value(ct.d))
    231 		}
    232 		if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" {
    233 			errf("want bool %v, got %v", ct.wantbool, *bp)
    234 		}
    235 		if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) {
    236 			errf("want time %v, got %v", ct.wanttime, timeValue(ct.d))
    237 		}
    238 		if ct.wantnil && *ct.d.(**int64) != nil {
    239 			errf("want nil, got %v", intPtrValue(ct.d))
    240 		}
    241 		if ct.wantptr != nil {
    242 			if *ct.d.(**int64) == nil {
    243 				errf("want pointer to %v, got nil", *ct.wantptr)
    244 			} else if *ct.wantptr != intPtrValue(ct.d) {
    245 				errf("want pointer to %v, got %v", *ct.wantptr, intPtrValue(ct.d))
    246 			}
    247 		}
    248 		if ifptr, ok := ct.d.(*interface{}); ok {
    249 			if !reflect.DeepEqual(ct.wantiface, scaniface) {
    250 				errf("want interface %#v, got %#v", ct.wantiface, scaniface)
    251 				continue
    252 			}
    253 			if srcBytes, ok := ct.s.([]byte); ok {
    254 				dstBytes := (*ifptr).([]byte)
    255 				if len(srcBytes) > 0 && &dstBytes[0] == &srcBytes[0] {
    256 					errf("copy into interface{} didn't copy []byte data")
    257 				}
    258 			}
    259 		}
    260 		if ct.wantusrdef != 0 && ct.wantusrdef != *ct.d.(*userDefined) {
    261 			errf("want userDefined %f, got %f", ct.wantusrdef, *ct.d.(*userDefined))
    262 		}
    263 	}
    264 }
    265 
    266 func TestNullString(t *testing.T) {
    267 	var ns NullString
    268 	convertAssign(&ns, []byte("foo"))
    269 	if !ns.Valid {
    270 		t.Errorf("expecting not null")
    271 	}
    272 	if ns.String != "foo" {
    273 		t.Errorf("expecting foo; got %q", ns.String)
    274 	}
    275 	convertAssign(&ns, nil)
    276 	if ns.Valid {
    277 		t.Errorf("expecting null on nil")
    278 	}
    279 	if ns.String != "" {
    280 		t.Errorf("expecting blank on nil; got %q", ns.String)
    281 	}
    282 }
    283 
    284 type valueConverterTest struct {
    285 	c       driver.ValueConverter
    286 	in, out interface{}
    287 	err     string
    288 }
    289 
    290 var valueConverterTests = []valueConverterTest{
    291 	{driver.DefaultParameterConverter, NullString{"hi", true}, "hi", ""},
    292 	{driver.DefaultParameterConverter, NullString{"", false}, nil, ""},
    293 }
    294 
    295 func TestValueConverters(t *testing.T) {
    296 	for i, tt := range valueConverterTests {
    297 		out, err := tt.c.ConvertValue(tt.in)
    298 		goterr := ""
    299 		if err != nil {
    300 			goterr = err.Error()
    301 		}
    302 		if goterr != tt.err {
    303 			t.Errorf("test %d: %T(%T(%v)) error = %q; want error = %q",
    304 				i, tt.c, tt.in, tt.in, goterr, tt.err)
    305 		}
    306 		if tt.err != "" {
    307 			continue
    308 		}
    309 		if !reflect.DeepEqual(out, tt.out) {
    310 			t.Errorf("test %d: %T(%T(%v)) = %v (%T); want %v (%T)",
    311 				i, tt.c, tt.in, tt.in, out, out, tt.out, tt.out)
    312 		}
    313 	}
    314 }
    315 
    316 // Tests that assigning to RawBytes doesn't allocate (and also works).
    317 func TestRawBytesAllocs(t *testing.T) {
    318 	var tests = []struct {
    319 		name string
    320 		in   interface{}
    321 		want string
    322 	}{
    323 		{"uint64", uint64(12345678), "12345678"},
    324 		{"uint32", uint32(1234), "1234"},
    325 		{"uint16", uint16(12), "12"},
    326 		{"uint8", uint8(1), "1"},
    327 		{"uint", uint(123), "123"},
    328 		{"int", int(123), "123"},
    329 		{"int8", int8(1), "1"},
    330 		{"int16", int16(12), "12"},
    331 		{"int32", int32(1234), "1234"},
    332 		{"int64", int64(12345678), "12345678"},
    333 		{"float32", float32(1.5), "1.5"},
    334 		{"float64", float64(64), "64"},
    335 		{"bool", false, "false"},
    336 	}
    337 
    338 	buf := make(RawBytes, 10)
    339 	test := func(name string, in interface{}, want string) {
    340 		if err := convertAssign(&buf, in); err != nil {
    341 			t.Fatalf("%s: convertAssign = %v", name, err)
    342 		}
    343 		match := len(buf) == len(want)
    344 		if match {
    345 			for i, b := range buf {
    346 				if want[i] != b {
    347 					match = false
    348 					break
    349 				}
    350 			}
    351 		}
    352 		if !match {
    353 			t.Fatalf("%s: got %q (len %d); want %q (len %d)", name, buf, len(buf), want, len(want))
    354 		}
    355 	}
    356 
    357 	n := testing.AllocsPerRun(100, func() {
    358 		for _, tt := range tests {
    359 			test(tt.name, tt.in, tt.want)
    360 		}
    361 	})
    362 
    363 	// The numbers below are only valid for 64-bit interface word sizes,
    364 	// and gc. With 32-bit words there are more convT2E allocs, and
    365 	// with gccgo, only pointers currently go in interface data.
    366 	// So only care on amd64 gc for now.
    367 	measureAllocs := runtime.GOARCH == "amd64" && runtime.Compiler == "gc"
    368 
    369 	if n > 0.5 && measureAllocs {
    370 		t.Fatalf("allocs = %v; want 0", n)
    371 	}
    372 
    373 	// This one involves a convT2E allocation, string -> interface{}
    374 	n = testing.AllocsPerRun(100, func() {
    375 		test("string", "foo", "foo")
    376 	})
    377 	if n > 1.5 && measureAllocs {
    378 		t.Fatalf("allocs = %v; want max 1", n)
    379 	}
    380 }
    381 
    382 // https://github.com/golang/go/issues/13905
    383 func TestUserDefinedBytes(t *testing.T) {
    384 	type userDefinedBytes []byte
    385 	var u userDefinedBytes
    386 	v := []byte("foo")
    387 
    388 	convertAssign(&u, v)
    389 	if &u[0] == &v[0] {
    390 		t.Fatal("userDefinedBytes got potentially dirty driver memory")
    391 	}
    392 }
    393 
    394 type Valuer_V string
    395 
    396 func (v Valuer_V) Value() (driver.Value, error) {
    397 	return strings.ToUpper(string(v)), nil
    398 }
    399 
    400 type Valuer_P string
    401 
    402 func (p *Valuer_P) Value() (driver.Value, error) {
    403 	if p == nil {
    404 		return "nil-to-str", nil
    405 	}
    406 	return strings.ToUpper(string(*p)), nil
    407 }
    408 
    409 func TestDriverArgs(t *testing.T) {
    410 	var nilValuerVPtr *Valuer_V
    411 	var nilValuerPPtr *Valuer_P
    412 	var nilStrPtr *string
    413 	tests := []struct {
    414 		args []interface{}
    415 		want []driver.NamedValue
    416 	}{
    417 		0: {
    418 			args: []interface{}{Valuer_V("foo")},
    419 			want: []driver.NamedValue{
    420 				driver.NamedValue{
    421 					Ordinal: 1,
    422 					Value:   "FOO",
    423 				},
    424 			},
    425 		},
    426 		1: {
    427 			args: []interface{}{nilValuerVPtr},
    428 			want: []driver.NamedValue{
    429 				driver.NamedValue{
    430 					Ordinal: 1,
    431 					Value:   nil,
    432 				},
    433 			},
    434 		},
    435 		2: {
    436 			args: []interface{}{nilValuerPPtr},
    437 			want: []driver.NamedValue{
    438 				driver.NamedValue{
    439 					Ordinal: 1,
    440 					Value:   "nil-to-str",
    441 				},
    442 			},
    443 		},
    444 		3: {
    445 			args: []interface{}{"plain-str"},
    446 			want: []driver.NamedValue{
    447 				driver.NamedValue{
    448 					Ordinal: 1,
    449 					Value:   "plain-str",
    450 				},
    451 			},
    452 		},
    453 		4: {
    454 			args: []interface{}{nilStrPtr},
    455 			want: []driver.NamedValue{
    456 				driver.NamedValue{
    457 					Ordinal: 1,
    458 					Value:   nil,
    459 				},
    460 			},
    461 		},
    462 	}
    463 	for i, tt := range tests {
    464 		ds := new(driverStmt)
    465 		got, err := driverArgs(ds, tt.args)
    466 		if err != nil {
    467 			t.Errorf("test[%d]: %v", i, err)
    468 			continue
    469 		}
    470 		if !reflect.DeepEqual(got, tt.want) {
    471 			t.Errorf("test[%d]: got %v, want %v", i, got, tt.want)
    472 		}
    473 	}
    474 }
    475