Home | History | Annotate | Download | only in fix
      1 // Copyright 2011 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package main
      6 
      7 import "go/ast"
      8 
      9 func init() {
     10 	addTestCases(importTests, nil)
     11 }
     12 
     13 var importTests = []testCase{
     14 	{
     15 		Name: "import.0",
     16 		Fn:   addImportFn("os"),
     17 		In: `package main
     18 
     19 import (
     20 	"os"
     21 )
     22 `,
     23 		Out: `package main
     24 
     25 import (
     26 	"os"
     27 )
     28 `,
     29 	},
     30 	{
     31 		Name: "import.1",
     32 		Fn:   addImportFn("os"),
     33 		In: `package main
     34 `,
     35 		Out: `package main
     36 
     37 import "os"
     38 `,
     39 	},
     40 	{
     41 		Name: "import.2",
     42 		Fn:   addImportFn("os"),
     43 		In: `package main
     44 
     45 // Comment
     46 import "C"
     47 `,
     48 		Out: `package main
     49 
     50 // Comment
     51 import "C"
     52 import "os"
     53 `,
     54 	},
     55 	{
     56 		Name: "import.3",
     57 		Fn:   addImportFn("os"),
     58 		In: `package main
     59 
     60 // Comment
     61 import "C"
     62 
     63 import (
     64 	"io"
     65 	"utf8"
     66 )
     67 `,
     68 		Out: `package main
     69 
     70 // Comment
     71 import "C"
     72 
     73 import (
     74 	"io"
     75 	"os"
     76 	"utf8"
     77 )
     78 `,
     79 	},
     80 	{
     81 		Name: "import.4",
     82 		Fn:   deleteImportFn("os"),
     83 		In: `package main
     84 
     85 import (
     86 	"os"
     87 )
     88 `,
     89 		Out: `package main
     90 `,
     91 	},
     92 	{
     93 		Name: "import.5",
     94 		Fn:   deleteImportFn("os"),
     95 		In: `package main
     96 
     97 // Comment
     98 import "C"
     99 import "os"
    100 `,
    101 		Out: `package main
    102 
    103 // Comment
    104 import "C"
    105 `,
    106 	},
    107 	{
    108 		Name: "import.6",
    109 		Fn:   deleteImportFn("os"),
    110 		In: `package main
    111 
    112 // Comment
    113 import "C"
    114 
    115 import (
    116 	"io"
    117 	"os"
    118 	"utf8"
    119 )
    120 `,
    121 		Out: `package main
    122 
    123 // Comment
    124 import "C"
    125 
    126 import (
    127 	"io"
    128 	"utf8"
    129 )
    130 `,
    131 	},
    132 	{
    133 		Name: "import.7",
    134 		Fn:   deleteImportFn("io"),
    135 		In: `package main
    136 
    137 import (
    138 	"io"   // a
    139 	"os"   // b
    140 	"utf8" // c
    141 )
    142 `,
    143 		Out: `package main
    144 
    145 import (
    146 	// a
    147 	"os"   // b
    148 	"utf8" // c
    149 )
    150 `,
    151 	},
    152 	{
    153 		Name: "import.8",
    154 		Fn:   deleteImportFn("os"),
    155 		In: `package main
    156 
    157 import (
    158 	"io"   // a
    159 	"os"   // b
    160 	"utf8" // c
    161 )
    162 `,
    163 		Out: `package main
    164 
    165 import (
    166 	"io" // a
    167 	// b
    168 	"utf8" // c
    169 )
    170 `,
    171 	},
    172 	{
    173 		Name: "import.9",
    174 		Fn:   deleteImportFn("utf8"),
    175 		In: `package main
    176 
    177 import (
    178 	"io"   // a
    179 	"os"   // b
    180 	"utf8" // c
    181 )
    182 `,
    183 		Out: `package main
    184 
    185 import (
    186 	"io" // a
    187 	"os" // b
    188 	// c
    189 )
    190 `,
    191 	},
    192 	{
    193 		Name: "import.10",
    194 		Fn:   deleteImportFn("io"),
    195 		In: `package main
    196 
    197 import (
    198 	"io"
    199 	"os"
    200 	"utf8"
    201 )
    202 `,
    203 		Out: `package main
    204 
    205 import (
    206 	"os"
    207 	"utf8"
    208 )
    209 `,
    210 	},
    211 	{
    212 		Name: "import.11",
    213 		Fn:   deleteImportFn("os"),
    214 		In: `package main
    215 
    216 import (
    217 	"io"
    218 	"os"
    219 	"utf8"
    220 )
    221 `,
    222 		Out: `package main
    223 
    224 import (
    225 	"io"
    226 	"utf8"
    227 )
    228 `,
    229 	},
    230 	{
    231 		Name: "import.12",
    232 		Fn:   deleteImportFn("utf8"),
    233 		In: `package main
    234 
    235 import (
    236 	"io"
    237 	"os"
    238 	"utf8"
    239 )
    240 `,
    241 		Out: `package main
    242 
    243 import (
    244 	"io"
    245 	"os"
    246 )
    247 `,
    248 	},
    249 	{
    250 		Name: "import.13",
    251 		Fn:   rewriteImportFn("utf8", "encoding/utf8"),
    252 		In: `package main
    253 
    254 import (
    255 	"io"
    256 	"os"
    257 	"utf8" // thanks ken
    258 )
    259 `,
    260 		Out: `package main
    261 
    262 import (
    263 	"encoding/utf8" // thanks ken
    264 	"io"
    265 	"os"
    266 )
    267 `,
    268 	},
    269 	{
    270 		Name: "import.14",
    271 		Fn:   rewriteImportFn("asn1", "encoding/asn1"),
    272 		In: `package main
    273 
    274 import (
    275 	"asn1"
    276 	"crypto"
    277 	"crypto/rsa"
    278 	_ "crypto/sha1"
    279 	"crypto/x509"
    280 	"crypto/x509/pkix"
    281 	"time"
    282 )
    283 
    284 var x = 1
    285 `,
    286 		Out: `package main
    287 
    288 import (
    289 	"crypto"
    290 	"crypto/rsa"
    291 	_ "crypto/sha1"
    292 	"crypto/x509"
    293 	"crypto/x509/pkix"
    294 	"encoding/asn1"
    295 	"time"
    296 )
    297 
    298 var x = 1
    299 `,
    300 	},
    301 	{
    302 		Name: "import.15",
    303 		Fn:   rewriteImportFn("url", "net/url"),
    304 		In: `package main
    305 
    306 import (
    307 	"bufio"
    308 	"net"
    309 	"path"
    310 	"url"
    311 )
    312 
    313 var x = 1 // comment on x, not on url
    314 `,
    315 		Out: `package main
    316 
    317 import (
    318 	"bufio"
    319 	"net"
    320 	"net/url"
    321 	"path"
    322 )
    323 
    324 var x = 1 // comment on x, not on url
    325 `,
    326 	},
    327 	{
    328 		Name: "import.16",
    329 		Fn:   rewriteImportFn("http", "net/http", "template", "text/template"),
    330 		In: `package main
    331 
    332 import (
    333 	"flag"
    334 	"http"
    335 	"log"
    336 	"template"
    337 )
    338 
    339 var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18
    340 `,
    341 		Out: `package main
    342 
    343 import (
    344 	"flag"
    345 	"log"
    346 	"net/http"
    347 	"text/template"
    348 )
    349 
    350 var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18
    351 `,
    352 	},
    353 	{
    354 		Name: "import.17",
    355 		Fn:   addImportFn("x/y/z", "x/a/c"),
    356 		In: `package main
    357 
    358 // Comment
    359 import "C"
    360 
    361 import (
    362 	"a"
    363 	"b"
    364 
    365 	"x/w"
    366 
    367 	"d/f"
    368 )
    369 `,
    370 		Out: `package main
    371 
    372 // Comment
    373 import "C"
    374 
    375 import (
    376 	"a"
    377 	"b"
    378 
    379 	"x/a/c"
    380 	"x/w"
    381 	"x/y/z"
    382 
    383 	"d/f"
    384 )
    385 `,
    386 	},
    387 	{
    388 		Name: "import.18",
    389 		Fn:   addDelImportFn("e", "o"),
    390 		In: `package main
    391 
    392 import (
    393 	"f"
    394 	"o"
    395 	"z"
    396 )
    397 `,
    398 		Out: `package main
    399 
    400 import (
    401 	"e"
    402 	"f"
    403 	"z"
    404 )
    405 `,
    406 	},
    407 }
    408 
    409 func addImportFn(path ...string) func(*ast.File) bool {
    410 	return func(f *ast.File) bool {
    411 		fixed := false
    412 		for _, p := range path {
    413 			if !imports(f, p) {
    414 				addImport(f, p)
    415 				fixed = true
    416 			}
    417 		}
    418 		return fixed
    419 	}
    420 }
    421 
    422 func deleteImportFn(path string) func(*ast.File) bool {
    423 	return func(f *ast.File) bool {
    424 		if imports(f, path) {
    425 			deleteImport(f, path)
    426 			return true
    427 		}
    428 		return false
    429 	}
    430 }
    431 
    432 func addDelImportFn(p1 string, p2 string) func(*ast.File) bool {
    433 	return func(f *ast.File) bool {
    434 		fixed := false
    435 		if !imports(f, p1) {
    436 			addImport(f, p1)
    437 			fixed = true
    438 		}
    439 		if imports(f, p2) {
    440 			deleteImport(f, p2)
    441 			fixed = true
    442 		}
    443 		return fixed
    444 	}
    445 }
    446 
    447 func rewriteImportFn(oldnew ...string) func(*ast.File) bool {
    448 	return func(f *ast.File) bool {
    449 		fixed := false
    450 		for i := 0; i < len(oldnew); i += 2 {
    451 			if imports(f, oldnew[i]) {
    452 				rewriteImport(f, oldnew[i], oldnew[i+1])
    453 				fixed = true
    454 			}
    455 		}
    456 		return fixed
    457 	}
    458 }
    459