Home | History | Annotate | Download | only in flag
      1 // Copyright 2009 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package flag_test
      6 
      7 import (
      8 	"bytes"
      9 	. "flag"
     10 	"fmt"
     11 	"os"
     12 	"sort"
     13 	"strings"
     14 	"testing"
     15 	"time"
     16 )
     17 
     18 func boolString(s string) string {
     19 	if s == "0" {
     20 		return "false"
     21 	}
     22 	return "true"
     23 }
     24 
     25 func TestEverything(t *testing.T) {
     26 	ResetForTesting(nil)
     27 	Bool("test_bool", false, "bool value")
     28 	Int("test_int", 0, "int value")
     29 	Int64("test_int64", 0, "int64 value")
     30 	Uint("test_uint", 0, "uint value")
     31 	Uint64("test_uint64", 0, "uint64 value")
     32 	String("test_string", "0", "string value")
     33 	Float64("test_float64", 0, "float64 value")
     34 	Duration("test_duration", 0, "time.Duration value")
     35 
     36 	m := make(map[string]*Flag)
     37 	desired := "0"
     38 	visitor := func(f *Flag) {
     39 		if len(f.Name) > 5 && f.Name[0:5] == "test_" {
     40 			m[f.Name] = f
     41 			ok := false
     42 			switch {
     43 			case f.Value.String() == desired:
     44 				ok = true
     45 			case f.Name == "test_bool" && f.Value.String() == boolString(desired):
     46 				ok = true
     47 			case f.Name == "test_duration" && f.Value.String() == desired+"s":
     48 				ok = true
     49 			}
     50 			if !ok {
     51 				t.Error("Visit: bad value", f.Value.String(), "for", f.Name)
     52 			}
     53 		}
     54 	}
     55 	VisitAll(visitor)
     56 	if len(m) != 8 {
     57 		t.Error("VisitAll misses some flags")
     58 		for k, v := range m {
     59 			t.Log(k, *v)
     60 		}
     61 	}
     62 	m = make(map[string]*Flag)
     63 	Visit(visitor)
     64 	if len(m) != 0 {
     65 		t.Errorf("Visit sees unset flags")
     66 		for k, v := range m {
     67 			t.Log(k, *v)
     68 		}
     69 	}
     70 	// Now set all flags
     71 	Set("test_bool", "true")
     72 	Set("test_int", "1")
     73 	Set("test_int64", "1")
     74 	Set("test_uint", "1")
     75 	Set("test_uint64", "1")
     76 	Set("test_string", "1")
     77 	Set("test_float64", "1")
     78 	Set("test_duration", "1s")
     79 	desired = "1"
     80 	Visit(visitor)
     81 	if len(m) != 8 {
     82 		t.Error("Visit fails after set")
     83 		for k, v := range m {
     84 			t.Log(k, *v)
     85 		}
     86 	}
     87 	// Now test they're visited in sort order.
     88 	var flagNames []string
     89 	Visit(func(f *Flag) { flagNames = append(flagNames, f.Name) })
     90 	if !sort.StringsAreSorted(flagNames) {
     91 		t.Errorf("flag names not sorted: %v", flagNames)
     92 	}
     93 }
     94 
     95 func TestGet(t *testing.T) {
     96 	ResetForTesting(nil)
     97 	Bool("test_bool", true, "bool value")
     98 	Int("test_int", 1, "int value")
     99 	Int64("test_int64", 2, "int64 value")
    100 	Uint("test_uint", 3, "uint value")
    101 	Uint64("test_uint64", 4, "uint64 value")
    102 	String("test_string", "5", "string value")
    103 	Float64("test_float64", 6, "float64 value")
    104 	Duration("test_duration", 7, "time.Duration value")
    105 
    106 	visitor := func(f *Flag) {
    107 		if len(f.Name) > 5 && f.Name[0:5] == "test_" {
    108 			g, ok := f.Value.(Getter)
    109 			if !ok {
    110 				t.Errorf("Visit: value does not satisfy Getter: %T", f.Value)
    111 				return
    112 			}
    113 			switch f.Name {
    114 			case "test_bool":
    115 				ok = g.Get() == true
    116 			case "test_int":
    117 				ok = g.Get() == int(1)
    118 			case "test_int64":
    119 				ok = g.Get() == int64(2)
    120 			case "test_uint":
    121 				ok = g.Get() == uint(3)
    122 			case "test_uint64":
    123 				ok = g.Get() == uint64(4)
    124 			case "test_string":
    125 				ok = g.Get() == "5"
    126 			case "test_float64":
    127 				ok = g.Get() == float64(6)
    128 			case "test_duration":
    129 				ok = g.Get() == time.Duration(7)
    130 			}
    131 			if !ok {
    132 				t.Errorf("Visit: bad value %T(%v) for %s", g.Get(), g.Get(), f.Name)
    133 			}
    134 		}
    135 	}
    136 	VisitAll(visitor)
    137 }
    138 
    139 func TestUsage(t *testing.T) {
    140 	called := false
    141 	ResetForTesting(func() { called = true })
    142 	if CommandLine.Parse([]string{"-x"}) == nil {
    143 		t.Error("parse did not fail for unknown flag")
    144 	}
    145 	if !called {
    146 		t.Error("did not call Usage for unknown flag")
    147 	}
    148 }
    149 
    150 func testParse(f *FlagSet, t *testing.T) {
    151 	if f.Parsed() {
    152 		t.Error("f.Parse() = true before Parse")
    153 	}
    154 	boolFlag := f.Bool("bool", false, "bool value")
    155 	bool2Flag := f.Bool("bool2", false, "bool2 value")
    156 	intFlag := f.Int("int", 0, "int value")
    157 	int64Flag := f.Int64("int64", 0, "int64 value")
    158 	uintFlag := f.Uint("uint", 0, "uint value")
    159 	uint64Flag := f.Uint64("uint64", 0, "uint64 value")
    160 	stringFlag := f.String("string", "0", "string value")
    161 	float64Flag := f.Float64("float64", 0, "float64 value")
    162 	durationFlag := f.Duration("duration", 5*time.Second, "time.Duration value")
    163 	extra := "one-extra-argument"
    164 	args := []string{
    165 		"-bool",
    166 		"-bool2=true",
    167 		"--int", "22",
    168 		"--int64", "0x23",
    169 		"-uint", "24",
    170 		"--uint64", "25",
    171 		"-string", "hello",
    172 		"-float64", "2718e28",
    173 		"-duration", "2m",
    174 		extra,
    175 	}
    176 	if err := f.Parse(args); err != nil {
    177 		t.Fatal(err)
    178 	}
    179 	if !f.Parsed() {
    180 		t.Error("f.Parse() = false after Parse")
    181 	}
    182 	if *boolFlag != true {
    183 		t.Error("bool flag should be true, is ", *boolFlag)
    184 	}
    185 	if *bool2Flag != true {
    186 		t.Error("bool2 flag should be true, is ", *bool2Flag)
    187 	}
    188 	if *intFlag != 22 {
    189 		t.Error("int flag should be 22, is ", *intFlag)
    190 	}
    191 	if *int64Flag != 0x23 {
    192 		t.Error("int64 flag should be 0x23, is ", *int64Flag)
    193 	}
    194 	if *uintFlag != 24 {
    195 		t.Error("uint flag should be 24, is ", *uintFlag)
    196 	}
    197 	if *uint64Flag != 25 {
    198 		t.Error("uint64 flag should be 25, is ", *uint64Flag)
    199 	}
    200 	if *stringFlag != "hello" {
    201 		t.Error("string flag should be `hello`, is ", *stringFlag)
    202 	}
    203 	if *float64Flag != 2718e28 {
    204 		t.Error("float64 flag should be 2718e28, is ", *float64Flag)
    205 	}
    206 	if *durationFlag != 2*time.Minute {
    207 		t.Error("duration flag should be 2m, is ", *durationFlag)
    208 	}
    209 	if len(f.Args()) != 1 {
    210 		t.Error("expected one argument, got", len(f.Args()))
    211 	} else if f.Args()[0] != extra {
    212 		t.Errorf("expected argument %q got %q", extra, f.Args()[0])
    213 	}
    214 }
    215 
    216 func TestParse(t *testing.T) {
    217 	ResetForTesting(func() { t.Error("bad parse") })
    218 	testParse(CommandLine, t)
    219 }
    220 
    221 func TestFlagSetParse(t *testing.T) {
    222 	testParse(NewFlagSet("test", ContinueOnError), t)
    223 }
    224 
    225 // Declare a user-defined flag type.
    226 type flagVar []string
    227 
    228 func (f *flagVar) String() string {
    229 	return fmt.Sprint([]string(*f))
    230 }
    231 
    232 func (f *flagVar) Set(value string) error {
    233 	*f = append(*f, value)
    234 	return nil
    235 }
    236 
    237 func TestUserDefined(t *testing.T) {
    238 	var flags FlagSet
    239 	flags.Init("test", ContinueOnError)
    240 	var v flagVar
    241 	flags.Var(&v, "v", "usage")
    242 	if err := flags.Parse([]string{"-v", "1", "-v", "2", "-v=3"}); err != nil {
    243 		t.Error(err)
    244 	}
    245 	if len(v) != 3 {
    246 		t.Fatal("expected 3 args; got ", len(v))
    247 	}
    248 	expect := "[1 2 3]"
    249 	if v.String() != expect {
    250 		t.Errorf("expected value %q got %q", expect, v.String())
    251 	}
    252 }
    253 
    254 func TestUserDefinedForCommandLine(t *testing.T) {
    255 	const help = "HELP"
    256 	var result string
    257 	ResetForTesting(func() { result = help })
    258 	Usage()
    259 	if result != help {
    260 		t.Fatalf("got %q; expected %q", result, help)
    261 	}
    262 }
    263 
    264 // Declare a user-defined boolean flag type.
    265 type boolFlagVar struct {
    266 	count int
    267 }
    268 
    269 func (b *boolFlagVar) String() string {
    270 	return fmt.Sprintf("%d", b.count)
    271 }
    272 
    273 func (b *boolFlagVar) Set(value string) error {
    274 	if value == "true" {
    275 		b.count++
    276 	}
    277 	return nil
    278 }
    279 
    280 func (b *boolFlagVar) IsBoolFlag() bool {
    281 	return b.count < 4
    282 }
    283 
    284 func TestUserDefinedBool(t *testing.T) {
    285 	var flags FlagSet
    286 	flags.Init("test", ContinueOnError)
    287 	var b boolFlagVar
    288 	var err error
    289 	flags.Var(&b, "b", "usage")
    290 	if err = flags.Parse([]string{"-b", "-b", "-b", "-b=true", "-b=false", "-b", "barg", "-b"}); err != nil {
    291 		if b.count < 4 {
    292 			t.Error(err)
    293 		}
    294 	}
    295 
    296 	if b.count != 4 {
    297 		t.Errorf("want: %d; got: %d", 4, b.count)
    298 	}
    299 
    300 	if err == nil {
    301 		t.Error("expected error; got none")
    302 	}
    303 }
    304 
    305 func TestSetOutput(t *testing.T) {
    306 	var flags FlagSet
    307 	var buf bytes.Buffer
    308 	flags.SetOutput(&buf)
    309 	flags.Init("test", ContinueOnError)
    310 	flags.Parse([]string{"-unknown"})
    311 	if out := buf.String(); !strings.Contains(out, "-unknown") {
    312 		t.Logf("expected output mentioning unknown; got %q", out)
    313 	}
    314 }
    315 
    316 // This tests that one can reset the flags. This still works but not well, and is
    317 // superseded by FlagSet.
    318 func TestChangingArgs(t *testing.T) {
    319 	ResetForTesting(func() { t.Fatal("bad parse") })
    320 	oldArgs := os.Args
    321 	defer func() { os.Args = oldArgs }()
    322 	os.Args = []string{"cmd", "-before", "subcmd", "-after", "args"}
    323 	before := Bool("before", false, "")
    324 	if err := CommandLine.Parse(os.Args[1:]); err != nil {
    325 		t.Fatal(err)
    326 	}
    327 	cmd := Arg(0)
    328 	os.Args = Args()
    329 	after := Bool("after", false, "")
    330 	Parse()
    331 	args := Args()
    332 
    333 	if !*before || cmd != "subcmd" || !*after || len(args) != 1 || args[0] != "args" {
    334 		t.Fatalf("expected true subcmd true [args] got %v %v %v %v", *before, cmd, *after, args)
    335 	}
    336 }
    337 
    338 // Test that -help invokes the usage message and returns ErrHelp.
    339 func TestHelp(t *testing.T) {
    340 	var helpCalled = false
    341 	fs := NewFlagSet("help test", ContinueOnError)
    342 	fs.Usage = func() { helpCalled = true }
    343 	var flag bool
    344 	fs.BoolVar(&flag, "flag", false, "regular flag")
    345 	// Regular flag invocation should work
    346 	err := fs.Parse([]string{"-flag=true"})
    347 	if err != nil {
    348 		t.Fatal("expected no error; got ", err)
    349 	}
    350 	if !flag {
    351 		t.Error("flag was not set by -flag")
    352 	}
    353 	if helpCalled {
    354 		t.Error("help called for regular flag")
    355 		helpCalled = false // reset for next test
    356 	}
    357 	// Help flag should work as expected.
    358 	err = fs.Parse([]string{"-help"})
    359 	if err == nil {
    360 		t.Fatal("error expected")
    361 	}
    362 	if err != ErrHelp {
    363 		t.Fatal("expected ErrHelp; got ", err)
    364 	}
    365 	if !helpCalled {
    366 		t.Fatal("help was not called")
    367 	}
    368 	// If we define a help flag, that should override.
    369 	var help bool
    370 	fs.BoolVar(&help, "help", false, "help flag")
    371 	helpCalled = false
    372 	err = fs.Parse([]string{"-help"})
    373 	if err != nil {
    374 		t.Fatal("expected no error for defined -help; got ", err)
    375 	}
    376 	if helpCalled {
    377 		t.Fatal("help was called; should not have been for defined help flag")
    378 	}
    379 }
    380 
    381 const defaultOutput = `  -A	for bootstrapping, allow 'any' type
    382   -Alongflagname
    383     	disable bounds checking
    384   -C	a boolean defaulting to true (default true)
    385   -D path
    386     	set relative path for local imports
    387   -F number
    388     	a non-zero number (default 2.7)
    389   -G float
    390     	a float that defaults to zero
    391   -N int
    392     	a non-zero int (default 27)
    393   -Z int
    394     	an int that defaults to zero
    395   -maxT timeout
    396     	set timeout for dial
    397 `
    398 
    399 func TestPrintDefaults(t *testing.T) {
    400 	fs := NewFlagSet("print defaults test", ContinueOnError)
    401 	var buf bytes.Buffer
    402 	fs.SetOutput(&buf)
    403 	fs.Bool("A", false, "for bootstrapping, allow 'any' type")
    404 	fs.Bool("Alongflagname", false, "disable bounds checking")
    405 	fs.Bool("C", true, "a boolean defaulting to true")
    406 	fs.String("D", "", "set relative `path` for local imports")
    407 	fs.Float64("F", 2.7, "a non-zero `number`")
    408 	fs.Float64("G", 0, "a float that defaults to zero")
    409 	fs.Int("N", 27, "a non-zero int")
    410 	fs.Int("Z", 0, "an int that defaults to zero")
    411 	fs.Duration("maxT", 0, "set `timeout` for dial")
    412 	fs.PrintDefaults()
    413 	got := buf.String()
    414 	if got != defaultOutput {
    415 		t.Errorf("got %q want %q\n", got, defaultOutput)
    416 	}
    417 }
    418