Skip to content

Commit

Permalink
review: convert fileFormat to type FileSystemFileFormat
Browse files Browse the repository at this point in the history
  • Loading branch information
tenzen-y committed Jan 4, 2022
1 parent 1fb26bb commit 4eb6d6c
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 23 deletions.
24 changes: 13 additions & 11 deletions cmd/metricscollector/v1beta1/file-metricscollector/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func printMetricsFile(mFile string) {
}
}

func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string) {
func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string, fileFormat commonv1beta1.FileSystemFileFormat) {

// metricStartStep is the dict where key = metric name, value = start step.
// We should apply early stopping rule only if metric is reported at least "start_step" times.
Expand Down Expand Up @@ -174,7 +174,7 @@ func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string) {

// Get list of regural expressions from filters.
var metricRegList []*regexp.Regexp
if *metricsFileFormat == commonv1beta1.TextFormat.String() {
if fileFormat == commonv1beta1.TextFormat {
metricRegList = filemc.GetFilterRegexpList(filters)
}

Expand All @@ -185,8 +185,8 @@ func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string) {
// Print log line
klog.Info(logText)

switch *metricsFileFormat {
case commonv1beta1.TextFormat.String():
switch fileFormat {
case commonv1beta1.TextFormat:
// Check if log line contains metric from stop rules.
isRuleLine := false
for _, rule := range stopRules {
Expand Down Expand Up @@ -224,7 +224,7 @@ func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string) {
}
}
}
case commonv1beta1.JsonFormat.String():
case commonv1beta1.JsonFormat:
var logJsonObj map[string]interface{}
if err = json.Unmarshal([]byte(logText), &logJsonObj); err != nil {
klog.Fatalf("Failed to unmarshal logs in JSON format, log: %s, error: %v", logText, err)
Expand Down Expand Up @@ -256,7 +256,7 @@ func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string) {
stopRules = updateStopRules(objMetric, stopRules, optimalObjValue, metricValue, objType, metricStartStep, rule, idx)
}
default:
klog.Fatalf("format must be set %s or %s", commonv1beta1.TextFormat.String(), commonv1beta1.JsonFormat.String())
klog.Fatalf("format must be set %v or %v", commonv1beta1.TextFormat, commonv1beta1.JsonFormat)
}

// If stopRules array is empty, Trial is early stopped.
Expand Down Expand Up @@ -295,7 +295,7 @@ func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string) {
}

// Report metrics to DB.
reportMetrics(filters)
reportMetrics(filters, fileFormat)

// Wait until main process is completed.
timeout := 60 * time.Second
Expand Down Expand Up @@ -400,9 +400,11 @@ func main() {
filters = strings.Split(*metricFilters, ";")
}

fileFormat := commonv1beta1.FileSystemFileFormat(*metricsFileFormat)

// If stop rule is set we need to parse metrics during run.
if len(stopRules) != 0 {
go watchMetricsFile(*metricsFilePath, stopRules, filters)
go watchMetricsFile(*metricsFilePath, stopRules, filters, fileFormat)
} else {
go printMetricsFile(*metricsFilePath)
}
Expand All @@ -421,11 +423,11 @@ func main() {

// If training was not early stopped, report the metrics.
if !isEarlyStopped {
reportMetrics(filters)
reportMetrics(filters, fileFormat)
}
}

