Skip to content

Commit

Permalink
fix: use custom marshaler for n_epochs
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis committed Sep 25, 2023
1 parent 8e4b796 commit 2be7241
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 3 deletions.
35 changes: 34 additions & 1 deletion fine_tuning_job.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package openai

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
Expand All @@ -23,8 +24,40 @@ type FineTuningJob struct {
TrainedTokens int `json:"trained_tokens"`
}

type HyperparameterNEpochs struct {
IntValue *int `json:"-"`
StringValue *string `json:"-"`
}

func (h *HyperparameterNEpochs) UnmarshalJSON(data []byte) error {
var intValue int
var stringValue string

if err := json.Unmarshal(data, &intValue); err == nil {
h.IntValue = &intValue
return nil
}

Check warning on line 39 in fine_tuning_job.go

View check run for this annotation

Codecov / codecov/patch

fine_tuning_job.go#L37-L39

Added lines #L37 - L39 were not covered by tests

if err := json.Unmarshal(data, &stringValue); err != nil {
return err
}

Check warning on line 43 in fine_tuning_job.go

View check run for this annotation

Codecov / codecov/patch

fine_tuning_job.go#L42-L43

Added lines #L42 - L43 were not covered by tests

h.StringValue = &stringValue
return nil
}

func (h *HyperparameterNEpochs) MarshalJSON() ([]byte, error) {
if h.IntValue != nil {
return json.Marshal(*h.IntValue)

Check warning on line 51 in fine_tuning_job.go

View check run for this annotation

Codecov / codecov/patch

fine_tuning_job.go#L51

Added line #L51 was not covered by tests
} else if h.StringValue != nil {
return json.Marshal(*h.StringValue)
}

return nil, fmt.Errorf("invalid hyperparameter n_epochs")

Check warning on line 56 in fine_tuning_job.go

View check run for this annotation

Codecov / codecov/patch

fine_tuning_job.go#L56

Added line #L56 was not covered by tests
}

type Hyperparameters struct {
Epochs int `json:"n_epochs"`
Epochs *HyperparameterNEpochs `json:"n_epochs,omitempty"`
}

type FineTuningJobRequest struct {
Expand Down
23 changes: 21 additions & 2 deletions fine_tuning_job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,27 @@ func TestFineTuningJob(t *testing.T) {
server.RegisterHandler(
"/v1/fine_tuning/jobs",
func(w http.ResponseWriter, r *http.Request) {
var resBytes []byte
resBytes, _ = json.Marshal(FineTuningJob{})
nEpochs := "auto"
resBytes, _ := json.Marshal(FineTuningJob{
Object: "fine_tuning.job",
ID: testFineTuninigJobID,
Model: "davinci-002",
CreatedAt: 1692661014,
FinishedAt: 1692661190,
FineTunedModel: "ft:davinci-002:my-org:custom_suffix:7q8mpxmy",
OrganizationID: "org-123",
ResultFiles: []string{"file-abc123"},
Status: "succeeded",
ValidationFile: "",
TrainingFile: "file-abc123",
Hyperparameters: Hyperparameters{
Epochs: &HyperparameterNEpochs{
IntValue: nil,
StringValue: &nEpochs,
},
},
TrainedTokens: 5768,
})
fmt.Fprintln(w, string(resBytes))
},
)
Expand Down

0 comments on commit 2be7241

Please sign in to comment.