diff --git a/flag_groups.go b/flag_groups.go index 560612fd3..62eb64041 100644 --- a/flag_groups.go +++ b/flag_groups.go @@ -23,9 +23,10 @@ import ( ) const ( - requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set" - oneRequiredAnnotation = "cobra_annotation_one_required" - mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive" + annotationGroupRequired = "cobra_annotation_required_if_others_set" + annotationRequiredOne = "cobra_annotation_one_required" + annotationMutuallyExclusive = "cobra_annotation_mutually_exclusive" + annotationGroupDependent = "cobra_annotation_if_present_then_others_required" ) // MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors @@ -37,7 +38,7 @@ func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) { if f == nil { panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v)) } - if err := c.Flags().SetAnnotation(v, requiredAsGroupAnnotation, append(f.Annotations[requiredAsGroupAnnotation], strings.Join(flagNames, " "))); err != nil { + if err := c.Flags().SetAnnotation(v, annotationGroupRequired, append(f.Annotations[annotationGroupRequired], strings.Join(flagNames, " "))); err != nil { // Only errs if the flag isn't found. panic(err) } @@ -53,7 +54,7 @@ func (c *Command) MarkFlagsOneRequired(flagNames ...string) { if f == nil { panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v)) } - if err := c.Flags().SetAnnotation(v, oneRequiredAnnotation, append(f.Annotations[oneRequiredAnnotation], strings.Join(flagNames, " "))); err != nil { + if err := c.Flags().SetAnnotation(v, annotationRequiredOne, append(f.Annotations[annotationRequiredOne], strings.Join(flagNames, " "))); err != nil { // Only errs if the flag isn't found. panic(err) } @@ -70,7 +71,26 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v)) } // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. - if err := c.Flags().SetAnnotation(v, mutuallyExclusiveAnnotation, append(f.Annotations[mutuallyExclusiveAnnotation], strings.Join(flagNames, " "))); err != nil { + if err := c.Flags().SetAnnotation(v, annotationMutuallyExclusive, append(f.Annotations[annotationMutuallyExclusive], strings.Join(flagNames, " "))); err != nil { + panic(err) + } + } +} + +// MarkIfFlagPresentThenOthersRequired marks the given flags so that if the first flag is set, +// all the other flags become required. +func (c *Command) MarkIfFlagPresentThenOthersRequired(flagNames ...string) { + if len(flagNames) < 2 { + panic("MarkIfFlagPresentThenRequired requires at least two flags") + } + c.mergePersistentFlags() + for _, v := range flagNames { + f := c.Flags().Lookup(v) + if f == nil { + panic(fmt.Sprintf("Failed to find flag %q and mark it as being in an if present then others required flag group", v)) + } + // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. + if err := c.Flags().SetAnnotation(v, annotationGroupDependent, append(f.Annotations[annotationGroupDependent], strings.Join(flagNames, " "))); err != nil { panic(err) } } @@ -90,10 +110,12 @@ func (c *Command) ValidateFlagGroups() error { groupStatus := map[string]map[string]bool{} oneRequiredGroupStatus := map[string]map[string]bool{} mutuallyExclusiveGroupStatus := map[string]map[string]bool{} + ifPresentThenOthersRequiredGroupStatus := map[string]map[string]bool{} flags.VisitAll(func(pflag *flag.Flag) { - processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus) - processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus) - processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationGroupRequired, groupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationRequiredOne, oneRequiredGroupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationMutuallyExclusive, mutuallyExclusiveGroupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationGroupDependent, ifPresentThenOthersRequiredGroupStatus) }) if err := validateRequiredFlagGroups(groupStatus); err != nil { @@ -105,6 +127,9 @@ func (c *Command) ValidateFlagGroups() error { if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil { return err } + if err := validateIfPresentThenRequiredFlagGroups(ifPresentThenOthersRequiredGroupStatus); err != nil { + return err + } return nil } @@ -206,6 +231,38 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error { return nil } +func validateIfPresentThenRequiredFlagGroups(data map[string]map[string]bool) error { + for flagList, flagnameAndStatus := range data { + flags := strings.Split(flagList, " ") + primaryFlag := flags[0] + remainingFlags := flags[1:] + + // Handle missing primary flag entry + if _, exists := flagnameAndStatus[primaryFlag]; !exists { + flagnameAndStatus[primaryFlag] = false + } + + // Check if the primary flag is set + if flagnameAndStatus[primaryFlag] { + var unset []string + for _, flag := range remainingFlags { + if !flagnameAndStatus[flag] { + unset = append(unset, flag) + } + } + + // If any dependent flags are unset, trigger an error + if len(unset) > 0 { + return fmt.Errorf( + "%v is set, the following flags must be provided: %v", + primaryFlag, unset, + ) + } + } + } + return nil +} + func sortedKeys(m map[string]map[string]bool) []string { keys := make([]string, len(m)) i := 0 @@ -221,6 +278,7 @@ func sortedKeys(m map[string]map[string]bool) []string { // - when a flag in a group is present, other flags in the group will be marked required // - when none of the flags in a one-required group are present, all flags in the group will be marked required // - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden +// - when the first flag in an if-present-then-required group is present, the other flags will be marked as required // This allows the standard completion logic to behave appropriately for flag groups func (c *Command) enforceFlagGroupsForCompletion() { if c.DisableFlagParsing { @@ -231,10 +289,12 @@ func (c *Command) enforceFlagGroupsForCompletion() { groupStatus := map[string]map[string]bool{} oneRequiredGroupStatus := map[string]map[string]bool{} mutuallyExclusiveGroupStatus := map[string]map[string]bool{} + ifPresentThenRequiredGroupStatus := map[string]map[string]bool{} c.Flags().VisitAll(func(pflag *flag.Flag) { - processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus) - processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus) - processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationGroupRequired, groupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationRequiredOne, oneRequiredGroupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationMutuallyExclusive, mutuallyExclusiveGroupStatus) + processFlagForGroupAnnotation(flags, pflag, annotationGroupDependent, ifPresentThenRequiredGroupStatus) }) // If a flag that is part of a group is present, we make all the other flags @@ -287,4 +347,17 @@ func (c *Command) enforceFlagGroupsForCompletion() { } } } + + // If a flag that is marked as if-present-then-required is present, make other flags in the group required + for flagList, flagnameAndStatus := range ifPresentThenRequiredGroupStatus { + flags := strings.Split(flagList, " ") + primaryFlag := flags[0] + remainingFlags := flags[1:] + + if flagnameAndStatus[primaryFlag] { + for _, fName := range remainingFlags { + _ = c.MarkFlagRequired(fName) + } + } + } } diff --git a/flag_groups_test.go b/flag_groups_test.go index cffa85525..7b602bc94 100644 --- a/flag_groups_test.go +++ b/flag_groups_test.go @@ -43,22 +43,25 @@ func TestValidateFlagGroups(t *testing.T) { // Each test case uses a unique command from the function above. testcases := []struct { - desc string - flagGroupsRequired []string - flagGroupsOneRequired []string - flagGroupsExclusive []string - subCmdFlagGroupsRequired []string - subCmdFlagGroupsOneRequired []string - subCmdFlagGroupsExclusive []string - args []string - expectErr string + desc string + flagGroupsRequired []string + flagGroupsOneRequired []string + flagGroupsExclusive []string + flagGroupsIfPresentThenRequired []string + subCmdFlagGroupsRequired []string + subCmdFlagGroupsOneRequired []string + subCmdFlagGroupsExclusive []string + subCmdFlagGroupsIfPresentThenRequired []string + args []string + expectErr string }{ { desc: "No flags no problem", }, { - desc: "No flags no problem even with conflicting groups", - flagGroupsRequired: []string{"a b"}, - flagGroupsExclusive: []string{"a b"}, + desc: "No flags no problem even with conflicting groups", + flagGroupsRequired: []string{"a b"}, + flagGroupsExclusive: []string{"a b"}, + flagGroupsIfPresentThenRequired: []string{"a b", "b a"}, }, { desc: "Required flag group not satisfied", flagGroupsRequired: []string{"a b c"}, @@ -74,6 +77,11 @@ func TestValidateFlagGroups(t *testing.T) { flagGroupsExclusive: []string{"a b c"}, args: []string{"--a=foo", "--b=foo"}, expectErr: "if any flags in the group [a b c] are set none of the others can be; [a b] were all set", + }, { + desc: "If present then others required flag group not satisfied", + flagGroupsIfPresentThenRequired: []string{"a b"}, + args: []string{"--a=foo"}, + expectErr: "a is set, the following flags must be provided: [b]", }, { desc: "Multiple required flag group not satisfied returns first error", flagGroupsRequired: []string{"a b c", "a d"}, @@ -89,6 +97,12 @@ func TestValidateFlagGroups(t *testing.T) { flagGroupsExclusive: []string{"a b c", "a d"}, args: []string{"--a=foo", "--c=foo", "--d=foo"}, expectErr: `if any flags in the group [a b c] are set none of the others can be; [a c] were all set`, + }, + { + desc: "Multiple if present then others required flag group not satisfied returns first error", + flagGroupsIfPresentThenRequired: []string{"a b", "d e"}, + args: []string{"--a=foo", "--f=foo"}, + expectErr: `a is set, the following flags must be provided: [b]`, }, { desc: "Validation of required groups occurs on groups in sorted order", flagGroupsRequired: []string{"a d", "a b", "a c"}, @@ -182,6 +196,12 @@ func TestValidateFlagGroups(t *testing.T) { for _, flagGroup := range tc.subCmdFlagGroupsExclusive { sub.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...) } + for _, flagGroup := range tc.flagGroupsIfPresentThenRequired { + c.MarkIfFlagPresentThenOthersRequired(strings.Split(flagGroup, " ")...) + } + for _, flagGroup := range tc.subCmdFlagGroupsIfPresentThenRequired { + sub.MarkIfFlagPresentThenOthersRequired(strings.Split(flagGroup, " ")...) + } c.SetArgs(tc.args) err := c.Execute() switch { @@ -193,3 +213,72 @@ func TestValidateFlagGroups(t *testing.T) { }) } } + +func TestMarkIfFlagPresentThenOthersRequiredAnnotations(t *testing.T) { + // Create a new command with some flags. + cmd := &Command{ + Use: "testcmd", + } + f := cmd.Flags() + f.String("a", "", "flag a") + f.String("b", "", "flag b") + f.String("c", "", "flag c") + + // Call the function with one group: ["a", "b"]. + cmd.MarkIfFlagPresentThenOthersRequired("a", "b") + + // Check that flag "a" has the correct annotation. + aFlag := f.Lookup("a") + if aFlag == nil { + t.Fatal("Flag 'a' not found") + } + annA := aFlag.Annotations[annotationGroupDependent] + expected1 := "a b" // since strings.Join(["a","b"], " ") yields "a b" + if len(annA) != 1 || annA[0] != expected1 { + t.Errorf("Expected flag 'a' annotation to be [%q], got %v", expected1, annA) + } + + // Also check that flag "b" has the correct annotation. + bFlag := f.Lookup("b") + if bFlag == nil { + t.Fatal("Flag 'b' not found") + } + annB := bFlag.Annotations[annotationGroupDependent] + if len(annB) != 1 || annB[0] != expected1 { + t.Errorf("Expected flag 'b' annotation to be [%q], got %v", expected1, annB) + } + + // Now, call MarkIfFlagPresentThenOthersRequired again with a different group involving "a" and "c". + cmd.MarkIfFlagPresentThenOthersRequired("a", "c") + + // The annotation for flag "a" should now have both groups: "a b" and "a c" + annA = aFlag.Annotations[annotationGroupDependent] + expectedAnnotations := []string{"a b", "a c"} + if len(annA) != 2 { + t.Errorf("Expected 2 annotations on flag 'a', got %v", annA) + } + // Check that both expected annotation strings are present. + for _, expected := range expectedAnnotations { + found := false + for _, ann := range annA { + if ann == expected { + found = true + break + } + } + if !found { + t.Errorf("Expected annotation %q not found on flag 'a': %v", expected, annA) + } + } + + // Similarly, check that flag "c" now has the annotation "a c". + cFlag := f.Lookup("c") + if cFlag == nil { + t.Fatal("Flag 'c' not found") + } + annC := cFlag.Annotations[annotationGroupDependent] + expected2 := "a c" + if len(annC) != 1 || annC[0] != expected2 { + t.Errorf("Expected flag 'c' annotation to be [%q], got %v", expected2, annC) + } +}