func reportMetrics(filters []string) {
func reportMetrics(filters []string, fileFormat commonv1beta1.FileSystemFileFormat) {

conn, err := grpc.Dial(*dbManagerServiceAddr, grpc.WithInsecure())
if err != nil {
Expand All @@ -438,7 +440,7 @@ func reportMetrics(filters []string) {
if len(*metricNames) != 0 {
metricList = strings.Split(*metricNames, ";")
}
olog, err := filemc.CollectObservationLog(*metricsFilePath, metricList, filters, *metricsFileFormat)
olog, err := filemc.CollectObservationLog(*metricsFilePath, metricList, filters, fileFormat)
if err != nil {
klog.Fatalf("Failed to collect logs: %v", err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import (
"github.com/kubeflow/katib/pkg/metricscollector/v1beta1/common"
)

func CollectObservationLog(fileName string, metrics []string, filters []string, format string) (*v1beta1.ObservationLog, error) {
func CollectObservationLog(fileName string, metrics []string, filters []string, fileFormat commonv1beta1.FileSystemFileFormat) (*v1beta1.ObservationLog, error) {
file, err := os.Open(fileName)
if err != nil {
return nil, err
Expand All @@ -46,13 +46,13 @@ func CollectObservationLog(fileName string, metrics []string, filters []string,
}
logs := string(content)

switch format {
case commonv1beta1.TextFormat.String():
switch fileFormat {
case commonv1beta1.TextFormat:
return parseLogsInTextFormat(strings.Split(logs, "\n"), metrics, filters)
case commonv1beta1.JsonFormat.String():
case commonv1beta1.JsonFormat:
return parseLogsInJsonFormat(strings.Split(logs, "\n"), metrics)
}
return nil, fmt.Errorf("format must be set %s or %s", commonv1beta1.TextFormat.String(), commonv1beta1.JsonFormat.String())
return nil, fmt.Errorf("format must be set %v or %s", commonv1beta1.TextFormat, commonv1beta1.JsonFormat)
}

func parseLogsInTextFormat(logs []string, metrics []string, filters []string) (*v1beta1.ObservationLog, error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ func TestCollectObservationLog(t *testing.T) {
fileName string
metrics []string
filters []string
format string
fileFormat commonv1beta1.FileSystemFileFormat
err bool
expected *v1beta1.ObservationLog
}{
{
description: "Positive case for logs in JSON format",
fileName: path.Join(testJsonDataPath, "good.json"),
metrics: []string{"acc", "loss"},
format: commonv1beta1.JsonFormat.String(),
fileFormat: commonv1beta1.JsonFormat,
expected: &v1beta1.ObservationLog{
MetricLogs: []*v1beta1.MetricLog{
{
Expand Down Expand Up @@ -95,19 +95,19 @@ func TestCollectObservationLog(t *testing.T) {
{
description: "Invalid file format",
fileName: path.Join(testJsonDataPath, "good.json"),
format: "invalid",
fileFormat: "invalid",
err: true,
},
{
description: "Invalid formatted file for logs in JSON format",
fileName: path.Join(testJsonDataPath, "invalid-format.json"),
format: commonv1beta1.JsonFormat.String(),
fileFormat: commonv1beta1.JsonFormat,
err: true,
},
{
description: "Invalid timestamp for logs in JSON format",
fileName: path.Join(testJsonDataPath, "invalid-timestamp.json"),
format: commonv1beta1.JsonFormat.String(),
fileFormat: commonv1beta1.JsonFormat,
metrics: []string{"acc", "loss"},
expected: &v1beta1.ObservationLog{
MetricLogs: []*v1beta1.MetricLog{
Expand All @@ -131,7 +131,7 @@ func TestCollectObservationLog(t *testing.T) {
{
description: "Missing objective metric in training logs",
fileName: path.Join(testJsonDataPath, "missing-objective-metric.json"),
format: commonv1beta1.JsonFormat.String(),
fileFormat: commonv1beta1.JsonFormat,
metrics: []string{"acc", "loss"},
expected: &v1beta1.ObservationLog{
MetricLogs: []*v1beta1.MetricLog{
Expand All @@ -149,7 +149,7 @@ func TestCollectObservationLog(t *testing.T) {

for _, test := range testCases {
t.Run(test.description, func(t *testing.T) {
actual, err := CollectObservationLog(test.fileName, test.metrics, test.filters, test.format)
actual, err := CollectObservationLog(test.fileName, test.metrics, test.filters, test.fileFormat)
if (err != nil) != test.err {
t.Errorf("\nGOT: \n%v\nWANT: %v\n", err, test.err)
} else {
Expand Down

0 comments on commit 4eb6d6c

Please sign in to comment.