From 469141fb2c52e1f493aae3520f8e53a40a3fd896 Mon Sep 17 00:00:00 2001 From: Vadim Berezniker Date: Wed, 15 Jun 2022 08:39:15 -0700 Subject: [PATCH 1/3] Remove stale executors if the registration is older than 10 minutes. (#2137) Executors will now continously update their registration instead of only when idle. Scheduler will remove any executors that have not updated their registration within the last 10 minutes. Fixes https://github.com/buildbuddy-io/buildbuddy-internal/issues/1419 --- .../scheduler_client/scheduler_client.go | 12 +++++------- .../server/scheduling/scheduler_server/BUILD | 1 + .../scheduler_server/scheduler_server.go | 16 ++++++++++++++++ proto/BUILD | 1 + proto/scheduler.proto | 2 ++ 5 files changed, 25 insertions(+), 7 deletions(-) diff --git a/enterprise/server/scheduling/scheduler_client/scheduler_client.go b/enterprise/server/scheduling/scheduler_client/scheduler_client.go index d27549ff804..82c29c9bf9d 100644 --- a/enterprise/server/scheduling/scheduler_client/scheduler_client.go +++ b/enterprise/server/scheduling/scheduler_client/scheduler_client.go @@ -109,14 +109,11 @@ func (r *Registration) Check(ctx context.Context) error { return errors.New("not registered to scheduler yet") } -func (r *Registration) processWorkStream(ctx context.Context, stream scpb.Scheduler_RegisterAndStreamWorkClient, schedulerMsgs chan *scpb.RegisterAndStreamWorkResponse) (bool, error) { +func (r *Registration) processWorkStream(ctx context.Context, stream scpb.Scheduler_RegisterAndStreamWorkClient, schedulerMsgs chan *scpb.RegisterAndStreamWorkResponse, registrationTicker *time.Ticker) (bool, error) { registrationMsg := &scpb.RegisterAndStreamWorkRequest{ RegisterExecutorRequest: &scpb.RegisterExecutorRequest{Node: r.node}, } - idleTimer := time.NewTimer(schedulerCheckInInterval) - defer idleTimer.Stop() - select { case <-ctx.Done(): log.Debugf("Context cancelled, cancelling node registration.") @@ -157,9 +154,9 @@ func (r *Registration) processWorkStream(ctx context.Context, stream scpb.Schedu if err := stream.Send(rspMsg); err != nil { return false, status.UnavailableErrorf("could not send task reservation response: %s", err) } - case <-idleTimer.C: + case <-registrationTicker.C: if err := stream.Send(registrationMsg); err != nil { - return false, status.UnavailableErrorf("could not send idle registration message: %s", err) + return false, status.UnavailableErrorf("could not send registration message: %s", err) } } return false, nil @@ -207,8 +204,9 @@ func (r *Registration) maintainRegistrationAndStreamWork(ctx context.Context) { } }() + registrationTicker := time.NewTicker(schedulerCheckInInterval) for { - done, err := r.processWorkStream(ctx, stream, schedulerMsgs) + done, err := r.processWorkStream(ctx, stream, schedulerMsgs, registrationTicker) if err != nil { _ = stream.CloseSend() log.Warningf("Error maintaining registration with scheduler, will retry: %s", err) diff --git a/enterprise/server/scheduling/scheduler_server/BUILD b/enterprise/server/scheduling/scheduler_server/BUILD index 3db86447295..24efe88d0f1 100644 --- a/enterprise/server/scheduling/scheduler_server/BUILD +++ b/enterprise/server/scheduling/scheduler_server/BUILD @@ -29,6 +29,7 @@ go_library( "@org_golang_google_grpc//peer", "@org_golang_google_protobuf//encoding/prototext", "@org_golang_google_protobuf//proto", + "@org_golang_google_protobuf//types/known/timestamppb", ], ) diff --git a/enterprise/server/scheduling/scheduler_server/scheduler_server.go b/enterprise/server/scheduling/scheduler_server/scheduler_server.go index a56d58e8c53..d80b3c6528a 100644 --- a/enterprise/server/scheduling/scheduler_server/scheduler_server.go +++ b/enterprise/server/scheduling/scheduler_server/scheduler_server.go @@ -28,6 +28,7 @@ import ( "google.golang.org/grpc/peer" "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" remote_execution_config "github.com/buildbuddy-io/buildbuddy/enterprise/server/remote_execution/config" scheduler_server_config "github.com/buildbuddy-io/buildbuddy/enterprise/server/scheduling/scheduler_server/config" @@ -41,6 +42,7 @@ var ( defaultPoolName = flag.String("remote_execution.default_pool_name", "", "The default executor pool to use if one is not specified.") sharedExecutorPoolGroupID = flag.String("remote_execution.shared_executor_pool_group_id", "", "Group ID that owns the shared executor pool.") requireExecutorAuthorization = flag.Bool("remote_execution.require_executor_authorization", false, "If true, executors connecting to this server must provide a valid executor API key.") + removeStaleExecutors = flag.Bool("remote_execution.remove_stale_executors", false, "If true, executors are removed if they are not heard from for a prolonged amount of time.") ) const ( @@ -60,6 +62,10 @@ const ( // was fetched more than this duration ago, they will be re-fetched. maxAllowedExecutionNodesStaleness = 10 * time.Second + // An executor is removed if it does not refresh its registration within + // this amount of time. + executorMaxRegistrationStaleness = 10 * time.Minute + // The maximum number of times a task may be re-enqueued. maxTaskAttemptCount = 5 @@ -440,6 +446,15 @@ func (np *nodePool) fetchExecutionNodes(ctx context.Context) ([]*executionNode, if err != nil { return nil, err } + + if *removeStaleExecutors && time.Since(node.GetLastPingTime().AsTime()) > executorMaxRegistrationStaleness { + log.Infof("Removing stale executor %q from pool %+v", id, np.key) + if err := np.rdb.HDel(ctx, np.key.redisPoolKey(), id).Err(); err != nil { + log.Warningf("could not remove stale executor: %s", err) + } + continue + } + executors = append(executors, &executionNode{ executorID: id, schedulerHostPort: node.GetSchedulerHostPort(), @@ -900,6 +915,7 @@ func (s *SchedulerServer) insertOrUpdateNode(ctx context.Context, executorHandle SchedulerHostPort: s.ownHostPort, GroupId: groupID, Acl: acl, + LastPingTime: timestamppb.Now(), } b, err := proto.Marshal(r) if err != nil { diff --git a/proto/BUILD b/proto/BUILD index 0d2edd997dc..3837ac1d138 100644 --- a/proto/BUILD +++ b/proto/BUILD @@ -394,6 +394,7 @@ proto_library( ":acl_proto", ":context_proto", ":trace_proto", + "@com_google_protobuf//:timestamp_proto", ], ) diff --git a/proto/scheduler.proto b/proto/scheduler.proto index a54d5c91852..0360d5fa3b9 100644 --- a/proto/scheduler.proto +++ b/proto/scheduler.proto @@ -1,5 +1,6 @@ syntax = "proto3"; +import "google/protobuf/timestamp.proto"; import "proto/acl.proto"; import "proto/context.proto"; import "proto/trace.proto"; @@ -238,4 +239,5 @@ message RegisteredExecutionNode { string scheduler_host_port = 2; string group_id = 3; acl.ACL acl = 4; + google.protobuf.Timestamp last_ping_time = 5; } From 2c2786e49acc5f201dc17cba786e22b7136e9231 Mon Sep 17 00:00:00 2001 From: Zoey Greer Date: Wed, 15 Jun 2022 14:19:07 -0400 Subject: [PATCH 2/3] Enable deprecation of flags (#2130) * Enable deprecation of flags * address review comments * address review comments, clean up params --- server/util/flagutil/common/common.go | 131 ++++++++++++++++----- server/util/flagutil/common/common_test.go | 64 +++++----- server/util/flagutil/flagutil.go | 8 +- server/util/flagutil/types/BUILD | 1 - server/util/flagutil/types/types.go | 82 ++++++++++--- server/util/flagutil/types/types_test.go | 96 ++++++++++++--- server/util/flagutil/yaml/yaml.go | 47 ++++++-- server/util/flagutil/yaml/yaml_test.go | 28 ++++- server/util/testing/flags/flags.go | 4 +- 9 files changed, 345 insertions(+), 116 deletions(-) diff --git a/server/util/flagutil/common/common.go b/server/util/flagutil/common/common.go index d999e7c9b4f..23eb91d7103 100644 --- a/server/util/flagutil/common/common.go +++ b/server/util/flagutil/common/common.go @@ -38,64 +38,124 @@ func flagTypeFromFlagFuncName(name string) reflect.Type { } type TypeAliased interface { + // AliasedType returns the type this flag.Value aliases. AliasedType() reflect.Type } type IsNameAliasing interface { + // AliasedName returns the flag name this flag.Value aliases. AliasedName() string } +type WrappingValue interface { + // WrappedValue returns the value this flag.Value wraps. + WrappedValue() flag.Value +} + type Appendable interface { + // AppendSlice appends the passed slice to this flag.Value. AppendSlice(any) error } type DocumentNodeOption interface { + // Transform transforms the passed yaml.Node in place. Transform(in any, n *yaml.Node) + // Passthrough returns whether this option should be passed to child nodes. Passthrough() bool } -// GetTypeForFlag returns the (pointer) Type this flag aliases; this is the same +type SetValueForFlagNameHooked interface { + // SetValueForFlagNameHooked is the hook for flags that is called when the + // flag.Value is set by name. + SetValueForFlagNameHook() +} + +// GetTypeForFlagValue returns the (pointer) Type this flag aliases; this is the same // type returned when defining the flag initially. -func GetTypeForFlag(flg *flag.Flag) (reflect.Type, error) { - if t, ok := flagTypeMap[reflect.TypeOf(flg.Value)]; ok { +func GetTypeForFlagValue(value flag.Value) (reflect.Type, error) { + if v, ok := value.(WrappingValue); ok { + return GetTypeForFlagValue(v.WrappedValue()) + } + if t, ok := flagTypeMap[reflect.TypeOf(value)]; ok { return t, nil - } else if v, ok := flg.Value.(TypeAliased); ok { + } else if v, ok := value.(TypeAliased); ok { return v.AliasedType(), nil } - return nil, status.UnimplementedErrorf("Unsupported flag type at %s: %T", flg.Name, flg.Value) + return nil, status.UnimplementedErrorf("Unsupported flag type : %T", value) } -// SetValueForFlagName sets the value for a flag by name. -func SetValueForFlagName(name string, i any, setFlags map[string]struct{}, appendSlice bool, strict bool) error { +// SetValueForFlagName sets the value for a flag by name. setFlags is the set of +// flags that have already been set on the command line; those flags will not be +// set again except to append to them, in the case of slices. To force the +// setting of a flag, pass a nil map. If appendSlice is true, a slice value will +// be appended to the current slice value; otherwise, a slice value will replace +// the current slice value. appendSlice has no effect if the values in question +// are not slices. +func SetValueForFlagName(name string, newValue any, setFlags map[string]struct{}, appendSlice bool) error { flg := DefaultFlagSet.Lookup(name) if flg == nil { - if strict { - return status.NotFoundErrorf("Undefined flag: %s", name) - } - return nil + return status.NotFoundErrorf("Undefined flag: %s", name) } - // For slice flags, append the YAML values to the existing values if appendSlice is true - if v, ok := flg.Value.(Appendable); ok && appendSlice { - if err := v.AppendSlice(i); err != nil { - return status.InternalErrorf("Error encountered appending to flag %s: %s", flg.Name, err) + return setValueFromFlagName(flg.Value, name, newValue, setFlags, appendSlice) +} + +func setValueFromFlagName(flagValue flag.Value, name string, newValue any, setFlags map[string]struct{}, appendSlice bool, setHooks ...func()) error { + if v, ok := flagValue.(SetValueForFlagNameHooked); ok { + setHooks = append(setHooks, v.SetValueForFlagNameHook) + } + return SetValueWithCustomIndirectBehavior(flagValue, name, newValue, setFlags, appendSlice, setValueFromFlagName, setHooks...) +} + +type SetValueForIndirectFxn func(flagValue flag.Value, name string, newValue any, setFlags map[string]struct{}, appendSlice bool, setHooks ...func()) error + +// SetValueWithCustomIndirectBehavior sets the value for a flag, but if the flag +// passed is an alias for another flag or wraps another flag.Value, it instead +// calls setValueForIndirect with the new flag.Value. setFlags is the set of +// flags that have already been set on the command line; those flags will not be +// set again except to append to them, in the case of slices. To force the +// setting of a flag, pass a nil map. If appendSlice is true, a slice value will +// be appended to the current slice value; otherwise, a slice value will replace +// the current slice value. appendSlice has no effect if the values in question +// are not slices. setHooks is a slice of functions to call in order if the +// flag.Value will be set. +func SetValueWithCustomIndirectBehavior(flagValue flag.Value, name string, newValue any, setFlags map[string]struct{}, appendSlice bool, setValueForIndirect SetValueForIndirectFxn, setHooks ...func()) error { + if v, ok := flagValue.(IsNameAliasing); ok { + aliasedFlag := DefaultFlagSet.Lookup(v.AliasedName()) + if aliasedFlag == nil { + return status.NotFoundErrorf("Flag %s aliases undefined flag: %s", name, v.AliasedName()) } + return setValueForIndirect(aliasedFlag.Value, v.AliasedName(), newValue, setFlags, appendSlice, setHooks...) + } + // Unwrap any wrapper values (e.g. DeprecatedFlag) + if v, ok := flagValue.(WrappingValue); ok { + return setValueForIndirect(v.WrappedValue(), name, newValue, setFlags, appendSlice, setHooks...) + } + var appendFlag Appendable + // For slice flags, append the values to the existing values if appendSlice is true + if v, ok := flagValue.(Appendable); ok && appendSlice { + appendFlag = v + } + // For non-append flags, skip the value if it has already been set + if _, ok := setFlags[name]; appendFlag == nil && ok { return nil } - if v, ok := flg.Value.(IsNameAliasing); ok { - return SetValueForFlagName(v.AliasedName(), i, setFlags, appendSlice, strict) + for _, setHook := range setHooks { + setHook() } - // For non-append flags, skip the YAML values if it was set on the command line - if _, ok := setFlags[name]; ok { + if appendFlag != nil { + if err := appendFlag.AppendSlice(newValue); err != nil { + return status.InternalErrorf("Error encountered appending to flag %s: %s", name, err) + } return nil } - t, err := GetTypeForFlag(flg) + t, err := GetTypeForFlagValue(flagValue) if err != nil { - return status.UnimplementedErrorf("Error encountered setting flag: %s", err) + return status.UnimplementedErrorf("Error encountered setting flag %s: %s", name, err) } - if !reflect.ValueOf(i).CanConvert(t.Elem()) { - return status.FailedPreconditionErrorf("Cannot convert value %v of type %T into type %v for flag %s.", i, i, t.Elem(), flg.Name) + if !reflect.ValueOf(newValue).CanConvert(t.Elem()) { + return status.FailedPreconditionErrorf("Cannot convert value %v of type %T into type %v for flag %s.", newValue, newValue, t.Elem(), name) } - reflect.ValueOf(flg.Value).Convert(t).Elem().Set(reflect.ValueOf(i).Convert(t.Elem())) + reflect.ValueOf(flagValue).Convert(t).Elem().Set(reflect.ValueOf(newValue).Convert(t.Elem())) return nil } @@ -107,28 +167,37 @@ func GetDereferencedValue[T any](name string) (T, error) { if flg == nil { return *zeroT, status.NotFoundErrorf("Undefined flag: %s", name) } - if v, ok := flg.Value.(IsNameAliasing); ok { + return getDereferencedValueFrom[T](flg.Value, flg.Name) +} + +func getDereferencedValueFrom[T any](value flag.Value, name string) (T, error) { + zeroT := reflect.New(reflect.TypeOf((*T)(nil)).Elem()).Interface().(*T) + if v, ok := value.(IsNameAliasing); ok { return GetDereferencedValue[T](v.AliasedName()) } + // Unwrap any wrapper values (e.g. DeprecatedFlag) + if v, ok := value.(WrappingValue); ok { + return getDereferencedValueFrom[T](v.WrappedValue(), name) + } t := reflect.TypeOf((*T)(nil)) - addr := reflect.ValueOf(flg.Value) + addr := reflect.ValueOf(value) if t == reflect.TypeOf((*any)(nil)) { var err error - t, err = GetTypeForFlag(flg) + t, err = GetTypeForFlagValue(value) if err != nil { - return *zeroT, status.InternalErrorf("Error dereferencing flag to unspecified type: %s.", err) + return *zeroT, status.InternalErrorf("Error dereferencing flag %s to unspecified type: %s.", name, err) } if !addr.CanConvert(t) { - return *zeroT, status.InvalidArgumentErrorf("Flag %s of type %T could not be converted to %s.", name, flg.Value, t) + return *zeroT, status.InvalidArgumentErrorf("Flag %s of type %T could not be converted to %s.", name, value, t) } return addr.Convert(t).Elem().Interface().(T), nil } if !addr.CanConvert(t) { - return *zeroT, status.InvalidArgumentErrorf("Flag %s of type %T could not be converted to %s.", name, flg.Value, t) + return *zeroT, status.InvalidArgumentErrorf("Flag %s of type %T could not be converted to %s.", name, value, t) } v, ok := addr.Convert(t).Interface().(*T) if !ok { - return *zeroT, status.InternalErrorf("Failed to assert flag %s of type %T as type %s.", name, flg.Value, t) + return *zeroT, status.InternalErrorf("Failed to assert flag %s of type %T as type %s.", name, value, t) } return *v, nil } diff --git a/server/util/flagutil/common/common_test.go b/server/util/flagutil/common/common_test.go index 337c0ec92d7..c18ba81b5b3 100644 --- a/server/util/flagutil/common/common_test.go +++ b/server/util/flagutil/common/common_test.go @@ -36,89 +36,85 @@ func replaceFlagsForTesting(t *testing.T) *flag.FlagSet { func TestSetValueForFlagName(t *testing.T) { flags := replaceFlagsForTesting(t) flagBool := flags.Bool("bool", false, "") - err := common.SetValueForFlagName("bool", true, map[string]struct{}{}, true, true) + err := common.SetValueForFlagName("bool", true, map[string]struct{}{}, true) require.NoError(t, err) assert.Equal(t, true, *flagBool) flags = replaceFlagsForTesting(t) flagBool = flags.Bool("bool", false, "") - err = common.SetValueForFlagName("bool", true, map[string]struct{}{"bool": {}}, true, true) + err = common.SetValueForFlagName("bool", true, map[string]struct{}{"bool": {}}, true) require.NoError(t, err) assert.Equal(t, false, *flagBool) - flags = replaceFlagsForTesting(t) - err = common.SetValueForFlagName("bool", true, map[string]struct{}{}, true, false) - require.NoError(t, err) - flags = replaceFlagsForTesting(t) flagInt := flags.Int("int", 2, "") - err = common.SetValueForFlagName("int", 1, map[string]struct{}{}, true, true) + err = common.SetValueForFlagName("int", 1, map[string]struct{}{}, true) require.NoError(t, err) assert.Equal(t, 1, *flagInt) flags = replaceFlagsForTesting(t) flagInt = flags.Int("int", 2, "") - err = common.SetValueForFlagName("int", 1, map[string]struct{}{"int": {}}, true, true) + err = common.SetValueForFlagName("int", 1, map[string]struct{}{"int": {}}, true) require.NoError(t, err) assert.Equal(t, 2, *flagInt) flags = replaceFlagsForTesting(t) flagInt64 := flags.Int64("int64", 2, "") - err = common.SetValueForFlagName("int64", 1, map[string]struct{}{}, true, true) + err = common.SetValueForFlagName("int64", 1, map[string]struct{}{}, true) require.NoError(t, err) assert.Equal(t, int64(1), *flagInt64) flags = replaceFlagsForTesting(t) flagInt64 = flags.Int64("int64", 2, "") - err = common.SetValueForFlagName("int64", 1, map[string]struct{}{"int64": {}}, true, true) + err = common.SetValueForFlagName("int64", 1, map[string]struct{}{"int64": {}}, true) require.NoError(t, err) assert.Equal(t, int64(2), *flagInt64) flags = replaceFlagsForTesting(t) flagUint := flags.Uint("uint", 2, "") - err = common.SetValueForFlagName("uint", 1, map[string]struct{}{}, true, true) + err = common.SetValueForFlagName("uint", 1, map[string]struct{}{}, true) require.NoError(t, err) assert.Equal(t, uint(1), *flagUint) flags = replaceFlagsForTesting(t) flagUint = flags.Uint("uint", 2, "") - err = common.SetValueForFlagName("uint", 1, map[string]struct{}{"uint": {}}, true, true) + err = common.SetValueForFlagName("uint", 1, map[string]struct{}{"uint": {}}, true) require.NoError(t, err) assert.Equal(t, uint(2), *flagUint) flags = replaceFlagsForTesting(t) flagUint64 := flags.Uint64("uint64", 2, "") - err = common.SetValueForFlagName("uint64", 1, map[string]struct{}{}, true, true) + err = common.SetValueForFlagName("uint64", 1, map[string]struct{}{}, true) require.NoError(t, err) assert.Equal(t, uint64(1), *flagUint64) flags = replaceFlagsForTesting(t) flagUint64 = flags.Uint64("uint64", 2, "") - err = common.SetValueForFlagName("uint64", 1, map[string]struct{}{"uint64": {}}, true, true) + err = common.SetValueForFlagName("uint64", 1, map[string]struct{}{"uint64": {}}, true) require.NoError(t, err) assert.Equal(t, uint64(2), *flagUint64) flags = replaceFlagsForTesting(t) flagFloat64 := flags.Float64("float64", 2, "") - err = common.SetValueForFlagName("float64", 1, map[string]struct{}{}, true, true) + err = common.SetValueForFlagName("float64", 1, map[string]struct{}{}, true) require.NoError(t, err) assert.Equal(t, float64(1), *flagFloat64) flags = replaceFlagsForTesting(t) flagFloat64 = flags.Float64("float64", 2, "") - err = common.SetValueForFlagName("float64", 1, map[string]struct{}{"float64": {}}, true, true) + err = common.SetValueForFlagName("float64", 1, map[string]struct{}{"float64": {}}, true) require.NoError(t, err) assert.Equal(t, float64(2), *flagFloat64) flags = replaceFlagsForTesting(t) flagString := flags.String("string", "2", "") - err = common.SetValueForFlagName("string", "1", map[string]struct{}{}, true, true) + err = common.SetValueForFlagName("string", "1", map[string]struct{}{}, true) require.NoError(t, err) assert.Equal(t, "1", *flagString) flags = replaceFlagsForTesting(t) flagString = flags.String("string", "2", "") - err = common.SetValueForFlagName("string", "1", map[string]struct{}{"string": {}}, true, true) + err = common.SetValueForFlagName("string", "1", map[string]struct{}{"string": {}}, true) require.NoError(t, err) assert.Equal(t, "2", *flagString) @@ -126,7 +122,7 @@ func TestSetValueForFlagName(t *testing.T) { flagURL := flagtypes.URLFromString("url", "https://www.example.com", "") u, err := url.Parse("https://www.example.com:8080") require.NoError(t, err) - err = common.SetValueForFlagName("url", *u, map[string]struct{}{}, true, true) + err = common.SetValueForFlagName("url", *u, map[string]struct{}{}, true) require.NoError(t, err) assert.Equal(t, url.URL{Scheme: "https", Host: "www.example.com:8080"}, *flagURL) @@ -134,7 +130,7 @@ func TestSetValueForFlagName(t *testing.T) { flagURL = flagtypes.URLFromString("url", "https://www.example.com", "") u, err = url.Parse("https://www.example.com:8080") require.NoError(t, err) - err = common.SetValueForFlagName("url", *u, map[string]struct{}{"url": {}}, true, true) + err = common.SetValueForFlagName("url", *u, map[string]struct{}{"url": {}}, true) require.NoError(t, err) assert.Equal(t, url.URL{Scheme: "https", Host: "www.example.com"}, *flagURL) @@ -143,53 +139,53 @@ func TestSetValueForFlagName(t *testing.T) { string_slice[0] = "1" string_slice[1] = "2" flagtypes.SliceVar(&string_slice, "string_slice", "") - err = common.SetValueForFlagName("string_slice", []string{"3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, map[string]struct{}{}, true, true) + err = common.SetValueForFlagName("string_slice", []string{"3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, map[string]struct{}{}, true) require.NoError(t, err) assert.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, string_slice) flags = replaceFlagsForTesting(t) flagStringSlice := flagtypes.Slice("string_slice", []string{"1", "2"}, "") - err = common.SetValueForFlagName("string_slice", []string{"3"}, map[string]struct{}{"string_slice": {}}, true, true) + err = common.SetValueForFlagName("string_slice", []string{"3"}, map[string]struct{}{"string_slice": {}}, true) require.NoError(t, err) assert.Equal(t, []string{"1", "2", "3"}, *flagStringSlice) flags = replaceFlagsForTesting(t) flagStringSlice = flagtypes.Slice("string_slice", []string{"1", "2"}, "") - err = common.SetValueForFlagName("string_slice", []string{"3"}, map[string]struct{}{}, false, true) + err = common.SetValueForFlagName("string_slice", []string{"3"}, map[string]struct{}{}, false) require.NoError(t, err) assert.Equal(t, []string{"3"}, *flagStringSlice) flags = replaceFlagsForTesting(t) flagStringSlice = flagtypes.Slice("string_slice", []string{"1", "2"}, "") - err = common.SetValueForFlagName("string_slice", []string{"3"}, map[string]struct{}{"string_slice": {}}, false, true) + err = common.SetValueForFlagName("string_slice", []string{"3"}, map[string]struct{}{"string_slice": {}}, false) require.NoError(t, err) assert.Equal(t, []string{"1", "2"}, *flagStringSlice) flags = replaceFlagsForTesting(t) flagStructSlice := []testStruct{{Field: 1}, {Field: 2}} flagtypes.SliceVar(&flagStructSlice, "struct_slice", "") - err = common.SetValueForFlagName("struct_slice", []testStruct{{Field: 3}}, map[string]struct{}{}, true, true) + err = common.SetValueForFlagName("struct_slice", []testStruct{{Field: 3}}, map[string]struct{}{}, true) require.NoError(t, err) assert.Equal(t, []testStruct{{Field: 1}, {Field: 2}, {Field: 3}}, flagStructSlice) flags = replaceFlagsForTesting(t) flagStructSlice = []testStruct{{Field: 1}, {Field: 2}} flagtypes.SliceVar(&flagStructSlice, "struct_slice", "") - err = common.SetValueForFlagName("struct_slice", []testStruct{{Field: 3}}, map[string]struct{}{"struct_slice": {}}, true, true) + err = common.SetValueForFlagName("struct_slice", []testStruct{{Field: 3}}, map[string]struct{}{"struct_slice": {}}, true) require.NoError(t, err) assert.Equal(t, []testStruct{{Field: 1}, {Field: 2}, {Field: 3}}, flagStructSlice) flags = replaceFlagsForTesting(t) flagStructSlice = []testStruct{{Field: 1}, {Field: 2}} flagtypes.SliceVar(&flagStructSlice, "struct_slice", "") - err = common.SetValueForFlagName("struct_slice", []testStruct{{Field: 3}}, map[string]struct{}{}, false, true) + err = common.SetValueForFlagName("struct_slice", []testStruct{{Field: 3}}, map[string]struct{}{}, false) require.NoError(t, err) assert.Equal(t, []testStruct{{Field: 3}}, flagStructSlice) flags = replaceFlagsForTesting(t) flagStructSlice = []testStruct{{Field: 1}, {Field: 2}} flagtypes.SliceVar(&flagStructSlice, "struct_slice", "") - err = common.SetValueForFlagName("struct_slice", []testStruct{{Field: 3}}, map[string]struct{}{"struct_slice": struct{}{}}, false, true) + err = common.SetValueForFlagName("struct_slice", []testStruct{{Field: 3}}, map[string]struct{}{"struct_slice": {}}, false) require.NoError(t, err) assert.Equal(t, []testStruct{{Field: 1}, {Field: 2}}, flagStructSlice) } @@ -197,33 +193,33 @@ func TestSetValueForFlagName(t *testing.T) { func TestBadSetValueForFlagName(t *testing.T) { flags := replaceFlagsForTesting(t) _ = flags.Bool("bool", false, "") - err := common.SetValueForFlagName("bool", 0, map[string]struct{}{}, true, true) + err := common.SetValueForFlagName("bool", 0, map[string]struct{}{}, true) require.Error(t, err) flags = replaceFlagsForTesting(t) - err = common.SetValueForFlagName("bool", false, map[string]struct{}{}, true, true) + err = common.SetValueForFlagName("bool", false, map[string]struct{}{}, true) require.Error(t, err) flags = replaceFlagsForTesting(t) _ = flagtypes.Slice("string_slice", []string{"1", "2"}, "") - err = common.SetValueForFlagName("string_slice", "3", map[string]struct{}{}, true, true) + err = common.SetValueForFlagName("string_slice", "3", map[string]struct{}{}, true) require.Error(t, err) flags = replaceFlagsForTesting(t) _ = flagtypes.Slice("string_slice", []string{"1", "2"}, "") - err = common.SetValueForFlagName("string_slice", "3", map[string]struct{}{}, false, true) + err = common.SetValueForFlagName("string_slice", "3", map[string]struct{}{}, false) require.Error(t, err) flags = replaceFlagsForTesting(t) flagStructSlice := []testStruct{{Field: 1}, {Field: 2}} flagtypes.SliceVar(&flagStructSlice, "struct_slice", "") - err = common.SetValueForFlagName("struct_slice", testStruct{Field: 3}, map[string]struct{}{}, true, true) + err = common.SetValueForFlagName("struct_slice", testStruct{Field: 3}, map[string]struct{}{}, true) require.Error(t, err) flags = replaceFlagsForTesting(t) flagStructSlice = []testStruct{{Field: 1}, {Field: 2}} flagtypes.SliceVar(&flagStructSlice, "struct_slice", "") - err = common.SetValueForFlagName("struct_slice", testStruct{Field: 3}, map[string]struct{}{}, false, true) + err = common.SetValueForFlagName("struct_slice", testStruct{Field: 3}, map[string]struct{}{}, false) require.Error(t, err) } diff --git a/server/util/flagutil/flagutil.go b/server/util/flagutil/flagutil.go index c7e39ed02fe..089a60f5e06 100644 --- a/server/util/flagutil/flagutil.go +++ b/server/util/flagutil/flagutil.go @@ -4,7 +4,13 @@ import ( "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/common" ) -// SetValueForFlagName sets the value for a flag by name. +// SetValueForFlagName sets the value for a flag by name. setFlags is the set of +// flags that have already been set on the command line; those flags will not be +// set again except to append to them, in the case of slices. To force the +// setting of a flag, pass a nil map. If appendSlice is true, a slice value will +// be appended to the current slice value; otherwise, a slice value will replace +// the current slice value. appendSlice has no effect if the values in question +// are not slices. var SetValueForFlagName = common.SetValueForFlagName // GetDereferencedValue retypes and returns the dereferenced Value for diff --git a/server/util/flagutil/types/BUILD b/server/util/flagutil/types/BUILD index 21bb0dc0b39..feb694a1e03 100644 --- a/server/util/flagutil/types/BUILD +++ b/server/util/flagutil/types/BUILD @@ -8,7 +8,6 @@ go_library( deps = [ "//server/util/alert", "//server/util/flagutil/common", - "//server/util/flagutil/yaml", "//server/util/log", "//server/util/status", "@in_gopkg_yaml_v3//:yaml_v3", diff --git a/server/util/flagutil/types/types.go b/server/util/flagutil/types/types.go index 773d92ad30d..a3d514c11c0 100644 --- a/server/util/flagutil/types/types.go +++ b/server/util/flagutil/types/types.go @@ -7,16 +7,44 @@ import ( "net/url" "reflect" "strings" + "time" "github.com/buildbuddy-io/buildbuddy/server/util/alert" "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/common" "github.com/buildbuddy-io/buildbuddy/server/util/log" "github.com/buildbuddy-io/buildbuddy/server/util/status" "gopkg.in/yaml.v3" - - flagyaml "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/yaml" ) +// NewPrimitiveFlagVar returns a flag.Value derived from the given primitive pointer. +func NewPrimitiveFlagVar[T bool | time.Duration | float64 | int | int64 | uint | uint64 | string](value *T) flag.Value { + fs := flag.NewFlagSet("", flag.ContinueOnError) + switch v := any(value).(type) { + case *bool: + fs.BoolVar(v, "", *v, "") + case *time.Duration: + fs.DurationVar(v, "", *v, "") + case *float64: + fs.Float64Var(v, "", *v, "") + case *int: + fs.IntVar(v, "", *v, "") + case *int64: + fs.Int64Var(v, "", *v, "") + case *uint: + fs.UintVar(v, "", *v, "") + case *uint64: + fs.Uint64Var(v, "", *v, "") + case *string: + fs.StringVar(v, "", *v, "") + } + return fs.Lookup("").Value +} + +// NewPrimitiveFlagVar returns a flag.Value derived from the given primitive. +func NewPrimitiveFlag[T bool | time.Duration | float64 | int | int64 | uint | uint64 | string](value T) flag.Value { + return NewPrimitiveFlagVar(&value) +} + type SliceFlag[T any] []T func NewSliceFlag[T any](slice *[]T) *SliceFlag[T] { @@ -169,7 +197,7 @@ func Alias[T any](newName, name string) *T { } } addr := reflect.ValueOf(flg.Value) - if t, err := common.GetTypeForFlag(flg); err == nil { + if t, err := common.GetTypeForFlagValue(flg.Value); err == nil { if !addr.CanConvert(t) { log.Fatalf("Error aliasing flag %s as %s: Flag %s of type %T could not be converted to %s.", name, newName, flg.Name, flg.Value, t) } @@ -188,27 +216,45 @@ func (f *FlagAlias) Set(value string) error { } func (f *FlagAlias) String() string { - return common.DefaultFlagSet.Lookup(f.name).Value.String() + return f.WrappedValue().String() } func (f *FlagAlias) AliasedName() string { return f.name } -func (f *FlagAlias) AliasedType() reflect.Type { - flg := common.DefaultFlagSet.Lookup(f.name) - t, err := common.GetTypeForFlag(flg) - if err != nil { - return reflect.TypeOf(flg.Value) - } - return t +func (f *FlagAlias) WrappedValue() flag.Value { + return common.DefaultFlagSet.Lookup(f.name).Value } -func (f *FlagAlias) YAMLTypeAlias() reflect.Type { - flg := common.DefaultFlagSet.Lookup(f.name) - t, err := flagyaml.GetYAMLTypeForFlag(flg) - if err != nil { - return reflect.TypeOf(flg.Value) - } - return t +type DeprecatedFlag struct { + flag.Value + name string + migrationPlan string +} + +// DeprecatedVar takes a flag.Value (which can be obtained for primitive types +// via the NewPrimitiveFlag or NewPrimitiveFlagVar functions), the customary +// name and usage parameters, and a migration plan, and defines a flag that will +// notify users that it is deprecated when it is set. +func DeprecatedVar[T any](value flag.Value, name string, usage, migrationPlan string) *T { + common.DefaultFlagSet.Var(&DeprecatedFlag{value, name, migrationPlan}, name, usage+" **DEPRECATED** "+migrationPlan) + return reflect.ValueOf(value).Convert(reflect.TypeOf((*T)(nil))).Interface().(*T) +} + +func (d *DeprecatedFlag) Set(value string) error { + log.Warningf("Flag \"%s\" was set on the command line but has been deprecated: %s", d.name, d.migrationPlan) + return d.Value.Set(value) +} + +func (d *DeprecatedFlag) WrappedValue() flag.Value { + return d.Value +} + +func (d *DeprecatedFlag) SetValueForFlagNameHook() { + log.Warningf("Flag \"%s\" was set programmatically by name but has been deprecated: %s", d.name, d.migrationPlan) +} + +func (d *DeprecatedFlag) YAMLSetValueHook() { + log.Warningf("Flag \"%s\" was set through the YAML config but has been deprecated: %s", d.name, d.migrationPlan) } diff --git a/server/util/flagutil/types/types_test.go b/server/util/flagutil/types/types_test.go index 5adc630c7e1..82d708a156f 100644 --- a/server/util/flagutil/types/types_test.go +++ b/server/util/flagutil/types/types_test.go @@ -178,14 +178,22 @@ func TestFlagAlias(t *testing.T) { asf := flags.Lookup("string_alias").Value.(*FlagAlias) assert.Equal(t, "meow", asf.String()) assert.Equal(t, "string", asf.AliasedName()) - assert.Equal(t, reflect.TypeOf((*string)(nil)), asf.AliasedType()) - assert.Equal(t, reflect.TypeOf((*string)(nil)), asf.YAMLTypeAlias()) + asfType, err := common.GetTypeForFlagValue(asf) + require.NoError(t, err) + assert.Equal(t, reflect.TypeOf((*string)(nil)), asfType) + asfYAMLType, err := flagyaml.GetYAMLTypeForFlagValue(asf) + require.NoError(t, err) + assert.Equal(t, reflect.TypeOf((*string)(nil)), asfYAMLType) aasf := flags.Lookup("string_alias").Value.(*FlagAlias) assert.Equal(t, "meow", aasf.String()) assert.Equal(t, "string", aasf.AliasedName()) - assert.Equal(t, reflect.TypeOf((*string)(nil)), aasf.AliasedType()) - assert.Equal(t, reflect.TypeOf((*string)(nil)), aasf.YAMLTypeAlias()) + aasfType, err := common.GetTypeForFlagValue(asf) + require.NoError(t, err) + assert.Equal(t, reflect.TypeOf((*string)(nil)), aasfType) + aasfYAMLType, err := flagyaml.GetYAMLTypeForFlagValue(asf) + require.NoError(t, err) + assert.Equal(t, reflect.TypeOf((*string)(nil)), aasfYAMLType) flags = replaceFlagsForTesting(t) @@ -199,7 +207,7 @@ string_alias2: "moo" string_alias3: "oink" string_alias: "meow" ` - err := flagyaml.PopulateFlagsFromData([]byte(yamlData)) + err = flagyaml.PopulateFlagsFromData([]byte(yamlData)) require.NoError(t, err) assert.Equal(t, "meow", *flagString) @@ -209,6 +217,7 @@ string_alias: "meow" Alias[[]string]("string_slice_alias", "string_slice") Alias[[]string]("string_slice_alias2", "string_slice") Alias[[]string]("string_slice_alias3", "string_slice") + flags.Set("string_slice", "squeak") yamlData = ` string_slice: - "woof" @@ -222,7 +231,7 @@ string_slice_alias: ` err = flagyaml.PopulateFlagsFromData([]byte(yamlData)) require.NoError(t, err) - assert.Equal(t, []string{"test", "woof", "moo", "oink", "ribbit", "meow"}, *flagStringSlice) + assert.Equal(t, []string{"test", "squeak", "woof", "moo", "oink", "ribbit", "meow"}, *flagStringSlice) flags = replaceFlagsForTesting(t) @@ -274,14 +283,14 @@ string_alias: "meow" flags = replaceFlagsForTesting(t) flagString = flags.String("string", "2", "") Alias[string]("string_alias", "string") - err = common.SetValueForFlagName("string_alias", "1", map[string]struct{}{}, true, true) + err = common.SetValueForFlagName("string_alias", "1", map[string]struct{}{}, true) require.NoError(t, err) assert.Equal(t, "1", *flagString) flags = replaceFlagsForTesting(t) flagString = flags.String("string", "2", "") Alias[string]("string_alias", "string") - err = common.SetValueForFlagName("string_alias", "1", map[string]struct{}{"string": {}}, true, true) + err = common.SetValueForFlagName("string_alias", "1", map[string]struct{}{"string": {}}, true) require.NoError(t, err) assert.Equal(t, "2", *flagString) @@ -291,28 +300,28 @@ string_alias: "meow" string_slice[1] = "2" SliceVar(&string_slice, "string_slice", "") Alias[[]string]("string_slice_alias", "string_slice") - err = common.SetValueForFlagName("string_slice_alias", []string{"3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, map[string]struct{}{}, true, true) + err = common.SetValueForFlagName("string_slice_alias", []string{"3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, map[string]struct{}{}, true) require.NoError(t, err) assert.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, string_slice) flags = replaceFlagsForTesting(t) flagStringSlice = Slice("string_slice", []string{"1", "2"}, "") Alias[[]string]("string_slice_alias", "string_slice") - err = common.SetValueForFlagName("string_slice_alias", []string{"3"}, map[string]struct{}{"string_slice": {}}, true, true) + err = common.SetValueForFlagName("string_slice_alias", []string{"3"}, map[string]struct{}{"string_slice": {}}, true) require.NoError(t, err) assert.Equal(t, []string{"1", "2", "3"}, *flagStringSlice) flags = replaceFlagsForTesting(t) flagStringSlice = Slice("string_slice", []string{"1", "2"}, "") Alias[[]string]("string_slice_alias", "string_slice") - err = common.SetValueForFlagName("string_slice_alias", []string{"3"}, map[string]struct{}{}, false, true) + err = common.SetValueForFlagName("string_slice_alias", []string{"3"}, map[string]struct{}{}, false) require.NoError(t, err) assert.Equal(t, []string{"3"}, *flagStringSlice) flags = replaceFlagsForTesting(t) flagStringSlice = Slice("string_slice", []string{"1", "2"}, "") Alias[[]string]("string_slice_alias", "string_slice") - err = common.SetValueForFlagName("string_slice_alias", []string{"3"}, map[string]struct{}{"string_slice": {}}, false, true) + err = common.SetValueForFlagName("string_slice_alias", []string{"3"}, map[string]struct{}{"string_slice": {}}, false) require.NoError(t, err) assert.Equal(t, []string{"1", "2"}, *flagStringSlice) @@ -323,7 +332,7 @@ string_alias: "meow" SliceVar(&string_slice, "string_slice", "") Alias[[]string]("string_slice_alias", "string_slice") Alias[[]string]("string_slice_alias_alias", "string_slice_alias") - err = common.SetValueForFlagName("string_slice_alias_alias", []string{"3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, map[string]struct{}{}, true, true) + err = common.SetValueForFlagName("string_slice_alias_alias", []string{"3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, map[string]struct{}{}, true) require.NoError(t, err) assert.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, string_slice) @@ -331,7 +340,7 @@ string_alias: "meow" flagStringSlice = Slice("string_slice", []string{"1", "2"}, "") Alias[[]string]("string_slice_alias", "string_slice") Alias[[]string]("string_slice_alias_alias", "string_slice_alias") - err = common.SetValueForFlagName("string_slice_alias_alias", []string{"3"}, map[string]struct{}{"string_slice": {}}, true, true) + err = common.SetValueForFlagName("string_slice_alias_alias", []string{"3"}, map[string]struct{}{"string_slice": {}}, true) require.NoError(t, err) assert.Equal(t, []string{"1", "2", "3"}, *flagStringSlice) @@ -339,7 +348,7 @@ string_alias: "meow" flagStringSlice = Slice("string_slice", []string{"1", "2"}, "") Alias[[]string]("string_slice_alias", "string_slice") Alias[[]string]("string_slice_alias_alias", "string_slice_alias") - err = common.SetValueForFlagName("string_slice_alias_alias", []string{"3"}, map[string]struct{}{}, false, true) + err = common.SetValueForFlagName("string_slice_alias_alias", []string{"3"}, map[string]struct{}{}, false) require.NoError(t, err) assert.Equal(t, []string{"3"}, *flagStringSlice) @@ -347,7 +356,7 @@ string_alias: "meow" flagStringSlice = Slice("string_slice", []string{"1", "2"}, "") Alias[[]string]("string_slice_alias", "string_slice") Alias[[]string]("string_slice_alias_alias", "string_slice_alias") - err = common.SetValueForFlagName("string_slice_alias_alias", []string{"3"}, map[string]struct{}{"string_slice": {}}, false, true) + err = common.SetValueForFlagName("string_slice_alias_alias", []string{"3"}, map[string]struct{}{"string_slice": {}}, false) require.NoError(t, err) assert.Equal(t, []string{"1", "2"}, *flagStringSlice) @@ -373,3 +382,58 @@ string_alias: "meow" require.NoError(t, err) assert.Equal(t, []testStruct{{Field: 1}, {Field: 2}}, structSlice) } + +func TestDeprecateFlag(t *testing.T) { + flags := replaceFlagsForTesting(t) + flagInt := DeprecatedVar[int](NewPrimitiveFlag(5), "deprecated_int", "", "migration plan") + flagStringSlice := DeprecatedVar[[]string](NewSliceFlag(&[]string{"hi"}), "deprecated_string_slice", "", "migration plan") + assert.Equal(t, *flagInt, 5) + assert.Equal(t, *flagStringSlice, []string{"hi"}) + flags.Set("deprecated_int", "7") + flags.Set("deprecated_string_slice", "hello") + assert.Equal(t, *flagStringSlice, []string{"hi", "hello"}) + assert.Equal(t, *flagInt, 7) + testInt, err := common.GetDereferencedValue[int]("deprecated_int") + require.NoError(t, err) + assert.Equal(t, testInt, 7) + testStringSlice, err := common.GetDereferencedValue[[]string]("deprecated_string_slice") + require.NoError(t, err) + assert.Equal(t, testStringSlice, []string{"hi", "hello"}) + + flags = replaceFlagsForTesting(t) + + flagInt = DeprecatedVar[int](NewPrimitiveFlag(5), "deprecated_int", "", "migration plan") + flagString := DeprecatedVar[string](NewPrimitiveFlag(""), "deprecated_string", "", "migration plan") + flagStringSlice = DeprecatedVar[[]string](NewSliceFlag(&[]string{"hi"}), "deprecated_string_slice", "", "migration plan") + flags.Set("deprecated_int", "7") + flags.Set("deprecated_string_slice", "hello") + yamlData := ` +deprecated_int: 9 +deprecated_string: "moo" +deprecated_string_slice: + - "hey" +` + err = flagyaml.PopulateFlagsFromData([]byte(yamlData)) + require.NoError(t, err) + assert.Equal(t, *flagInt, 7) + assert.Equal(t, *flagString, "moo") + assert.Equal(t, *flagStringSlice, []string{"hi", "hello", "hey"}) + testInt, err = common.GetDereferencedValue[int]("deprecated_int") + require.NoError(t, err) + assert.Equal(t, testInt, 7) + testString, err := common.GetDereferencedValue[string]("deprecated_string") + require.NoError(t, err) + assert.Equal(t, testString, "moo") + testStringSlice, err = common.GetDereferencedValue[[]string]("deprecated_string_slice") + require.NoError(t, err) + assert.Equal(t, testStringSlice, []string{"hi", "hello", "hey"}) + + d := any(&DeprecatedFlag{}) + _, ok := d.(common.WrappingValue) + assert.True(t, ok) + _, ok = d.(common.SetValueForFlagNameHooked) + assert.True(t, ok) + _, ok = d.(flagyaml.YAMLSetValueHooked) + assert.True(t, ok) + +} diff --git a/server/util/flagutil/yaml/yaml.go b/server/util/flagutil/yaml/yaml.go index 255e09d288f..8f085121089 100644 --- a/server/util/flagutil/yaml/yaml.go +++ b/server/util/flagutil/yaml/yaml.go @@ -50,25 +50,37 @@ func IgnoreFilter(flg *flag.Flag) bool { } type YAMLTypeAliasable interface { + // YAMLTypeAlias returns the type alias we use in YAML for this flag.Value. YAMLTypeAlias() reflect.Type } type YAMLTypeStringable interface { + // YAMLTypeString returns the name to print for this type in YAML docs. YAMLTypeString() string } +type YAMLSetValueHooked interface { + // YAMLSetValueHook is the hook for flags that is called when the flag.Value + // is set through the YAML config. + YAMLSetValueHook() +} + type DocumentedMarshaler interface { + // DocumentNode documents the yaml.Node representing this value. DocumentNode(n *yaml.Node, opts ...common.DocumentNodeOption) error } -// GetYAMLTypeForFlag returns the type alias to use in YAML contexts for the flag. -func GetYAMLTypeForFlag(flg *flag.Flag) (reflect.Type, error) { - if v, ok := flg.Value.(YAMLTypeAliasable); ok { +// GetYAMLTypeForFlagValue returns the type alias to use in YAML contexts for the flag. +func GetYAMLTypeForFlagValue(value flag.Value) (reflect.Type, error) { + if v, ok := value.(common.WrappingValue); ok { + return GetYAMLTypeForFlagValue(v.WrappedValue()) + } + if v, ok := value.(YAMLTypeAliasable); ok { return v.YAMLTypeAlias(), nil - } else if t, err := common.GetTypeForFlag(flg); err == nil { + } else if t, err := common.GetTypeForFlagValue(value); err == nil { return t, nil } - return nil, status.UnimplementedErrorf("Unsupported flag type at %s: %T", flg.Name, flg.Value) + return nil, status.UnimplementedErrorf("Unsupported flag type: %T", value) } type HeadComment string @@ -223,9 +235,9 @@ func DocumentNode(in any, n *yaml.Node, opts ...common.DocumentNodeOption) error // GenerateDocumentedYAMLNodeFromFlag produces a documented yaml.Node which // represents the value contained in the flag. func GenerateDocumentedYAMLNodeFromFlag(flg *flag.Flag) (*yaml.Node, error) { - t, err := GetYAMLTypeForFlag(flg) + t, err := GetYAMLTypeForFlagValue(flg.Value) if err != nil { - return nil, status.InternalErrorf("Error encountered generating default YAML from flags: %s", err) + return nil, status.InternalErrorf("Error encountered generating default YAML from flags when processing flag %s: %s", flg.Name, err) } v, err := common.GetDereferencedValue[any](flg.Name) if err != nil { @@ -425,7 +437,12 @@ func PopulateFlagsFromData(data []byte) error { if len(node.Content) > 0 { node = node.Content[0] } - typeMap, err := GenerateYAMLMapWithValuesFromFlags(GetYAMLTypeForFlag, IgnoreFilter) + typeMap, err := GenerateYAMLMapWithValuesFromFlags( + func(flg *flag.Flag) (reflect.Type, error) { + return GetYAMLTypeForFlagValue(flg.Value) + }, + IgnoreFilter, + ) if err != nil { return err } @@ -502,5 +519,17 @@ func populateFlagsFromYAML(a any, prefix []string, node *yaml.Node, setFlags map if _, ok := ignoreSet[name]; ok { return nil } - return common.SetValueForFlagName(name, a, setFlags, true, false) + + flg := common.DefaultFlagSet.Lookup(name) + if flg == nil { + return nil + } + return setValueForYAML(flg.Value, name, a, setFlags, true) +} + +func setValueForYAML(flagValue flag.Value, name string, newValue any, setFlags map[string]struct{}, appendSlice bool, setHooks ...func()) error { + if v, ok := flagValue.(YAMLSetValueHooked); ok { + setHooks = append(setHooks, v.YAMLSetValueHook) + } + return common.SetValueWithCustomIndirectBehavior(flagValue, name, newValue, setFlags, appendSlice, setValueForYAML, setHooks...) } diff --git a/server/util/flagutil/yaml/yaml_test.go b/server/util/flagutil/yaml/yaml_test.go index f416dbc5270..a478c487f47 100644 --- a/server/util/flagutil/yaml/yaml_test.go +++ b/server/util/flagutil/yaml/yaml_test.go @@ -47,7 +47,12 @@ func TestGenerateYAMLTypeMapFromFlags(t *testing.T) { flagtypes.Slice("one.two.three.struct_slice", []testStruct{{Field: 4, Meadow: "Great"}}, "") flags.String("a.b.string", "xxx", "") flagtypes.URLFromString("a.b.url", "https://www.example.com", "") - actual, err := flagyaml.GenerateYAMLMapWithValuesFromFlags(flagyaml.GetYAMLTypeForFlag, flagyaml.IgnoreFilter) + actual, err := flagyaml.GenerateYAMLMapWithValuesFromFlags( + func(flg *flag.Flag) (reflect.Type, error) { + return flagyaml.GetYAMLTypeForFlagValue(flg.Value) + }, + flagyaml.IgnoreFilter, + ) require.NoError(t, err) expected := map[string]any{ "bool": reflect.TypeOf((*bool)(nil)), @@ -80,20 +85,35 @@ func TestBadGenerateYAMLTypeMapFromFlags(t *testing.T) { flags.Int("one.two.int", 10, "") flags.Int("one.two", 10, "") - _, err := flagyaml.GenerateYAMLMapWithValuesFromFlags(flagyaml.GetYAMLTypeForFlag, flagyaml.IgnoreFilter) + _, err := flagyaml.GenerateYAMLMapWithValuesFromFlags( + func(flg *flag.Flag) (reflect.Type, error) { + return flagyaml.GetYAMLTypeForFlagValue(flg.Value) + }, + flagyaml.IgnoreFilter, + ) require.Error(t, err) flags = replaceFlagsForTesting(t) flags.Int("one.two", 10, "") flags.Int("one.two.int", 10, "") - _, err = flagyaml.GenerateYAMLMapWithValuesFromFlags(flagyaml.GetYAMLTypeForFlag, flagyaml.IgnoreFilter) + _, err = flagyaml.GenerateYAMLMapWithValuesFromFlags( + func(flg *flag.Flag) (reflect.Type, error) { + return flagyaml.GetYAMLTypeForFlagValue(flg.Value) + }, + flagyaml.IgnoreFilter, + ) require.Error(t, err) flags = replaceFlagsForTesting(t) flags.Var(&unsupportedFlagValue{}, "unsupported", "") - _, err = flagyaml.GenerateYAMLMapWithValuesFromFlags(flagyaml.GetYAMLTypeForFlag, flagyaml.IgnoreFilter) + _, err = flagyaml.GenerateYAMLMapWithValuesFromFlags( + func(flg *flag.Flag) (reflect.Type, error) { + return flagyaml.GetYAMLTypeForFlagValue(flg.Value) + }, + flagyaml.IgnoreFilter, + ) require.Error(t, err) } diff --git a/server/util/testing/flags/flags.go b/server/util/testing/flags/flags.go index d63103d7553..e07a0617d1d 100644 --- a/server/util/testing/flags/flags.go +++ b/server/util/testing/flags/flags.go @@ -28,11 +28,11 @@ func PopulateFlagsFromData(t testing.TB, testConfigData []byte) { func Set(t testing.TB, name string, value any) { origValue, err := flagutil.GetDereferencedValue[any](name) require.NoError(t, err) - err = flagutil_common.SetValueForFlagName(name, value, nil, false, true) + err = flagutil.SetValueForFlagName(name, value, nil, false) require.NoError(t, err) t.Cleanup(func() { - err = flagutil_common.SetValueForFlagName(name, origValue, nil, false, true) + err = flagutil.SetValueForFlagName(name, origValue, nil, false) require.NoError(t, err) }) } From 67a224c257901f55212a3e0ad22687bd5cbcecc0 Mon Sep 17 00:00:00 2001 From: Brandon Duffany Date: Wed, 15 Jun 2022 20:05:22 -0400 Subject: [PATCH 3/3] Bump version v2.10.3 -> v2.10.4 (release.py) (#2142) --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 28793914751..f0c43f342c0 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v2.10.3 +v2.10.4