From 98f41d064a880fe6a9ce97d70d611300186649ec Mon Sep 17 00:00:00 2001 From: Peter Bourgon Date: Sun, 27 Jul 2025 13:17:28 +0200 Subject: [PATCH 1/2] Parse: mark as parsed, even on error --- command_test.go | 6 ++++++ flag_set.go | 41 ++++++++++++++++++++++------------------- flag_set_test.go | 10 ++++++++++ 3 files changed, 38 insertions(+), 19 deletions(-) diff --git a/command_test.go b/command_test.go index 8c4e957..6bbed33 100644 --- a/command_test.go +++ b/command_test.go @@ -18,15 +18,21 @@ func TestCommandNoFlags(t *testing.T) { cmd = &ff.Command{Name: "root"} ctx = context.Background() ) + if err := cmd.ParseAndRun(ctx, []string{"-h"}); !errors.Is(err, ff.ErrHelp) { t.Errorf("err: want %v, have %v", ff.ErrHelp, err) } + cmd.Reset() + if err := cmd.ParseAndRun(ctx, []string{"--help"}); !errors.Is(err, ff.ErrHelp) { t.Errorf("err: want %v, have %v", ff.ErrHelp, err) } + cmd.Reset() + if err := cmd.ParseAndRun(ctx, []string{}); !errors.Is(err, ff.ErrNoExec) { t.Errorf("err: want %v, have %v", ff.ErrNoExec, err) } + cmd.Reset() } func TestCommandReset(t *testing.T) { diff --git a/flag_set.go b/flag_set.go index 4b9e56a..83a6811 100644 --- a/flag_set.go +++ b/flag_set.go @@ -152,30 +152,33 @@ func (fs *FlagSet) GetName() string { // Parse the provided args against the flag set, assigning flag values as // appropriate. Args are matched to flags defined in this flag set, and, if a -// parent is set, all parent flag sets, recursively. If a specified flag can't -// be found, parse fails with [ErrUnknownFlag]. After a successful parse, -// subsequent calls to parse fail with [ErrAlreadyParsed], until and unless the -// flag set is reset. +// parent is set, all parent flag sets, recursively. +// +// Parse returns a nil error when it runs out of args to parse, or when it +// encounters the first non-flag argument. It returns a non-nil error when it +// encounters an unknown flag, or when setting a flag fails. Regardless of final +// outcome, once parse returns, any successfully-parsed args will have updated +// their corresponding flags, the flag set is marked as parsed, and any +// un-parsed args are made available via [FlagSet.GetArgs]. +// +// Use [FlagSet.Reset] to reset a parsed flag set back to its un-parsed state, +// including resetting all flags back to their defaults. func (fs *FlagSet) Parse(args []string) error { if fs.isParsed { return ErrAlreadyParsed } - err := fs.parseArgs(args) - switch { - case err == nil: - fs.isParsed = true - case err != nil: - fs.postParseArgs = []string{} - } + leftover, err := fs.parseArgs(args) + fs.postParseArgs = leftover + fs.isParsed = true return err } -func (fs *FlagSet) parseArgs(args []string) (err error) { +func (fs *FlagSet) parseArgs(args []string) ([]string, error) { // Credit where credit is due: this implementation is adapted from // https://pkg.go.dev/github.com/pborman/getopt/v2. - fs.postParseArgs = args + leftover := args for len(args) > 0 { arg := args[0] @@ -187,12 +190,12 @@ func (fs *FlagSet) parseArgs(args []string) (err error) { parseDone = isEmpty || noDash ) if parseDone { - return nil // fs.postParseArgs should include arg + return leftover, nil // leftover should include arg } if arg == "--" { - fs.postParseArgs = args // fs.postParseArgs should not include "--" - return nil + leftover = args // leftover should not include "--" + return leftover, nil } var ( @@ -217,13 +220,13 @@ func (fs *FlagSet) parseArgs(args []string) (err error) { args, parseErr = fs.parseLongFlag(arg, args) } if parseErr != nil { - return parseErr + return leftover, parseErr } - fs.postParseArgs = args // we parsed arg, so update fs.postParseArgs with the remainder + leftover = args // we parsed arg, so update leftover with the remainder } - return nil + return leftover, nil } // findFlag finds the first matching flag in the flags hierarchy. diff --git a/flag_set_test.go b/flag_set_test.go index 39493cf..5e106c0 100644 --- a/flag_set_test.go +++ b/flag_set_test.go @@ -507,6 +507,16 @@ func TestFlagSet_StructIgnoreReset(t *testing.T) { } } + { + args := []string{"--foo=1", "--baz=2"} + if err := fs.Parse(args); !errors.Is(err, ff.ErrAlreadyParsed) { + t.Errorf("ff.Parse(...): want %v, have %v", ff.ErrAlreadyParsed, err) + } + if err := fs.Reset(); err != nil { + t.Errorf("fs.Reset(): error: %v", err) + } + } + { args := []string{"--foo=1", "--baz=2"} if err := fs.Parse(args); err != nil { From 3d4d1351cd1b7a341ab6afc52468d922a2d8b2a8 Mon Sep 17 00:00:00 2001 From: Peter Bourgon Date: Sun, 27 Jul 2025 23:44:54 +0200 Subject: [PATCH 2/2] UnknownFlagError --- errors.go | 37 +++++++++++++++++++++++++++++++++++++ flag_set.go | 4 ++-- parse.go | 2 +- 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/errors.go b/errors.go index 3ed9080..62cac2e 100644 --- a/errors.go +++ b/errors.go @@ -3,6 +3,8 @@ package ff import ( "errors" "flag" + "fmt" + "strings" ) var ( @@ -34,3 +36,38 @@ var ( // ErrNoExec is returned when a command without an exec function is run. ErrNoExec = errors.New("no exec function") ) + +// + +// UnknownFlagError is an [ErrUnknownFlag] that wraps the name of the flag. +type UnknownFlagError struct { + flagName string // ideally includes leading -/-- but not required +} + +var _ error = (*UnknownFlagError)(nil) + +func newUnknownFlagError(flagName string) *UnknownFlagError { + return &UnknownFlagError{ + flagName: flagName, + } +} + +// Error implements the error interface. +func (e *UnknownFlagError) Error() string { + return fmt.Sprintf("%q: %v", e.flagName, ErrUnknownFlag) +} + +// Unwrap returns [ErrUnknownFlag]. +func (e *UnknownFlagError) Unwrap() error { + return ErrUnknownFlag +} + +// GetFlagName returns the unknown flag name, maybe including leading hyphens. +func (e *UnknownFlagError) GetFlagName() string { + return e.flagName +} + +// GetName returns the unknown flag name, trimmed of leading hyphens. +func (e *UnknownFlagError) GetName() string { + return strings.TrimPrefix(e.flagName, "-") +} diff --git a/flag_set.go b/flag_set.go index 83a6811..5f81328 100644 --- a/flag_set.go +++ b/flag_set.go @@ -268,7 +268,7 @@ func (fs *FlagSet) parseShortFlag(arg string, args []string) ([]string, error) { case r == 'h': return args, ErrHelp default: - return args, fmt.Errorf("%w %q", ErrUnknownFlag, string(r)) + return args, newUnknownFlagError("-" + string(r)) } } @@ -320,7 +320,7 @@ func (fs *FlagSet) parseLongFlag(arg string, args []string) ([]string, error) { case fs.isStdAdapter && strings.EqualFold(name, "h"): return nil, ErrHelp default: - return nil, fmt.Errorf("%w %q", ErrUnknownFlag, name) + return nil, newUnknownFlagError("--" + name) } } diff --git a/parse.go b/parse.go index 1bc19ce..a18fc07 100644 --- a/parse.go +++ b/parse.go @@ -183,7 +183,7 @@ func parse(fs Flags, args []string, options ...Option) error { case pc.configIgnoreUndefinedFlags: // not found, but that's OK return nil case !pc.configIgnoreUndefinedFlags: // not found, and that's not OK - return fmt.Errorf("%s: %w", name, ErrUnknownFlag) + return newUnknownFlagError(name) default: panic(fmt.Errorf("unexpected unreachable case for %s", name)) }