Home | History | Annotate | Download | only in fix
      1 // Copyright 2011 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package main
      6 
      7 import (
      8 	"bytes"
      9 	"flag"
     10 	"fmt"
     11 	"go/ast"
     12 	"go/format"
     13 	"go/parser"
     14 	"go/scanner"
     15 	"go/token"
     16 	"io/ioutil"
     17 	"os"
     18 	"os/exec"
     19 	"path/filepath"
     20 	"sort"
     21 	"strings"
     22 )
     23 
     24 var (
     25 	fset     = token.NewFileSet()
     26 	exitCode = 0
     27 )
     28 
     29 var allowedRewrites = flag.String("r", "",
     30 	"restrict the rewrites to this comma-separated list")
     31 
     32 var forceRewrites = flag.String("force", "",
     33 	"force these fixes to run even if the code looks updated")
     34 
     35 var allowed, force map[string]bool
     36 
     37 var doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files")
     38 
     39 // enable for debugging fix failures
     40 const debug = false // display incorrectly reformatted source and exit
     41 
     42 func usage() {
     43 	fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
     44 	flag.PrintDefaults()
     45 	fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
     46 	sort.Sort(byName(fixes))
     47 	for _, f := range fixes {
     48 		if f.disabled {
     49 			fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name)
     50 		} else {
     51 			fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
     52 		}
     53 		desc := strings.TrimSpace(f.desc)
     54 		desc = strings.Replace(desc, "\n", "\n\t", -1)
     55 		fmt.Fprintf(os.Stderr, "\t%s\n", desc)
     56 	}
     57 	os.Exit(2)
     58 }
     59 
     60 func main() {
     61 	flag.Usage = usage
     62 	flag.Parse()
     63 
     64 	sort.Sort(byDate(fixes))
     65 
     66 	if *allowedRewrites != "" {
     67 		allowed = make(map[string]bool)
     68 		for _, f := range strings.Split(*allowedRewrites, ",") {
     69 			allowed[f] = true
     70 		}
     71 	}
     72 
     73 	if *forceRewrites != "" {
     74 		force = make(map[string]bool)
     75 		for _, f := range strings.Split(*forceRewrites, ",") {
     76 			force[f] = true
     77 		}
     78 	}
     79 
     80 	if flag.NArg() == 0 {
     81 		if err := processFile("standard input", true); err != nil {
     82 			report(err)
     83 		}
     84 		os.Exit(exitCode)
     85 	}
     86 
     87 	for i := 0; i < flag.NArg(); i++ {
     88 		path := flag.Arg(i)
     89 		switch dir, err := os.Stat(path); {
     90 		case err != nil:
     91 			report(err)
     92 		case dir.IsDir():
     93 			walkDir(path)
     94 		default:
     95 			if err := processFile(path, false); err != nil {
     96 				report(err)
     97 			}
     98 		}
     99 	}
    100 
    101 	os.Exit(exitCode)
    102 }
    103 
    104 const parserMode = parser.ParseComments
    105 
    106 func gofmtFile(f *ast.File) ([]byte, error) {
    107 	var buf bytes.Buffer
    108 	if err := format.Node(&buf, fset, f); err != nil {
    109 		return nil, err
    110 	}
    111 	return buf.Bytes(), nil
    112 }
    113 
    114 func processFile(filename string, useStdin bool) error {
    115 	var f *os.File
    116 	var err error
    117 	var fixlog bytes.Buffer
    118 
    119 	if useStdin {
    120 		f = os.Stdin
    121 	} else {
    122 		f, err = os.Open(filename)
    123 		if err != nil {
    124 			return err
    125 		}
    126 		defer f.Close()
    127 	}
    128 
    129 	src, err := ioutil.ReadAll(f)
    130 	if err != nil {
    131 		return err
    132 	}
    133 
    134 	file, err := parser.ParseFile(fset, filename, src, parserMode)
    135 	if err != nil {
    136 		return err
    137 	}
    138 
    139 	// Apply all fixes to file.
    140 	newFile := file
    141 	fixed := false
    142 	for _, fix := range fixes {
    143 		if allowed != nil && !allowed[fix.name] {
    144 			continue
    145 		}
    146 		if fix.disabled && !force[fix.name] {
    147 			continue
    148 		}
    149 		if fix.f(newFile) {
    150 			fixed = true
    151 			fmt.Fprintf(&fixlog, " %s", fix.name)
    152 
    153 			// AST changed.
    154 			// Print and parse, to update any missing scoping
    155 			// or position information for subsequent fixers.
    156 			newSrc, err := gofmtFile(newFile)
    157 			if err != nil {
    158 				return err
    159 			}
    160 			newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
    161 			if err != nil {
    162 				if debug {
    163 					fmt.Printf("%s", newSrc)
    164 					report(err)
    165 					os.Exit(exitCode)
    166 				}
    167 				return err
    168 			}
    169 		}
    170 	}
    171 	if !fixed {
    172 		return nil
    173 	}
    174 	fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
    175 
    176 	// Print AST.  We did that after each fix, so this appears
    177 	// redundant, but it is necessary to generate gofmt-compatible
    178 	// source code in a few cases. The official gofmt style is the
    179 	// output of the printer run on a standard AST generated by the parser,
    180 	// but the source we generated inside the loop above is the
    181 	// output of the printer run on a mangled AST generated by a fixer.
    182 	newSrc, err := gofmtFile(newFile)
    183 	if err != nil {
    184 		return err
    185 	}
    186 
    187 	if *doDiff {
    188 		data, err := diff(src, newSrc)
    189 		if err != nil {
    190 			return fmt.Errorf("computing diff: %s", err)
    191 		}
    192 		fmt.Printf("diff %s fixed/%s\n", filename, filename)
    193 		os.Stdout.Write(data)
    194 		return nil
    195 	}
    196 
    197 	if useStdin {
    198 		os.Stdout.Write(newSrc)
    199 		return nil
    200 	}
    201 
    202 	return ioutil.WriteFile(f.Name(), newSrc, 0)
    203 }
    204 
    205 var gofmtBuf bytes.Buffer
    206 
    207 func gofmt(n interface{}) string {
    208 	gofmtBuf.Reset()
    209 	if err := format.Node(&gofmtBuf, fset, n); err != nil {
    210 		return "<" + err.Error() + ">"
    211 	}
    212 	return gofmtBuf.String()
    213 }
    214 
    215 func report(err error) {
    216 	scanner.PrintError(os.Stderr, err)
    217 	exitCode = 2
    218 }
    219 
    220 func walkDir(path string) {
    221 	filepath.Walk(path, visitFile)
    222 }
    223 
    224 func visitFile(path string, f os.FileInfo, err error) error {
    225 	if err == nil && isGoFile(f) {
    226 		err = processFile(path, false)
    227 	}
    228 	if err != nil {
    229 		report(err)
    230 	}
    231 	return nil
    232 }
    233 
    234 func isGoFile(f os.FileInfo) bool {
    235 	// ignore non-Go files
    236 	name := f.Name()
    237 	return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
    238 }
    239 
    240 func diff(b1, b2 []byte) (data []byte, err error) {
    241 	f1, err := ioutil.TempFile("", "go-fix")
    242 	if err != nil {
    243 		return nil, err
    244 	}
    245 	defer os.Remove(f1.Name())
    246 	defer f1.Close()
    247 
    248 	f2, err := ioutil.TempFile("", "go-fix")
    249 	if err != nil {
    250 		return nil, err
    251 	}
    252 	defer os.Remove(f2.Name())
    253 	defer f2.Close()
    254 
    255 	f1.Write(b1)
    256 	f2.Write(b2)
    257 
    258 	data, err = exec.Command("diff", "-u", f1.Name(), f2.Name()).CombinedOutput()
    259 	if len(data) > 0 {
    260 		// diff exits with a non-zero status when the files don't match.
    261 		// Ignore that failure as long as we get output.
    262 		err = nil
    263 	}
    264 	return
    265 }
    266