Skip to content

Commit

Permalink
modify MetricStrategy specification
Browse files Browse the repository at this point in the history
  • Loading branch information
sperlingxx committed Jun 8, 2020
1 parent cebacc4 commit 5d2bba4
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 38 deletions.
9 changes: 5 additions & 4 deletions examples/v1beta1/metric-strategy-example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ spec:
type: maximize
goal: 0.99
objectiveMetricName: Validation-accuracy
additionalMetricNames:
- Train-accuracy
additionalMetricNames: [Train-accuracy]
metricStrategies:
Validation-accuracy: "max"
Train-accuracy: "latest"
- name: Train-accuracy
value: "latest"
- name: Validation-accuracy
value: "max"
algorithm:
algorithmName: tpe
parallelTrialCount: 3
Expand Down
21 changes: 13 additions & 8 deletions pkg/apis/controller/common/v1beta1/common_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type ObjectiveSpec struct {
// Note: If we adopt a push instead of pull mechanism, this can be omitted completely.
AdditionalMetricNames []string `json:"additionalMetricNames,omitempty"`
// This field is allowed to missing, experiment defaulter (webhook) will fill it.
MetricStrategies map[string]MetricStrategy `json:"metricStrategies,omitempty"`
MetricStrategies []MetricStrategy `json:"metricStrategies,omitempty"`
}

type ObjectiveType string
Expand All @@ -63,18 +63,23 @@ const (
ObjectiveTypeMaximize ObjectiveType = "maximize"
)

type ParameterAssignment struct {
Name string `json:"name,omitempty"`
Value string `json:"value,omitempty"`
}

// ObjectiveExtractType describes the various approaches to extract objective value from metrics.
type MetricStrategy string
type MetricStrategyType string

const (
ExtractByMin MetricStrategy = "min"
ExtractByMax MetricStrategy = "max"
ExtractByLatest MetricStrategy = "latest"
ExtractByMin MetricStrategyType = "min"
ExtractByMax MetricStrategyType = "max"
ExtractByLatest MetricStrategyType = "latest"
)

type ParameterAssignment struct {
Name string `json:"name,omitempty"`
Value string `json:"value,omitempty"`
type MetricStrategy struct {
Name string `json:"name,omitempty"`
Value MetricStrategyType `json:"value,omitempty"`
}

type Metric struct {
Expand Down
22 changes: 18 additions & 4 deletions pkg/apis/controller/common/v1beta1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 20 additions & 8 deletions pkg/apis/controller/experiments/v1beta1/experiment_defaults.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion pkg/controller.v1beta1/experiment/util/status_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,14 @@ func getObjectiveMetricValue(trial trialsv1beta1.Trial) *string {
if trial.Status.Observation == nil {
return nil
}
var objectiveStrategy commonv1beta1.MetricStrategyType
objectiveMetricName := trial.Spec.Objective.ObjectiveMetricName
objectiveStrategy, _ := trial.Spec.Objective.MetricStrategies[objectiveMetricName]
for _, strategy := range trial.Spec.Objective.MetricStrategies {
if strategy.Name == objectiveMetricName {
objectiveStrategy = strategy.Value
break
}
}
for _, metric := range trial.Status.Observation.Metrics {
if objectiveMetricName == metric.Name {
switch objectiveStrategy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,18 @@ func convertTrialConditionType(conditionType trialsv1beta1.TrialConditionType) s
}

// convertTrialObservation convert Trial Observation Metrics CRD to the GRPC definition
func convertTrialObservation(strategies map[string]commonapiv1beta1.MetricStrategy, observation *commonapiv1beta1.Observation) *suggestionapi.Observation {
func convertTrialObservation(strategies []commonapiv1beta1.MetricStrategy, observation *commonapiv1beta1.Observation) *suggestionapi.Observation {
resObservation := &suggestionapi.Observation{
Metrics: make([]*suggestionapi.Metric, 0),
}
strategyMap := make(map[string]commonapiv1beta1.MetricStrategyType)
for _, strategy := range strategies {
strategyMap[strategy.Name] = strategy.Value
}
if observation != nil && observation.Metrics != nil {
for _, m := range observation.Metrics {
var value string
switch strategy, _ := strategies[m.Name]; strategy {
switch strategy, _ := strategyMap[m.Name]; strategy {
case commonapiv1beta1.ExtractByMin:
if math.IsNaN(m.Min) {
value = m.Latest
Expand All @@ -279,8 +283,8 @@ func convertTrialObservation(strategies map[string]commonapiv1beta1.MetricStrate
})
}
}
return resObservation

return resObservation
}

// convertTrialStatusTime convert Trial Status Time CRD to the GRPC definition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func init() {

func TestConvertTrialObservation(t *testing.T) {
g := gomega.NewGomegaWithT(t)
var strategies = map[string]commonv1beta1.MetricStrategy{
var strategies = map[string]commonv1beta1.MetricStrategyType{
"error": commonv1beta1.ExtractByMin,
"auc": commonv1beta1.ExtractByMax,
"accuracy": commonv1beta1.ExtractByLatest,
Expand Down
8 changes: 4 additions & 4 deletions pkg/controller.v1beta1/trial/trial_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func TestGetObjectiveMetricValue(t *testing.T) {
{TimeStamp: "2020-04-12T14:47:42+08:00", Metric: &api_pb.Metric{Name: "accuracy", Value: "0.6"}},
}

getMetricsFromLogs := func(strategies map[string]commonv1beta1.MetricStrategy) (*commonv1beta1.Metric, *commonv1beta1.Metric, error) {
getMetricsFromLogs := func(strategies []commonv1beta1.MetricStrategy) (*commonv1beta1.Metric, *commonv1beta1.Metric, error) {
observation, err := getMetrics(metricLogs, strategies)
if err != nil {
return nil, nil, err
Expand All @@ -273,9 +273,9 @@ func TestGetObjectiveMetricValue(t *testing.T) {
return errMetric, accMetric, nil
}

metricStrategies := map[string]commonv1beta1.MetricStrategy{
"error": commonv1beta1.ExtractByMin,
"accuracy": commonv1beta1.ExtractByMax,
metricStrategies := []commonv1beta1.MetricStrategy{
{Name: "error", Value: commonv1beta1.ExtractByMin},
{Name: "accuracy", Value: commonv1beta1.ExtractByMax},
}
errMetric, accMetric, err := getMetricsFromLogs(metricStrategies)
g.Expect(err).ShouldNot(gomega.HaveOccurred())
Expand Down
10 changes: 5 additions & 5 deletions pkg/controller.v1beta1/trial/trial_controller_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,13 @@ func isJobSucceeded(jobCondition *commonv1.JobCondition) bool {
return false
}

func getMetrics(metricLogs []*api_pb.MetricLog, strategies map[string]commonv1beta1.MetricStrategy) (*commonv1beta1.Observation, error) {
func getMetrics(metricLogs []*api_pb.MetricLog, strategies []commonv1beta1.MetricStrategy) (*commonv1beta1.Observation, error) {
metrics := make(map[string]*commonv1beta1.Metric)
timestamps := make(map[string]*time.Time)
for name := range strategies {
timestamps[name] = nil
metrics[name] = &commonv1beta1.Metric{
Name: name,
for _, strategy := range strategies {
timestamps[strategy.Name] = nil
metrics[strategy.Name] = &commonv1beta1.Metric{
Name: strategy.Name,
Min: math.NaN(),
Max: math.NaN(),
Latest: "",
Expand Down

0 comments on commit 5d2bba4

Please sign in to comment.