Home | History | Annotate | Download | only in format
      1 // Copyright 2015 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // TODO(gri): This file and the file src/cmd/gofmt/internal.go are
      6 // the same (but for this comment and the package name). Do not modify
      7 // one without the other. Determine if we can factor out functionality
      8 // in a public API. See also #11844 for context.
      9 
     10 package format
     11 
     12 import (
     13 	"bytes"
     14 	"go/ast"
     15 	"go/parser"
     16 	"go/printer"
     17 	"go/token"
     18 	"strings"
     19 )
     20 
     21 // parse parses src, which was read from the named file,
     22 // as a Go source file, declaration, or statement list.
     23 func parse(fset *token.FileSet, filename string, src []byte, fragmentOk bool) (
     24 	file *ast.File,
     25 	sourceAdj func(src []byte, indent int) []byte,
     26 	indentAdj int,
     27 	err error,
     28 ) {
     29 	// Try as whole source file.
     30 	file, err = parser.ParseFile(fset, filename, src, parserMode)
     31 	// If there's no error, return. If the error is that the source file didn't begin with a
     32 	// package line and source fragments are ok, fall through to
     33 	// try as a source fragment. Stop and return on any other error.
     34 	if err == nil || !fragmentOk || !strings.Contains(err.Error(), "expected 'package'") {
     35 		return
     36 	}
     37 
     38 	// If this is a declaration list, make it a source file
     39 	// by inserting a package clause.
     40 	// Insert using a ';', not a newline, so that the line numbers
     41 	// in psrc match the ones in src.
     42 	psrc := append([]byte("package p;"), src...)
     43 	file, err = parser.ParseFile(fset, filename, psrc, parserMode)
     44 	if err == nil {
     45 		sourceAdj = func(src []byte, indent int) []byte {
     46 			// Remove the package clause.
     47 			// Gofmt has turned the ';' into a '\n'.
     48 			src = src[indent+len("package p\n"):]
     49 			return bytes.TrimSpace(src)
     50 		}
     51 		return
     52 	}
     53 	// If the error is that the source file didn't begin with a
     54 	// declaration, fall through to try as a statement list.
     55 	// Stop and return on any other error.
     56 	if !strings.Contains(err.Error(), "expected declaration") {
     57 		return
     58 	}
     59 
     60 	// If this is a statement list, make it a source file
     61 	// by inserting a package clause and turning the list
     62 	// into a function body. This handles expressions too.
     63 	// Insert using a ';', not a newline, so that the line numbers
     64 	// in fsrc match the ones in src. Add an extra '\n' before the '}'
     65 	// to make sure comments are flushed before the '}'.
     66 	fsrc := append(append([]byte("package p; func _() {"), src...), '\n', '\n', '}')
     67 	file, err = parser.ParseFile(fset, filename, fsrc, parserMode)
     68 	if err == nil {
     69 		sourceAdj = func(src []byte, indent int) []byte {
     70 			// Cap adjusted indent to zero.
     71 			if indent < 0 {
     72 				indent = 0
     73 			}
     74 			// Remove the wrapping.
     75 			// Gofmt has turned the ';' into a '\n'.
     76 			// There will be two non-blank lines with indent, hence 2*indent.
     77 			src = src[2*indent+len("package p\n\nfunc _() {"):]
     78 			// Remove only the "}\n" suffix: remaining whitespaces will be trimmed anyway
     79 			src = src[:len(src)-len("}\n")]
     80 			return bytes.TrimSpace(src)
     81 		}
     82 		// Gofmt has also indented the function body one level.
     83 		// Adjust that with indentAdj.
     84 		indentAdj = -1
     85 	}
     86 
     87 	// Succeeded, or out of options.
     88 	return
     89 }
     90 
     91 // format formats the given package file originally obtained from src
     92 // and adjusts the result based on the original source via sourceAdj
     93 // and indentAdj.
     94 func format(
     95 	fset *token.FileSet,
     96 	file *ast.File,
     97 	sourceAdj func(src []byte, indent int) []byte,
     98 	indentAdj int,
     99 	src []byte,
    100 	cfg printer.Config,
    101 ) ([]byte, error) {
    102 	if sourceAdj == nil {
    103 		// Complete source file.
    104 		var buf bytes.Buffer
    105 		err := cfg.Fprint(&buf, fset, file)
    106 		if err != nil {
    107 			return nil, err
    108 		}
    109 		return buf.Bytes(), nil
    110 	}
    111 
    112 	// Partial source file.
    113 	// Determine and prepend leading space.
    114 	i, j := 0, 0
    115 	for j < len(src) && isSpace(src[j]) {
    116 		if src[j] == '\n' {
    117 			i = j + 1 // byte offset of last line in leading space
    118 		}
    119 		j++
    120 	}
    121 	var res []byte
    122 	res = append(res, src[:i]...)
    123 
    124 	// Determine and prepend indentation of first code line.
    125 	// Spaces are ignored unless there are no tabs,
    126 	// in which case spaces count as one tab.
    127 	indent := 0
    128 	hasSpace := false
    129 	for _, b := range src[i:j] {
    130 		switch b {
    131 		case ' ':
    132 			hasSpace = true
    133 		case '\t':
    134 			indent++
    135 		}
    136 	}
    137 	if indent == 0 && hasSpace {
    138 		indent = 1
    139 	}
    140 	for i := 0; i < indent; i++ {
    141 		res = append(res, '\t')
    142 	}
    143 
    144 	// Format the source.
    145 	// Write it without any leading and trailing space.
    146 	cfg.Indent = indent + indentAdj
    147 	var buf bytes.Buffer
    148 	err := cfg.Fprint(&buf, fset, file)
    149 	if err != nil {
    150 		return nil, err
    151 	}
    152 	out := sourceAdj(buf.Bytes(), cfg.Indent)
    153 
    154 	// If the adjusted output is empty, the source
    155 	// was empty but (possibly) for white space.
    156 	// The result is the incoming source.
    157 	if len(out) == 0 {
    158 		return src, nil
    159 	}
    160 
    161 	// Otherwise, append output to leading space.
    162 	res = append(res, out...)
    163 
    164 	// Determine and append trailing space.
    165 	i = len(src)
    166 	for i > 0 && isSpace(src[i-1]) {
    167 		i--
    168 	}
    169 	return append(res, src[i:]...), nil
    170 }
    171 
    172 // isSpace reports whether the byte is a space character.
    173 // isSpace defines a space as being among the following bytes: ' ', '\t', '\n' and '\r'.
    174 func isSpace(b byte) bool {
    175 	return b == ' ' || b == '\t' || b == '\n' || b == '\r'
    176 }
    177