Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,21 @@ func TestCommandNoFlags(t *testing.T) {
cmd = &ff.Command{Name: "root"}
ctx = context.Background()
)

if err := cmd.ParseAndRun(ctx, []string{"-h"}); !errors.Is(err, ff.ErrHelp) {
t.Errorf("err: want %v, have %v", ff.ErrHelp, err)
}
cmd.Reset()

if err := cmd.ParseAndRun(ctx, []string{"--help"}); !errors.Is(err, ff.ErrHelp) {
t.Errorf("err: want %v, have %v", ff.ErrHelp, err)
}
cmd.Reset()

if err := cmd.ParseAndRun(ctx, []string{}); !errors.Is(err, ff.ErrNoExec) {
t.Errorf("err: want %v, have %v", ff.ErrNoExec, err)
}
cmd.Reset()
}

func TestCommandReset(t *testing.T) {
Expand Down
37 changes: 37 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package ff
import (
"errors"
"flag"
"fmt"
"strings"
)

var (
Expand Down Expand Up @@ -34,3 +36,38 @@ var (
// ErrNoExec is returned when a command without an exec function is run.
ErrNoExec = errors.New("no exec function")
)

//

// UnknownFlagError is an [ErrUnknownFlag] that wraps the name of the flag.
type UnknownFlagError struct {
flagName string // ideally includes leading -/-- but not required
}

var _ error = (*UnknownFlagError)(nil)

func newUnknownFlagError(flagName string) *UnknownFlagError {
return &UnknownFlagError{
flagName: flagName,
}
}

// Error implements the error interface.
func (e *UnknownFlagError) Error() string {
return fmt.Sprintf("%q: %v", e.flagName, ErrUnknownFlag)
}

// Unwrap returns [ErrUnknownFlag].
func (e *UnknownFlagError) Unwrap() error {
return ErrUnknownFlag
}

// GetFlagName returns the unknown flag name, maybe including leading hyphens.
func (e *UnknownFlagError) GetFlagName() string {
return e.flagName
}

// GetName returns the unknown flag name, trimmed of leading hyphens.
func (e *UnknownFlagError) GetName() string {
return strings.TrimPrefix(e.flagName, "-")
}
45 changes: 24 additions & 21 deletions flag_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,30 +152,33 @@ func (fs *FlagSet) GetName() string {

// Parse the provided args against the flag set, assigning flag values as
// appropriate. Args are matched to flags defined in this flag set, and, if a
// parent is set, all parent flag sets, recursively. If a specified flag can't
// be found, parse fails with [ErrUnknownFlag]. After a successful parse,
// subsequent calls to parse fail with [ErrAlreadyParsed], until and unless the
// flag set is reset.
// parent is set, all parent flag sets, recursively.
//
// Parse returns a nil error when it runs out of args to parse, or when it
// encounters the first non-flag argument. It returns a non-nil error when it
// encounters an unknown flag, or when setting a flag fails. Regardless of final
// outcome, once parse returns, any successfully-parsed args will have updated
// their corresponding flags, the flag set is marked as parsed, and any
// un-parsed args are made available via [FlagSet.GetArgs].
//
// Use [FlagSet.Reset] to reset a parsed flag set back to its un-parsed state,
// including resetting all flags back to their defaults.
func (fs *FlagSet) Parse(args []string) error {
if fs.isParsed {
return ErrAlreadyParsed
}

err := fs.parseArgs(args)
switch {
case err == nil:
fs.isParsed = true
case err != nil:
fs.postParseArgs = []string{}
}
leftover, err := fs.parseArgs(args)
fs.postParseArgs = leftover
fs.isParsed = true
return err
}

func (fs *FlagSet) parseArgs(args []string) (err error) {
func (fs *FlagSet) parseArgs(args []string) ([]string, error) {
// Credit where credit is due: this implementation is adapted from
// https://pkg.go.dev/github.com/pborman/getopt/v2.

fs.postParseArgs = args
leftover := args

for len(args) > 0 {
arg := args[0]
Expand All @@ -187,12 +190,12 @@ func (fs *FlagSet) parseArgs(args []string) (err error) {
parseDone = isEmpty || noDash
)
if parseDone {
return nil // fs.postParseArgs should include arg
return leftover, nil // leftover should include arg
}

if arg == "--" {
fs.postParseArgs = args // fs.postParseArgs should not include "--"
return nil
leftover = args // leftover should not include "--"
return leftover, nil
}

var (
Expand All @@ -217,13 +220,13 @@ func (fs *FlagSet) parseArgs(args []string) (err error) {
args, parseErr = fs.parseLongFlag(arg, args)
}
if parseErr != nil {
return parseErr
return leftover, parseErr
}

fs.postParseArgs = args // we parsed arg, so update fs.postParseArgs with the remainder
leftover = args // we parsed arg, so update leftover with the remainder
}

return nil
return leftover, nil
}

// findFlag finds the first matching flag in the flags hierarchy.
Expand Down Expand Up @@ -265,7 +268,7 @@ func (fs *FlagSet) parseShortFlag(arg string, args []string) ([]string, error) {
case r == 'h':
return args, ErrHelp
default:
return args, fmt.Errorf("%w %q", ErrUnknownFlag, string(r))
return args, newUnknownFlagError("-" + string(r))
}
}

Expand Down Expand Up @@ -317,7 +320,7 @@ func (fs *FlagSet) parseLongFlag(arg string, args []string) ([]string, error) {
case fs.isStdAdapter && strings.EqualFold(name, "h"):
return nil, ErrHelp
default:
return nil, fmt.Errorf("%w %q", ErrUnknownFlag, name)
return nil, newUnknownFlagError("--" + name)
}
}

Expand Down
10 changes: 10 additions & 0 deletions flag_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,16 @@ func TestFlagSet_StructIgnoreReset(t *testing.T) {
}
}

{
args := []string{"--foo=1", "--baz=2"}
if err := fs.Parse(args); !errors.Is(err, ff.ErrAlreadyParsed) {
t.Errorf("ff.Parse(...): want %v, have %v", ff.ErrAlreadyParsed, err)
}
if err := fs.Reset(); err != nil {
t.Errorf("fs.Reset(): error: %v", err)
}
}

{
args := []string{"--foo=1", "--baz=2"}
if err := fs.Parse(args); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func parse(fs Flags, args []string, options ...Option) error {
case pc.configIgnoreUndefinedFlags: // not found, but that's OK
return nil
case !pc.configIgnoreUndefinedFlags: // not found, and that's not OK
return fmt.Errorf("%s: %w", name, ErrUnknownFlag)
return newUnknownFlagError(name)
default:
panic(fmt.Errorf("unexpected unreachable case for %s", name))
}
Expand Down
Loading