Home | History | Annotate | Download | only in fix
      1 // Copyright 2011 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package main
      6 
      7 import (
      8 	"go/ast"
      9 	"go/parser"
     10 	"strings"
     11 	"testing"
     12 )
     13 
     14 type testCase struct {
     15 	Name string
     16 	Fn   func(*ast.File) bool
     17 	In   string
     18 	Out  string
     19 }
     20 
     21 var testCases []testCase
     22 
     23 func addTestCases(t []testCase, fn func(*ast.File) bool) {
     24 	// Fill in fn to avoid repetition in definitions.
     25 	if fn != nil {
     26 		for i := range t {
     27 			if t[i].Fn == nil {
     28 				t[i].Fn = fn
     29 			}
     30 		}
     31 	}
     32 	testCases = append(testCases, t...)
     33 }
     34 
     35 func fnop(*ast.File) bool { return false }
     36 
     37 func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string, mustBeGofmt bool) (out string, fixed, ok bool) {
     38 	file, err := parser.ParseFile(fset, desc, in, parserMode)
     39 	if err != nil {
     40 		t.Errorf("%s: parsing: %v", desc, err)
     41 		return
     42 	}
     43 
     44 	outb, err := gofmtFile(file)
     45 	if err != nil {
     46 		t.Errorf("%s: printing: %v", desc, err)
     47 		return
     48 	}
     49 	if s := string(outb); in != s && mustBeGofmt {
     50 		t.Errorf("%s: not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s",
     51 			desc, desc, in, desc, s)
     52 		tdiff(t, in, s)
     53 		return
     54 	}
     55 
     56 	if fn == nil {
     57 		for _, fix := range fixes {
     58 			if fix.f(file) {
     59 				fixed = true
     60 			}
     61 		}
     62 	} else {
     63 		fixed = fn(file)
     64 	}
     65 
     66 	outb, err = gofmtFile(file)
     67 	if err != nil {
     68 		t.Errorf("%s: printing: %v", desc, err)
     69 		return
     70 	}
     71 
     72 	return string(outb), fixed, true
     73 }
     74 
     75 func TestRewrite(t *testing.T) {
     76 	for _, tt := range testCases {
     77 		// Apply fix: should get tt.Out.
     78 		out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In, true)
     79 		if !ok {
     80 			continue
     81 		}
     82 
     83 		// reformat to get printing right
     84 		out, _, ok = parseFixPrint(t, fnop, tt.Name, out, false)
     85 		if !ok {
     86 			continue
     87 		}
     88 
     89 		if out != tt.Out {
     90 			t.Errorf("%s: incorrect output.\n", tt.Name)
     91 			if !strings.HasPrefix(tt.Name, "testdata/") {
     92 				t.Errorf("--- have\n%s\n--- want\n%s", out, tt.Out)
     93 			}
     94 			tdiff(t, out, tt.Out)
     95 			continue
     96 		}
     97 
     98 		if changed := out != tt.In; changed != fixed {
     99 			t.Errorf("%s: changed=%v != fixed=%v", tt.Name, changed, fixed)
    100 			continue
    101 		}
    102 
    103 		// Should not change if run again.
    104 		out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out, true)
    105 		if !ok {
    106 			continue
    107 		}
    108 
    109 		if fixed2 {
    110 			t.Errorf("%s: applied fixes during second round", tt.Name)
    111 			continue
    112 		}
    113 
    114 		if out2 != out {
    115 			t.Errorf("%s: changed output after second round of fixes.\n--- output after first round\n%s\n--- output after second round\n%s",
    116 				tt.Name, out, out2)
    117 			tdiff(t, out, out2)
    118 		}
    119 	}
    120 }
    121 
    122 func tdiff(t *testing.T, a, b string) {
    123 	data, err := diff([]byte(a), []byte(b))
    124 	if err != nil {
    125 		t.Error(err)
    126 		return
    127 	}
    128 	t.Error(string(data))
    129 }
    130