Home | History | Annotate | Download | only in flag
      1 // Copyright 2009 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package flag_test
      6 
      7 import (
      8 	"bytes"
      9 	. "flag"
     10 	"fmt"
     11 	"io"
     12 	"os"
     13 	"sort"
     14 	"strconv"
     15 	"strings"
     16 	"testing"
     17 	"time"
     18 )
     19 
     20 func boolString(s string) string {
     21 	if s == "0" {
     22 		return "false"
     23 	}
     24 	return "true"
     25 }
     26 
     27 func TestEverything(t *testing.T) {
     28 	ResetForTesting(nil)
     29 	Bool("test_bool", false, "bool value")
     30 	Int("test_int", 0, "int value")
     31 	Int64("test_int64", 0, "int64 value")
     32 	Uint("test_uint", 0, "uint value")
     33 	Uint64("test_uint64", 0, "uint64 value")
     34 	String("test_string", "0", "string value")
     35 	Float64("test_float64", 0, "float64 value")
     36 	Duration("test_duration", 0, "time.Duration value")
     37 
     38 	m := make(map[string]*Flag)
     39 	desired := "0"
     40 	visitor := func(f *Flag) {
     41 		if len(f.Name) > 5 && f.Name[0:5] == "test_" {
     42 			m[f.Name] = f
     43 			ok := false
     44 			switch {
     45 			case f.Value.String() == desired:
     46 				ok = true
     47 			case f.Name == "test_bool" && f.Value.String() == boolString(desired):
     48 				ok = true
     49 			case f.Name == "test_duration" && f.Value.String() == desired+"s":
     50 				ok = true
     51 			}
     52 			if !ok {
     53 				t.Error("Visit: bad value", f.Value.String(), "for", f.Name)
     54 			}
     55 		}
     56 	}
     57 	VisitAll(visitor)
     58 	if len(m) != 8 {
     59 		t.Error("VisitAll misses some flags")
     60 		for k, v := range m {
     61 			t.Log(k, *v)
     62 		}
     63 	}
     64 	m = make(map[string]*Flag)
     65 	Visit(visitor)
     66 	if len(m) != 0 {
     67 		t.Errorf("Visit sees unset flags")
     68 		for k, v := range m {
     69 			t.Log(k, *v)
     70 		}
     71 	}
     72 	// Now set all flags
     73 	Set("test_bool", "true")
     74 	Set("test_int", "1")
     75 	Set("test_int64", "1")
     76 	Set("test_uint", "1")
     77 	Set("test_uint64", "1")
     78 	Set("test_string", "1")
     79 	Set("test_float64", "1")
     80 	Set("test_duration", "1s")
     81 	desired = "1"
     82 	Visit(visitor)
     83 	if len(m) != 8 {
     84 		t.Error("Visit fails after set")
     85 		for k, v := range m {
     86 			t.Log(k, *v)
     87 		}
     88 	}
     89 	// Now test they're visited in sort order.
     90 	var flagNames []string
     91 	Visit(func(f *Flag) { flagNames = append(flagNames, f.Name) })
     92 	if !sort.StringsAreSorted(flagNames) {
     93 		t.Errorf("flag names not sorted: %v", flagNames)
     94 	}
     95 }
     96 
     97 func TestGet(t *testing.T) {
     98 	ResetForTesting(nil)
     99 	Bool("test_bool", true, "bool value")
    100 	Int("test_int", 1, "int value")
    101 	Int64("test_int64", 2, "int64 value")
    102 	Uint("test_uint", 3, "uint value")
    103 	Uint64("test_uint64", 4, "uint64 value")
    104 	String("test_string", "5", "string value")
    105 	Float64("test_float64", 6, "float64 value")
    106 	Duration("test_duration", 7, "time.Duration value")
    107 
    108 	visitor := func(f *Flag) {
    109 		if len(f.Name) > 5 && f.Name[0:5] == "test_" {
    110 			g, ok := f.Value.(Getter)
    111 			if !ok {
    112 				t.Errorf("Visit: value does not satisfy Getter: %T", f.Value)
    113 				return
    114 			}
    115 			switch f.Name {
    116 			case "test_bool":
    117 				ok = g.Get() == true
    118 			case "test_int":
    119 				ok = g.Get() == int(1)
    120 			case "test_int64":
    121 				ok = g.Get() == int64(2)
    122 			case "test_uint":
    123 				ok = g.Get() == uint(3)
    124 			case "test_uint64":
    125 				ok = g.Get() == uint64(4)
    126 			case "test_string":
    127 				ok = g.Get() == "5"
    128 			case "test_float64":
    129 				ok = g.Get() == float64(6)
    130 			case "test_duration":
    131 				ok = g.Get() == time.Duration(7)
    132 			}
    133 			if !ok {
    134 				t.Errorf("Visit: bad value %T(%v) for %s", g.Get(), g.Get(), f.Name)
    135 			}
    136 		}
    137 	}
    138 	VisitAll(visitor)
    139 }
    140 
    141 func TestUsage(t *testing.T) {
    142 	called := false
    143 	ResetForTesting(func() { called = true })
    144 	if CommandLine.Parse([]string{"-x"}) == nil {
    145 		t.Error("parse did not fail for unknown flag")
    146 	}
    147 	if !called {
    148 		t.Error("did not call Usage for unknown flag")
    149 	}
    150 }
    151 
    152 func testParse(f *FlagSet, t *testing.T) {
    153 	if f.Parsed() {
    154 		t.Error("f.Parse() = true before Parse")
    155 	}
    156 	boolFlag := f.Bool("bool", false, "bool value")
    157 	bool2Flag := f.Bool("bool2", false, "bool2 value")
    158 	intFlag := f.Int("int", 0, "int value")
    159 	int64Flag := f.Int64("int64", 0, "int64 value")
    160 	uintFlag := f.Uint("uint", 0, "uint value")
    161 	uint64Flag := f.Uint64("uint64", 0, "uint64 value")
    162 	stringFlag := f.String("string", "0", "string value")
    163 	float64Flag := f.Float64("float64", 0, "float64 value")
    164 	durationFlag := f.Duration("duration", 5*time.Second, "time.Duration value")
    165 	extra := "one-extra-argument"
    166 	args := []string{
    167 		"-bool",
    168 		"-bool2=true",
    169 		"--int", "22",
    170 		"--int64", "0x23",
    171 		"-uint", "24",
    172 		"--uint64", "25",
    173 		"-string", "hello",
    174 		"-float64", "2718e28",
    175 		"-duration", "2m",
    176 		extra,
    177 	}
    178 	if err := f.Parse(args); err != nil {
    179 		t.Fatal(err)
    180 	}
    181 	if !f.Parsed() {
    182 		t.Error("f.Parse() = false after Parse")
    183 	}
    184 	if *boolFlag != true {
    185 		t.Error("bool flag should be true, is ", *boolFlag)
    186 	}
    187 	if *bool2Flag != true {
    188 		t.Error("bool2 flag should be true, is ", *bool2Flag)
    189 	}
    190 	if *intFlag != 22 {
    191 		t.Error("int flag should be 22, is ", *intFlag)
    192 	}
    193 	if *int64Flag != 0x23 {
    194 		t.Error("int64 flag should be 0x23, is ", *int64Flag)
    195 	}
    196 	if *uintFlag != 24 {
    197 		t.Error("uint flag should be 24, is ", *uintFlag)
    198 	}
    199 	if *uint64Flag != 25 {
    200 		t.Error("uint64 flag should be 25, is ", *uint64Flag)
    201 	}
    202 	if *stringFlag != "hello" {
    203 		t.Error("string flag should be `hello`, is ", *stringFlag)
    204 	}
    205 	if *float64Flag != 2718e28 {
    206 		t.Error("float64 flag should be 2718e28, is ", *float64Flag)
    207 	}
    208 	if *durationFlag != 2*time.Minute {
    209 		t.Error("duration flag should be 2m, is ", *durationFlag)
    210 	}
    211 	if len(f.Args()) != 1 {
    212 		t.Error("expected one argument, got", len(f.Args()))
    213 	} else if f.Args()[0] != extra {
    214 		t.Errorf("expected argument %q got %q", extra, f.Args()[0])
    215 	}
    216 }
    217 
    218 func TestParse(t *testing.T) {
    219 	ResetForTesting(func() { t.Error("bad parse") })
    220 	testParse(CommandLine, t)
    221 }
    222 
    223 func TestFlagSetParse(t *testing.T) {
    224 	testParse(NewFlagSet("test", ContinueOnError), t)
    225 }
    226 
    227 // Declare a user-defined flag type.
    228 type flagVar []string
    229 
    230 func (f *flagVar) String() string {
    231 	return fmt.Sprint([]string(*f))
    232 }
    233 
    234 func (f *flagVar) Set(value string) error {
    235 	*f = append(*f, value)
    236 	return nil
    237 }
    238 
    239 func TestUserDefined(t *testing.T) {
    240 	var flags FlagSet
    241 	flags.Init("test", ContinueOnError)
    242 	var v flagVar
    243 	flags.Var(&v, "v", "usage")
    244 	if err := flags.Parse([]string{"-v", "1", "-v", "2", "-v=3"}); err != nil {
    245 		t.Error(err)
    246 	}
    247 	if len(v) != 3 {
    248 		t.Fatal("expected 3 args; got ", len(v))
    249 	}
    250 	expect := "[1 2 3]"
    251 	if v.String() != expect {
    252 		t.Errorf("expected value %q got %q", expect, v.String())
    253 	}
    254 }
    255 
    256 func TestUserDefinedForCommandLine(t *testing.T) {
    257 	const help = "HELP"
    258 	var result string
    259 	ResetForTesting(func() { result = help })
    260 	Usage()
    261 	if result != help {
    262 		t.Fatalf("got %q; expected %q", result, help)
    263 	}
    264 }
    265 
    266 // Declare a user-defined boolean flag type.
    267 type boolFlagVar struct {
    268 	count int
    269 }
    270 
    271 func (b *boolFlagVar) String() string {
    272 	return fmt.Sprintf("%d", b.count)
    273 }
    274 
    275 func (b *boolFlagVar) Set(value string) error {
    276 	if value == "true" {
    277 		b.count++
    278 	}
    279 	return nil
    280 }
    281 
    282 func (b *boolFlagVar) IsBoolFlag() bool {
    283 	return b.count < 4
    284 }
    285 
    286 func TestUserDefinedBool(t *testing.T) {
    287 	var flags FlagSet
    288 	flags.Init("test", ContinueOnError)
    289 	var b boolFlagVar
    290 	var err error
    291 	flags.Var(&b, "b", "usage")
    292 	if err = flags.Parse([]string{"-b", "-b", "-b", "-b=true", "-b=false", "-b", "barg", "-b"}); err != nil {
    293 		if b.count < 4 {
    294 			t.Error(err)
    295 		}
    296 	}
    297 
    298 	if b.count != 4 {
    299 		t.Errorf("want: %d; got: %d", 4, b.count)
    300 	}
    301 
    302 	if err == nil {
    303 		t.Error("expected error; got none")
    304 	}
    305 }
    306 
    307 func TestSetOutput(t *testing.T) {
    308 	var flags FlagSet
    309 	var buf bytes.Buffer
    310 	flags.SetOutput(&buf)
    311 	flags.Init("test", ContinueOnError)
    312 	flags.Parse([]string{"-unknown"})
    313 	if out := buf.String(); !strings.Contains(out, "-unknown") {
    314 		t.Logf("expected output mentioning unknown; got %q", out)
    315 	}
    316 }
    317 
    318 // This tests that one can reset the flags. This still works but not well, and is
    319 // superseded by FlagSet.
    320 func TestChangingArgs(t *testing.T) {
    321 	ResetForTesting(func() { t.Fatal("bad parse") })
    322 	oldArgs := os.Args
    323 	defer func() { os.Args = oldArgs }()
    324 	os.Args = []string{"cmd", "-before", "subcmd", "-after", "args"}
    325 	before := Bool("before", false, "")
    326 	if err := CommandLine.Parse(os.Args[1:]); err != nil {
    327 		t.Fatal(err)
    328 	}
    329 	cmd := Arg(0)
    330 	os.Args = Args()
    331 	after := Bool("after", false, "")
    332 	Parse()
    333 	args := Args()
    334 
    335 	if !*before || cmd != "subcmd" || !*after || len(args) != 1 || args[0] != "args" {
    336 		t.Fatalf("expected true subcmd true [args] got %v %v %v %v", *before, cmd, *after, args)
    337 	}
    338 }
    339 
    340 // Test that -help invokes the usage message and returns ErrHelp.
    341 func TestHelp(t *testing.T) {
    342 	var helpCalled = false
    343 	fs := NewFlagSet("help test", ContinueOnError)
    344 	fs.Usage = func() { helpCalled = true }
    345 	var flag bool
    346 	fs.BoolVar(&flag, "flag", false, "regular flag")
    347 	// Regular flag invocation should work
    348 	err := fs.Parse([]string{"-flag=true"})
    349 	if err != nil {
    350 		t.Fatal("expected no error; got ", err)
    351 	}
    352 	if !flag {
    353 		t.Error("flag was not set by -flag")
    354 	}
    355 	if helpCalled {
    356 		t.Error("help called for regular flag")
    357 		helpCalled = false // reset for next test
    358 	}
    359 	// Help flag should work as expected.
    360 	err = fs.Parse([]string{"-help"})
    361 	if err == nil {
    362 		t.Fatal("error expected")
    363 	}
    364 	if err != ErrHelp {
    365 		t.Fatal("expected ErrHelp; got ", err)
    366 	}
    367 	if !helpCalled {
    368 		t.Fatal("help was not called")
    369 	}
    370 	// If we define a help flag, that should override.
    371 	var help bool
    372 	fs.BoolVar(&help, "help", false, "help flag")
    373 	helpCalled = false
    374 	err = fs.Parse([]string{"-help"})
    375 	if err != nil {
    376 		t.Fatal("expected no error for defined -help; got ", err)
    377 	}
    378 	if helpCalled {
    379 		t.Fatal("help was called; should not have been for defined help flag")
    380 	}
    381 }
    382 
    383 const defaultOutput = `  -A	for bootstrapping, allow 'any' type
    384   -Alongflagname
    385     	disable bounds checking
    386   -C	a boolean defaulting to true (default true)
    387   -D path
    388     	set relative path for local imports
    389   -F number
    390     	a non-zero number (default 2.7)
    391   -G float
    392     	a float that defaults to zero
    393   -M string
    394     	a multiline
    395     	help
    396     	string
    397   -N int
    398     	a non-zero int (default 27)
    399   -O	a flag
    400     	multiline help string (default true)
    401   -Z int
    402     	an int that defaults to zero
    403   -maxT timeout
    404     	set timeout for dial
    405 `
    406 
    407 func TestPrintDefaults(t *testing.T) {
    408 	fs := NewFlagSet("print defaults test", ContinueOnError)
    409 	var buf bytes.Buffer
    410 	fs.SetOutput(&buf)
    411 	fs.Bool("A", false, "for bootstrapping, allow 'any' type")
    412 	fs.Bool("Alongflagname", false, "disable bounds checking")
    413 	fs.Bool("C", true, "a boolean defaulting to true")
    414 	fs.String("D", "", "set relative `path` for local imports")
    415 	fs.Float64("F", 2.7, "a non-zero `number`")
    416 	fs.Float64("G", 0, "a float that defaults to zero")
    417 	fs.String("M", "", "a multiline\nhelp\nstring")
    418 	fs.Int("N", 27, "a non-zero int")
    419 	fs.Bool("O", true, "a flag\nmultiline help string")
    420 	fs.Int("Z", 0, "an int that defaults to zero")
    421 	fs.Duration("maxT", 0, "set `timeout` for dial")
    422 	fs.PrintDefaults()
    423 	got := buf.String()
    424 	if got != defaultOutput {
    425 		t.Errorf("got %q want %q\n", got, defaultOutput)
    426 	}
    427 }
    428 
    429 // Issue 19230: validate range of Int and Uint flag values.
    430 func TestIntFlagOverflow(t *testing.T) {
    431 	if strconv.IntSize != 32 {
    432 		return
    433 	}
    434 	ResetForTesting(nil)
    435 	Int("i", 0, "")
    436 	Uint("u", 0, "")
    437 	if err := Set("i", "2147483648"); err == nil {
    438 		t.Error("unexpected success setting Int")
    439 	}
    440 	if err := Set("u", "4294967296"); err == nil {
    441 		t.Error("unexpected success setting Uint")
    442 	}
    443 }
    444 
    445 // Issue 20998: Usage should respect CommandLine.output.
    446 func TestUsageOutput(t *testing.T) {
    447 	ResetForTesting(DefaultUsage)
    448 	var buf bytes.Buffer
    449 	CommandLine.SetOutput(&buf)
    450 	defer func(old []string) { os.Args = old }(os.Args)
    451 	os.Args = []string{"app", "-i=1", "-unknown"}
    452 	Parse()
    453 	const want = "flag provided but not defined: -i\nUsage of app:\n"
    454 	if got := buf.String(); got != want {
    455 		t.Errorf("output = %q; want %q", got, want)
    456 	}
    457 }
    458 
    459 func TestGetters(t *testing.T) {
    460 	expectedName := "flag set"
    461 	expectedErrorHandling := ContinueOnError
    462 	expectedOutput := io.Writer(os.Stderr)
    463 	fs := NewFlagSet(expectedName, expectedErrorHandling)
    464 
    465 	if fs.Name() != expectedName {
    466 		t.Errorf("unexpected name: got %s, expected %s", fs.Name(), expectedName)
    467 	}
    468 	if fs.ErrorHandling() != expectedErrorHandling {
    469 		t.Errorf("unexpected ErrorHandling: got %d, expected %d", fs.ErrorHandling(), expectedErrorHandling)
    470 	}
    471 	if fs.Output() != expectedOutput {
    472 		t.Errorf("unexpected output: got %#v, expected %#v", fs.Output(), expectedOutput)
    473 	}
    474 
    475 	expectedName = "gopher"
    476 	expectedErrorHandling = ExitOnError
    477 	expectedOutput = os.Stdout
    478 	fs.Init(expectedName, expectedErrorHandling)
    479 	fs.SetOutput(expectedOutput)
    480 
    481 	if fs.Name() != expectedName {
    482 		t.Errorf("unexpected name: got %s, expected %s", fs.Name(), expectedName)
    483 	}
    484 	if fs.ErrorHandling() != expectedErrorHandling {
    485 		t.Errorf("unexpected ErrorHandling: got %d, expected %d", fs.ErrorHandling(), expectedErrorHandling)
    486 	}
    487 	if fs.Output() != expectedOutput {
    488 		t.Errorf("unexpected output: got %v, expected %v", fs.Output(), expectedOutput)
    489 	}
    490 }
    491