Home | History | Annotate | Download | only in sort
      1 // Copyright 2016 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // +build ignore
      6 
      7 // This program is run via "go generate" (via a directive in sort.go)
      8 // to generate zfuncversion.go.
      9 //
     10 // It copies sort.go to zfuncversion.go, only retaining funcs which
     11 // take a "data Interface" parameter, and renaming each to have a
     12 // "_func" suffix and taking a "data lessSwap" instead. It then rewrites
     13 // each internal function call to the appropriate _func variants.
     14 
     15 package main
     16 
     17 import (
     18 	"bytes"
     19 	"go/ast"
     20 	"go/format"
     21 	"go/parser"
     22 	"go/token"
     23 	"io/ioutil"
     24 	"log"
     25 	"regexp"
     26 )
     27 
     28 var fset = token.NewFileSet()
     29 
     30 func main() {
     31 	af, err := parser.ParseFile(fset, "sort.go", nil, 0)
     32 	if err != nil {
     33 		log.Fatal(err)
     34 	}
     35 	af.Doc = nil
     36 	af.Imports = nil
     37 	af.Comments = nil
     38 
     39 	var newDecl []ast.Decl
     40 	for _, d := range af.Decls {
     41 		fd, ok := d.(*ast.FuncDecl)
     42 		if !ok {
     43 			continue
     44 		}
     45 		if fd.Recv != nil || fd.Name.IsExported() {
     46 			continue
     47 		}
     48 		typ := fd.Type
     49 		if len(typ.Params.List) < 1 {
     50 			continue
     51 		}
     52 		arg0 := typ.Params.List[0]
     53 		arg0Name := arg0.Names[0].Name
     54 		arg0Type := arg0.Type.(*ast.Ident)
     55 		if arg0Name != "data" || arg0Type.Name != "Interface" {
     56 			continue
     57 		}
     58 		arg0Type.Name = "lessSwap"
     59 
     60 		newDecl = append(newDecl, fd)
     61 	}
     62 	af.Decls = newDecl
     63 	ast.Walk(visitFunc(rewriteCalls), af)
     64 
     65 	var out bytes.Buffer
     66 	if err := format.Node(&out, fset, af); err != nil {
     67 		log.Fatalf("format.Node: %v", err)
     68 	}
     69 
     70 	// Get rid of blank lines after removal of comments.
     71 	src := regexp.MustCompile(`\n{2,}`).ReplaceAll(out.Bytes(), []byte("\n"))
     72 
     73 	// Add comments to each func, for the lost reader.
     74 	// This is so much easier than adding comments via the AST
     75 	// and trying to get position info correct.
     76 	src = regexp.MustCompile(`(?m)^func (\w+)`).ReplaceAll(src, []byte("\n// Auto-generated variant of sort.go:$1\nfunc ${1}_func"))
     77 
     78 	// Final gofmt.
     79 	src, err = format.Source(src)
     80 	if err != nil {
     81 		log.Fatalf("format.Source: %v on\n%s", err, src)
     82 	}
     83 
     84 	out.Reset()
     85 	out.WriteString(`// DO NOT EDIT; AUTO-GENERATED from sort.go using genzfunc.go
     86 
     87 // Copyright 2016 The Go Authors. All rights reserved.
     88 // Use of this source code is governed by a BSD-style
     89 // license that can be found in the LICENSE file.
     90 
     91 `)
     92 	out.Write(src)
     93 
     94 	const target = "zfuncversion.go"
     95 	if err := ioutil.WriteFile(target, out.Bytes(), 0644); err != nil {
     96 		log.Fatal(err)
     97 	}
     98 }
     99 
    100 type visitFunc func(ast.Node) ast.Visitor
    101 
    102 func (f visitFunc) Visit(n ast.Node) ast.Visitor { return f(n) }
    103 
    104 func rewriteCalls(n ast.Node) ast.Visitor {
    105 	ce, ok := n.(*ast.CallExpr)
    106 	if ok {
    107 		rewriteCall(ce)
    108 	}
    109 	return visitFunc(rewriteCalls)
    110 }
    111 
    112 func rewriteCall(ce *ast.CallExpr) {
    113 	ident, ok := ce.Fun.(*ast.Ident)
    114 	if !ok {
    115 		// e.g. skip SelectorExpr (data.Less(..) calls)
    116 		return
    117 	}
    118 	if len(ce.Args) < 1 {
    119 		return
    120 	}
    121 	ident.Name += "_func"
    122 }
    123