Home | History | Annotate | Download | only in sql
      1 // Copyright 2011 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package sql
      6 
      7 import (
      8 	"database/sql/driver"
      9 	"errors"
     10 	"fmt"
     11 	"io"
     12 	"log"
     13 	"sort"
     14 	"strconv"
     15 	"strings"
     16 	"sync"
     17 	"testing"
     18 	"time"
     19 )
     20 
     21 var _ = log.Printf
     22 
     23 // fakeDriver is a fake database that implements Go's driver.Driver
     24 // interface, just for testing.
     25 //
     26 // It speaks a query language that's semantically similar to but
     27 // syntactically different and simpler than SQL.  The syntax is as
     28 // follows:
     29 //
     30 //   WIPE
     31 //   CREATE|<tablename>|<col>=<type>,<col>=<type>,...
     32 //     where types are: "string", [u]int{8,16,32,64}, "bool"
     33 //   INSERT|<tablename>|col=val,col2=val2,col3=?
     34 //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
     35 //
     36 // When opening a fakeDriver's database, it starts empty with no
     37 // tables.  All tables and data are stored in memory only.
     38 type fakeDriver struct {
     39 	mu         sync.Mutex // guards 3 following fields
     40 	openCount  int        // conn opens
     41 	closeCount int        // conn closes
     42 	waitCh     chan struct{}
     43 	waitingCh  chan struct{}
     44 	dbs        map[string]*fakeDB
     45 }
     46 
     47 type fakeDB struct {
     48 	name string
     49 
     50 	mu      sync.Mutex
     51 	free    []*fakeConn
     52 	tables  map[string]*table
     53 	badConn bool
     54 }
     55 
     56 type table struct {
     57 	mu      sync.Mutex
     58 	colname []string
     59 	coltype []string
     60 	rows    []*row
     61 }
     62 
     63 func (t *table) columnIndex(name string) int {
     64 	for n, nname := range t.colname {
     65 		if name == nname {
     66 			return n
     67 		}
     68 	}
     69 	return -1
     70 }
     71 
     72 type row struct {
     73 	cols []interface{} // must be same size as its table colname + coltype
     74 }
     75 
     76 func (r *row) clone() *row {
     77 	nrow := &row{cols: make([]interface{}, len(r.cols))}
     78 	copy(nrow.cols, r.cols)
     79 	return nrow
     80 }
     81 
     82 type fakeConn struct {
     83 	db *fakeDB // where to return ourselves to
     84 
     85 	currTx *fakeTx
     86 
     87 	// Stats for tests:
     88 	mu          sync.Mutex
     89 	stmtsMade   int
     90 	stmtsClosed int
     91 	numPrepare  int
     92 
     93 	// bad connection tests; see isBad()
     94 	bad       bool
     95 	stickyBad bool
     96 }
     97 
     98 func (c *fakeConn) incrStat(v *int) {
     99 	c.mu.Lock()
    100 	*v++
    101 	c.mu.Unlock()
    102 }
    103 
    104 type fakeTx struct {
    105 	c *fakeConn
    106 }
    107 
    108 type fakeStmt struct {
    109 	c *fakeConn
    110 	q string // just for debugging
    111 
    112 	cmd   string
    113 	table string
    114 
    115 	closed bool
    116 
    117 	colName      []string      // used by CREATE, INSERT, SELECT (selected columns)
    118 	colType      []string      // used by CREATE
    119 	colValue     []interface{} // used by INSERT (mix of strings and "?" for bound params)
    120 	placeholders int           // used by INSERT/SELECT: number of ? params
    121 
    122 	whereCol []string // used by SELECT (all placeholders)
    123 
    124 	placeholderConverter []driver.ValueConverter // used by INSERT
    125 }
    126 
    127 var fdriver driver.Driver = &fakeDriver{}
    128 
    129 func init() {
    130 	Register("test", fdriver)
    131 }
    132 
    133 func contains(list []string, y string) bool {
    134 	for _, x := range list {
    135 		if x == y {
    136 			return true
    137 		}
    138 	}
    139 	return false
    140 }
    141 
    142 type Dummy struct {
    143 	driver.Driver
    144 }
    145 
    146 func TestDrivers(t *testing.T) {
    147 	unregisterAllDrivers()
    148 	Register("test", fdriver)
    149 	Register("invalid", Dummy{})
    150 	all := Drivers()
    151 	if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
    152 		t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
    153 	}
    154 }
    155 
    156 // Supports dsn forms:
    157 //    <dbname>
    158 //    <dbname>;<opts>  (only currently supported option is `badConn`,
    159 //                      which causes driver.ErrBadConn to be returned on
    160 //                      every other conn.Begin())
    161 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
    162 	parts := strings.Split(dsn, ";")
    163 	if len(parts) < 1 {
    164 		return nil, errors.New("fakedb: no database name")
    165 	}
    166 	name := parts[0]
    167 
    168 	db := d.getDB(name)
    169 
    170 	d.mu.Lock()
    171 	d.openCount++
    172 	d.mu.Unlock()
    173 	conn := &fakeConn{db: db}
    174 
    175 	if len(parts) >= 2 && parts[1] == "badConn" {
    176 		conn.bad = true
    177 	}
    178 	if d.waitCh != nil {
    179 		d.waitingCh <- struct{}{}
    180 		<-d.waitCh
    181 		d.waitCh = nil
    182 		d.waitingCh = nil
    183 	}
    184 	return conn, nil
    185 }
    186 
    187 func (d *fakeDriver) getDB(name string) *fakeDB {
    188 	d.mu.Lock()
    189 	defer d.mu.Unlock()
    190 	if d.dbs == nil {
    191 		d.dbs = make(map[string]*fakeDB)
    192 	}
    193 	db, ok := d.dbs[name]
    194 	if !ok {
    195 		db = &fakeDB{name: name}
    196 		d.dbs[name] = db
    197 	}
    198 	return db
    199 }
    200 
    201 func (db *fakeDB) wipe() {
    202 	db.mu.Lock()
    203 	defer db.mu.Unlock()
    204 	db.tables = nil
    205 }
    206 
    207 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
    208 	db.mu.Lock()
    209 	defer db.mu.Unlock()
    210 	if db.tables == nil {
    211 		db.tables = make(map[string]*table)
    212 	}
    213 	if _, exist := db.tables[name]; exist {
    214 		return fmt.Errorf("table %q already exists", name)
    215 	}
    216 	if len(columnNames) != len(columnTypes) {
    217 		return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d",
    218 			name, len(columnNames), len(columnTypes))
    219 	}
    220 	db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
    221 	return nil
    222 }
    223 
    224 // must be called with db.mu lock held
    225 func (db *fakeDB) table(table string) (*table, bool) {
    226 	if db.tables == nil {
    227 		return nil, false
    228 	}
    229 	t, ok := db.tables[table]
    230 	return t, ok
    231 }
    232 
    233 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
    234 	db.mu.Lock()
    235 	defer db.mu.Unlock()
    236 	t, ok := db.table(table)
    237 	if !ok {
    238 		return
    239 	}
    240 	for n, cname := range t.colname {
    241 		if cname == column {
    242 			return t.coltype[n], true
    243 		}
    244 	}
    245 	return "", false
    246 }
    247 
    248 func (c *fakeConn) isBad() bool {
    249 	if c.stickyBad {
    250 		return true
    251 	} else if c.bad {
    252 		// alternate between bad conn and not bad conn
    253 		c.db.badConn = !c.db.badConn
    254 		return c.db.badConn
    255 	} else {
    256 		return false
    257 	}
    258 }
    259 
    260 func (c *fakeConn) Begin() (driver.Tx, error) {
    261 	if c.isBad() {
    262 		return nil, driver.ErrBadConn
    263 	}
    264 	if c.currTx != nil {
    265 		return nil, errors.New("already in a transaction")
    266 	}
    267 	c.currTx = &fakeTx{c: c}
    268 	return c.currTx, nil
    269 }
    270 
    271 var hookPostCloseConn struct {
    272 	sync.Mutex
    273 	fn func(*fakeConn, error)
    274 }
    275 
    276 func setHookpostCloseConn(fn func(*fakeConn, error)) {
    277 	hookPostCloseConn.Lock()
    278 	defer hookPostCloseConn.Unlock()
    279 	hookPostCloseConn.fn = fn
    280 }
    281 
    282 var testStrictClose *testing.T
    283 
    284 // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
    285 // fails to close. If nil, the check is disabled.
    286 func setStrictFakeConnClose(t *testing.T) {
    287 	testStrictClose = t
    288 }
    289 
    290 func (c *fakeConn) Close() (err error) {
    291 	drv := fdriver.(*fakeDriver)
    292 	defer func() {
    293 		if err != nil && testStrictClose != nil {
    294 			testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
    295 		}
    296 		hookPostCloseConn.Lock()
    297 		fn := hookPostCloseConn.fn
    298 		hookPostCloseConn.Unlock()
    299 		if fn != nil {
    300 			fn(c, err)
    301 		}
    302 		if err == nil {
    303 			drv.mu.Lock()
    304 			drv.closeCount++
    305 			drv.mu.Unlock()
    306 		}
    307 	}()
    308 	if c.currTx != nil {
    309 		return errors.New("can't close fakeConn; in a Transaction")
    310 	}
    311 	if c.db == nil {
    312 		return errors.New("can't close fakeConn; already closed")
    313 	}
    314 	if c.stmtsMade > c.stmtsClosed {
    315 		return errors.New("can't close; dangling statement(s)")
    316 	}
    317 	c.db = nil
    318 	return nil
    319 }
    320 
    321 func checkSubsetTypes(args []driver.Value) error {
    322 	for n, arg := range args {
    323 		switch arg.(type) {
    324 		case int64, float64, bool, nil, []byte, string, time.Time:
    325 		default:
    326 			return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
    327 		}
    328 	}
    329 	return nil
    330 }
    331 
    332 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
    333 	// This is an optional interface, but it's implemented here
    334 	// just to check that all the args are of the proper types.
    335 	// ErrSkip is returned so the caller acts as if we didn't
    336 	// implement this at all.
    337 	err := checkSubsetTypes(args)
    338 	if err != nil {
    339 		return nil, err
    340 	}
    341 	return nil, driver.ErrSkip
    342 }
    343 
    344 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
    345 	// This is an optional interface, but it's implemented here
    346 	// just to check that all the args are of the proper types.
    347 	// ErrSkip is returned so the caller acts as if we didn't
    348 	// implement this at all.
    349 	err := checkSubsetTypes(args)
    350 	if err != nil {
    351 		return nil, err
    352 	}
    353 	return nil, driver.ErrSkip
    354 }
    355 
    356 func errf(msg string, args ...interface{}) error {
    357 	return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
    358 }
    359 
    360 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
    361 // (note that where columns must always contain ? marks,
    362 //  just a limitation for fakedb)
    363 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
    364 	if len(parts) != 3 {
    365 		stmt.Close()
    366 		return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
    367 	}
    368 	stmt.table = parts[0]
    369 	stmt.colName = strings.Split(parts[1], ",")
    370 	for n, colspec := range strings.Split(parts[2], ",") {
    371 		if colspec == "" {
    372 			continue
    373 		}
    374 		nameVal := strings.Split(colspec, "=")
    375 		if len(nameVal) != 2 {
    376 			stmt.Close()
    377 			return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
    378 		}
    379 		column, value := nameVal[0], nameVal[1]
    380 		_, ok := c.db.columnType(stmt.table, column)
    381 		if !ok {
    382 			stmt.Close()
    383 			return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
    384 		}
    385 		if value != "?" {
    386 			stmt.Close()
    387 			return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
    388 				stmt.table, column)
    389 		}
    390 		stmt.whereCol = append(stmt.whereCol, column)
    391 		stmt.placeholders++
    392 	}
    393 	return stmt, nil
    394 }
    395 
    396 // parts are table|col=type,col2=type2
    397 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
    398 	if len(parts) != 2 {
    399 		stmt.Close()
    400 		return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
    401 	}
    402 	stmt.table = parts[0]
    403 	for n, colspec := range strings.Split(parts[1], ",") {
    404 		nameType := strings.Split(colspec, "=")
    405 		if len(nameType) != 2 {
    406 			stmt.Close()
    407 			return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
    408 		}
    409 		stmt.colName = append(stmt.colName, nameType[0])
    410 		stmt.colType = append(stmt.colType, nameType[1])
    411 	}
    412 	return stmt, nil
    413 }
    414 
    415 // parts are table|col=?,col2=val
    416 func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
    417 	if len(parts) != 2 {
    418 		stmt.Close()
    419 		return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
    420 	}
    421 	stmt.table = parts[0]
    422 	for n, colspec := range strings.Split(parts[1], ",") {
    423 		nameVal := strings.Split(colspec, "=")
    424 		if len(nameVal) != 2 {
    425 			stmt.Close()
    426 			return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
    427 		}
    428 		column, value := nameVal[0], nameVal[1]
    429 		ctype, ok := c.db.columnType(stmt.table, column)
    430 		if !ok {
    431 			stmt.Close()
    432 			return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
    433 		}
    434 		stmt.colName = append(stmt.colName, column)
    435 
    436 		if value != "?" {
    437 			var subsetVal interface{}
    438 			// Convert to driver subset type
    439 			switch ctype {
    440 			case "string":
    441 				subsetVal = []byte(value)
    442 			case "blob":
    443 				subsetVal = []byte(value)
    444 			case "int32":
    445 				i, err := strconv.Atoi(value)
    446 				if err != nil {
    447 					stmt.Close()
    448 					return nil, errf("invalid conversion to int32 from %q", value)
    449 				}
    450 				subsetVal = int64(i) // int64 is a subset type, but not int32
    451 			default:
    452 				stmt.Close()
    453 				return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
    454 			}
    455 			stmt.colValue = append(stmt.colValue, subsetVal)
    456 		} else {
    457 			stmt.placeholders++
    458 			stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
    459 			stmt.colValue = append(stmt.colValue, "?")
    460 		}
    461 	}
    462 	return stmt, nil
    463 }
    464 
    465 // hook to simulate broken connections
    466 var hookPrepareBadConn func() bool
    467 
    468 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
    469 	c.numPrepare++
    470 	if c.db == nil {
    471 		panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
    472 	}
    473 
    474 	if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
    475 		return nil, driver.ErrBadConn
    476 	}
    477 
    478 	parts := strings.Split(query, "|")
    479 	if len(parts) < 1 {
    480 		return nil, errf("empty query")
    481 	}
    482 	cmd := parts[0]
    483 	parts = parts[1:]
    484 	stmt := &fakeStmt{q: query, c: c, cmd: cmd}
    485 	c.incrStat(&c.stmtsMade)
    486 	switch cmd {
    487 	case "WIPE":
    488 		// Nothing
    489 	case "SELECT":
    490 		return c.prepareSelect(stmt, parts)
    491 	case "CREATE":
    492 		return c.prepareCreate(stmt, parts)
    493 	case "INSERT":
    494 		return c.prepareInsert(stmt, parts)
    495 	case "NOSERT":
    496 		// Do all the prep-work like for an INSERT but don't actually insert the row.
    497 		// Used for some of the concurrent tests.
    498 		return c.prepareInsert(stmt, parts)
    499 	default:
    500 		stmt.Close()
    501 		return nil, errf("unsupported command type %q", cmd)
    502 	}
    503 	return stmt, nil
    504 }
    505 
    506 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
    507 	if len(s.placeholderConverter) == 0 {
    508 		return driver.DefaultParameterConverter
    509 	}
    510 	return s.placeholderConverter[idx]
    511 }
    512 
    513 func (s *fakeStmt) Close() error {
    514 	if s.c == nil {
    515 		panic("nil conn in fakeStmt.Close")
    516 	}
    517 	if s.c.db == nil {
    518 		panic("in fakeStmt.Close, conn's db is nil (already closed)")
    519 	}
    520 	if !s.closed {
    521 		s.c.incrStat(&s.c.stmtsClosed)
    522 		s.closed = true
    523 	}
    524 	return nil
    525 }
    526 
    527 var errClosed = errors.New("fakedb: statement has been closed")
    528 
    529 // hook to simulate broken connections
    530 var hookExecBadConn func() bool
    531 
    532 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
    533 	if s.closed {
    534 		return nil, errClosed
    535 	}
    536 
    537 	if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
    538 		return nil, driver.ErrBadConn
    539 	}
    540 
    541 	err := checkSubsetTypes(args)
    542 	if err != nil {
    543 		return nil, err
    544 	}
    545 
    546 	db := s.c.db
    547 	switch s.cmd {
    548 	case "WIPE":
    549 		db.wipe()
    550 		return driver.ResultNoRows, nil
    551 	case "CREATE":
    552 		if err := db.createTable(s.table, s.colName, s.colType); err != nil {
    553 			return nil, err
    554 		}
    555 		return driver.ResultNoRows, nil
    556 	case "INSERT":
    557 		return s.execInsert(args, true)
    558 	case "NOSERT":
    559 		// Do all the prep-work like for an INSERT but don't actually insert the row.
    560 		// Used for some of the concurrent tests.
    561 		return s.execInsert(args, false)
    562 	}
    563 	fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s)
    564 	return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd)
    565 }
    566 
    567 // When doInsert is true, add the row to the table.
    568 // When doInsert is false do prep-work and error checking, but don't
    569 // actually add the row to the table.
    570 func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result, error) {
    571 	db := s.c.db
    572 	if len(args) != s.placeholders {
    573 		panic("error in pkg db; should only get here if size is correct")
    574 	}
    575 	db.mu.Lock()
    576 	t, ok := db.table(s.table)
    577 	db.mu.Unlock()
    578 	if !ok {
    579 		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
    580 	}
    581 
    582 	t.mu.Lock()
    583 	defer t.mu.Unlock()
    584 
    585 	var cols []interface{}
    586 	if doInsert {
    587 		cols = make([]interface{}, len(t.colname))
    588 	}
    589 	argPos := 0
    590 	for n, colname := range s.colName {
    591 		colidx := t.columnIndex(colname)
    592 		if colidx == -1 {
    593 			return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
    594 		}
    595 		var val interface{}
    596 		if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" {
    597 			val = args[argPos]
    598 			argPos++
    599 		} else {
    600 			val = s.colValue[n]
    601 		}
    602 		if doInsert {
    603 			cols[colidx] = val
    604 		}
    605 	}
    606 
    607 	if doInsert {
    608 		t.rows = append(t.rows, &row{cols: cols})
    609 	}
    610 	return driver.RowsAffected(1), nil
    611 }
    612 
    613 // hook to simulate broken connections
    614 var hookQueryBadConn func() bool
    615 
    616 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
    617 	if s.closed {
    618 		return nil, errClosed
    619 	}
    620 
    621 	if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
    622 		return nil, driver.ErrBadConn
    623 	}
    624 
    625 	err := checkSubsetTypes(args)
    626 	if err != nil {
    627 		return nil, err
    628 	}
    629 
    630 	db := s.c.db
    631 	if len(args) != s.placeholders {
    632 		panic("error in pkg db; should only get here if size is correct")
    633 	}
    634 
    635 	db.mu.Lock()
    636 	t, ok := db.table(s.table)
    637 	db.mu.Unlock()
    638 	if !ok {
    639 		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
    640 	}
    641 
    642 	if s.table == "magicquery" {
    643 		if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" {
    644 			if args[0] == "sleep" {
    645 				time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond)
    646 			}
    647 		}
    648 	}
    649 
    650 	t.mu.Lock()
    651 	defer t.mu.Unlock()
    652 
    653 	colIdx := make(map[string]int) // select column name -> column index in table
    654 	for _, name := range s.colName {
    655 		idx := t.columnIndex(name)
    656 		if idx == -1 {
    657 			return nil, fmt.Errorf("fakedb: unknown column name %q", name)
    658 		}
    659 		colIdx[name] = idx
    660 	}
    661 
    662 	mrows := []*row{}
    663 rows:
    664 	for _, trow := range t.rows {
    665 		// Process the where clause, skipping non-match rows. This is lazy
    666 		// and just uses fmt.Sprintf("%v") to test equality.  Good enough
    667 		// for test code.
    668 		for widx, wcol := range s.whereCol {
    669 			idx := t.columnIndex(wcol)
    670 			if idx == -1 {
    671 				return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
    672 			}
    673 			tcol := trow.cols[idx]
    674 			if bs, ok := tcol.([]byte); ok {
    675 				// lazy hack to avoid sprintf %v on a []byte
    676 				tcol = string(bs)
    677 			}
    678 			if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
    679 				continue rows
    680 			}
    681 		}
    682 		mrow := &row{cols: make([]interface{}, len(s.colName))}
    683 		for seli, name := range s.colName {
    684 			mrow.cols[seli] = trow.cols[colIdx[name]]
    685 		}
    686 		mrows = append(mrows, mrow)
    687 	}
    688 
    689 	cursor := &rowsCursor{
    690 		pos:    -1,
    691 		rows:   mrows,
    692 		cols:   s.colName,
    693 		errPos: -1,
    694 	}
    695 	return cursor, nil
    696 }
    697 
    698 func (s *fakeStmt) NumInput() int {
    699 	return s.placeholders
    700 }
    701 
    702 func (tx *fakeTx) Commit() error {
    703 	tx.c.currTx = nil
    704 	return nil
    705 }
    706 
    707 func (tx *fakeTx) Rollback() error {
    708 	tx.c.currTx = nil
    709 	return nil
    710 }
    711 
    712 type rowsCursor struct {
    713 	cols   []string
    714 	pos    int
    715 	rows   []*row
    716 	closed bool
    717 
    718 	// errPos and err are for making Next return early with error.
    719 	errPos int
    720 	err    error
    721 
    722 	// a clone of slices to give out to clients, indexed by the
    723 	// the original slice's first byte address.  we clone them
    724 	// just so we're able to corrupt them on close.
    725 	bytesClone map[*byte][]byte
    726 }
    727 
    728 func (rc *rowsCursor) Close() error {
    729 	if !rc.closed {
    730 		for _, bs := range rc.bytesClone {
    731 			bs[0] = 255 // first byte corrupted
    732 		}
    733 	}
    734 	rc.closed = true
    735 	return nil
    736 }
    737 
    738 func (rc *rowsCursor) Columns() []string {
    739 	return rc.cols
    740 }
    741 
    742 var rowsCursorNextHook func(dest []driver.Value) error
    743 
    744 func (rc *rowsCursor) Next(dest []driver.Value) error {
    745 	if rowsCursorNextHook != nil {
    746 		return rowsCursorNextHook(dest)
    747 	}
    748 
    749 	if rc.closed {
    750 		return errors.New("fakedb: cursor is closed")
    751 	}
    752 	rc.pos++
    753 	if rc.pos == rc.errPos {
    754 		return rc.err
    755 	}
    756 	if rc.pos >= len(rc.rows) {
    757 		return io.EOF // per interface spec
    758 	}
    759 	for i, v := range rc.rows[rc.pos].cols {
    760 		// TODO(bradfitz): convert to subset types? naah, I
    761 		// think the subset types should only be input to
    762 		// driver, but the sql package should be able to handle
    763 		// a wider range of types coming out of drivers. all
    764 		// for ease of drivers, and to prevent drivers from
    765 		// messing up conversions or doing them differently.
    766 		dest[i] = v
    767 
    768 		if bs, ok := v.([]byte); ok {
    769 			if rc.bytesClone == nil {
    770 				rc.bytesClone = make(map[*byte][]byte)
    771 			}
    772 			clone, ok := rc.bytesClone[&bs[0]]
    773 			if !ok {
    774 				clone = make([]byte, len(bs))
    775 				copy(clone, bs)
    776 				rc.bytesClone[&bs[0]] = clone
    777 			}
    778 			dest[i] = clone
    779 		}
    780 	}
    781 	return nil
    782 }
    783 
    784 // fakeDriverString is like driver.String, but indirects pointers like
    785 // DefaultValueConverter.
    786 //
    787 // This could be surprising behavior to retroactively apply to
    788 // driver.String now that Go1 is out, but this is convenient for
    789 // our TestPointerParamsAndScans.
    790 //
    791 type fakeDriverString struct{}
    792 
    793 func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
    794 	switch c := v.(type) {
    795 	case string, []byte:
    796 		return v, nil
    797 	case *string:
    798 		if c == nil {
    799 			return nil, nil
    800 		}
    801 		return *c, nil
    802 	}
    803 	return fmt.Sprintf("%v", v), nil
    804 }
    805 
    806 func converterForType(typ string) driver.ValueConverter {
    807 	switch typ {
    808 	case "bool":
    809 		return driver.Bool
    810 	case "nullbool":
    811 		return driver.Null{Converter: driver.Bool}
    812 	case "int32":
    813 		return driver.Int32
    814 	case "string":
    815 		return driver.NotNull{Converter: fakeDriverString{}}
    816 	case "nullstring":
    817 		return driver.Null{Converter: fakeDriverString{}}
    818 	case "int64":
    819 		// TODO(coopernurse): add type-specific converter
    820 		return driver.NotNull{Converter: driver.DefaultParameterConverter}
    821 	case "nullint64":
    822 		// TODO(coopernurse): add type-specific converter
    823 		return driver.Null{Converter: driver.DefaultParameterConverter}
    824 	case "float64":
    825 		// TODO(coopernurse): add type-specific converter
    826 		return driver.NotNull{Converter: driver.DefaultParameterConverter}
    827 	case "nullfloat64":
    828 		// TODO(coopernurse): add type-specific converter
    829 		return driver.Null{Converter: driver.DefaultParameterConverter}
    830 	case "datetime":
    831 		return driver.DefaultParameterConverter
    832 	}
    833 	panic("invalid fakedb column type of " + typ)
    834 }
    835