Home | History | Annotate | Download | only in bpfmt
      1 // Mostly copied from Go's src/cmd/gofmt:
      2 // Copyright 2009 The Go Authors. All rights reserved.
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 package main
      7 
      8 import (
      9 	"bytes"
     10 	"flag"
     11 	"fmt"
     12 	"github.com/google/blueprint/parser"
     13 	"io"
     14 	"io/ioutil"
     15 	"os"
     16 	"os/exec"
     17 	"path/filepath"
     18 )
     19 
     20 var (
     21 	// main operation modes
     22 	list      = flag.Bool("l", false, "list files whose formatting differs from bpfmt's")
     23 	write     = flag.Bool("w", false, "write result to (source) file instead of stdout")
     24 	doDiff    = flag.Bool("d", false, "display diffs instead of rewriting files")
     25 	sortLists = flag.Bool("s", false, "sort arrays")
     26 )
     27 
     28 var (
     29 	exitCode = 0
     30 )
     31 
     32 func report(err error) {
     33 	fmt.Fprintln(os.Stderr, err)
     34 	exitCode = 2
     35 }
     36 
     37 func usage() {
     38 	fmt.Fprintf(os.Stderr, "usage: bpfmt [flags] [path ...]\n")
     39 	flag.PrintDefaults()
     40 	os.Exit(2)
     41 }
     42 
     43 // If in == nil, the source is the contents of the file with the given filename.
     44 func processFile(filename string, in io.Reader, out io.Writer) error {
     45 	if in == nil {
     46 		f, err := os.Open(filename)
     47 		if err != nil {
     48 			return err
     49 		}
     50 		defer f.Close()
     51 		in = f
     52 	}
     53 
     54 	src, err := ioutil.ReadAll(in)
     55 	if err != nil {
     56 		return err
     57 	}
     58 
     59 	r := bytes.NewBuffer(src)
     60 
     61 	file, errs := parser.Parse(filename, r, parser.NewScope(nil))
     62 	if len(errs) > 0 {
     63 		for _, err := range errs {
     64 			fmt.Fprintln(os.Stderr, err)
     65 		}
     66 		return fmt.Errorf("%d parsing errors", len(errs))
     67 	}
     68 
     69 	if *sortLists {
     70 		parser.SortLists(file)
     71 	}
     72 
     73 	res, err := parser.Print(file)
     74 	if err != nil {
     75 		return err
     76 	}
     77 
     78 	if !bytes.Equal(src, res) {
     79 		// formatting has changed
     80 		if *list {
     81 			fmt.Fprintln(out, filename)
     82 		}
     83 		if *write {
     84 			err = ioutil.WriteFile(filename, res, 0644)
     85 			if err != nil {
     86 				return err
     87 			}
     88 		}
     89 		if *doDiff {
     90 			data, err := diff(src, res)
     91 			if err != nil {
     92 				return fmt.Errorf("computing diff: %s", err)
     93 			}
     94 			fmt.Printf("diff %s bpfmt/%s\n", filename, filename)
     95 			out.Write(data)
     96 		}
     97 	}
     98 
     99 	if !*list && !*write && !*doDiff {
    100 		_, err = out.Write(res)
    101 	}
    102 
    103 	return err
    104 }
    105 
    106 func visitFile(path string, f os.FileInfo, err error) error {
    107 	if err == nil && f.Name() == "Blueprints" {
    108 		err = processFile(path, nil, os.Stdout)
    109 	}
    110 	if err != nil {
    111 		report(err)
    112 	}
    113 	return nil
    114 }
    115 
    116 func walkDir(path string) {
    117 	filepath.Walk(path, visitFile)
    118 }
    119 
    120 func main() {
    121 	flag.Parse()
    122 
    123 	if flag.NArg() == 0 {
    124 		if *write {
    125 			fmt.Fprintln(os.Stderr, "error: cannot use -w with standard input")
    126 			exitCode = 2
    127 			return
    128 		}
    129 		if err := processFile("<standard input>", os.Stdin, os.Stdout); err != nil {
    130 			report(err)
    131 		}
    132 		return
    133 	}
    134 
    135 	for i := 0; i < flag.NArg(); i++ {
    136 		path := flag.Arg(i)
    137 		switch dir, err := os.Stat(path); {
    138 		case err != nil:
    139 			report(err)
    140 		case dir.IsDir():
    141 			walkDir(path)
    142 		default:
    143 			if err := processFile(path, nil, os.Stdout); err != nil {
    144 				report(err)
    145 			}
    146 		}
    147 	}
    148 }
    149 
    150 func diff(b1, b2 []byte) (data []byte, err error) {
    151 	f1, err := ioutil.TempFile("", "bpfmt")
    152 	if err != nil {
    153 		return
    154 	}
    155 	defer os.Remove(f1.Name())
    156 	defer f1.Close()
    157 
    158 	f2, err := ioutil.TempFile("", "bpfmt")
    159 	if err != nil {
    160 		return
    161 	}
    162 	defer os.Remove(f2.Name())
    163 	defer f2.Close()
    164 
    165 	f1.Write(b1)
    166 	f2.Write(b2)
    167 
    168 	data, err = exec.Command("diff", "-u", f1.Name(), f2.Name()).CombinedOutput()
    169 	if len(data) > 0 {
    170 		// diff exits with a non-zero status when the files don't match.
    171 		// Ignore that failure as long as we get output.
    172 		err = nil
    173 	}
    174 	return
    175 
    176 }
    177