Skip to content

Commit

Permalink
Merge pull request #1830 from rancher/fix-modeltraining-ordering
Browse files Browse the repository at this point in the history
Fix ordering of API tasks for training models
  • Loading branch information
alexandreLamarre authored Nov 8, 2023
2 parents b826f44 + d5c2fb2 commit 353ff3a
Showing 1 changed file with 36 additions and 14 deletions.
50 changes: 36 additions & 14 deletions plugins/aiops/pkg/gateway/modeltraining.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/google/uuid"
backoffv2 "github.com/lestrrat-go/backoff/v2"
"github.com/nats-io/nats.go"
corev1 "github.com/rancher/opni/pkg/apis/core/v1"
"github.com/rancher/opni/plugins/aiops/apis/admin"
Expand Down Expand Up @@ -67,27 +68,33 @@ func (p *AIOpsPlugin) requestModelTraining(
if err != nil {
return nil, status.Errorf(codes.Internal, "Failed to get model training parameters: %v", err)
}
msg, err := natsConnection.Request(modelTrainingNatsSubject, parametersPayload, time.Minute)
if err != nil {
return nil, status.Errorf(codes.Unavailable, "Failed to train model: %v", err)
retrier := backoffv2.Exponential(
backoffv2.WithMaxRetries(10),
backoffv2.WithMinInterval(2*time.Second),
backoffv2.WithMaxInterval(5*time.Second),
backoffv2.WithMultiplier(1.2),
)
signalStartTraining := retrier.Start(ctx)
var msgErr error
for backoffv2.Continue(signalStartTraining) {
msg, err := natsConnection.Request(modelTrainingNatsSubject, parametersPayload, time.Minute)
if err == nil {
return &modeltraining.ModelTrainingResponse{
Response: string(msg.Data),
}, nil
}
msgErr = err
}
return &modeltraining.ModelTrainingResponse{
Response: string(msg.Data),
}, nil
return nil, fmt.Errorf("failed to post to training message to NATS : %s", msgErr)
}

func (p *AIOpsPlugin) TrainModel(ctx context.Context, in *modeltraining.ModelTrainingParametersList) (*modeltraining.ModelTrainingResponse, error) {
_, err := p.LaunchAIServices(ctx)
if err != nil {
return nil, status.Error(codes.FailedPrecondition, fmt.Sprintf("failed to launch AI services : %s", err))
}
func (p *AIOpsPlugin) persistInitialJobInfo(ctx context.Context, in *modeltraining.ModelTrainingParametersList) ([]byte, error) {
ctxca, ca := context.WithTimeout(ctx, 10*time.Second)
defer ca()
modelTrainingKv, err := p.modelTrainingKv.GetContext(ctxca)
if err != nil {
return nil, status.Errorf(codes.FailedPrecondition, "Failed to get model training KV: %v", err)
}

modelTrainingParameters, parametersBytes, err := modelTrainingParams(in)
if err != nil {
return nil, status.Errorf(codes.Internal, "Failed to marshall training parameters: %v", err)
Expand All @@ -96,6 +103,7 @@ func (p *AIOpsPlugin) TrainModel(ctx context.Context, in *modeltraining.ModelTra
if _, err := modelTrainingKv.Put(modelTrainingParametersKey, parametersBytes); err != nil {
return nil, err
}

initialStatus := modeltraining.ModelStatus{Status: "training", Statistics: &modeltraining.ModelTrainingStatistics{Stage: "fetching data"}}
if len(modelTrainingParameters) == 0 {
initialStatus = modeltraining.ModelStatus{Status: "no model trained"}
Expand All @@ -104,15 +112,29 @@ func (p *AIOpsPlugin) TrainModel(ctx context.Context, in *modeltraining.ModelTra
if err != nil {
return nil, status.Errorf(codes.Internal, "Failed to put model training status: %v", err)
}
return parametersBytes, nil
}

func (p *AIOpsPlugin) TrainModel(ctx context.Context, in *modeltraining.ModelTrainingParametersList) (*modeltraining.ModelTrainingResponse, error) {
parametersBytes, err := p.persistInitialJobInfo(ctx, in)
if err != nil {
return nil, err
}
_, err = p.LaunchAIServices(ctx)
if err != nil {
delErr := errors.Join(p.deleteTrainingJobInfo(ctx), err)
p.Logger.Error(fmt.Sprintf("failed to launch AI services : %s", delErr))
return nil, status.Error(codes.FailedPrecondition, fmt.Sprintf("failed to launch AI services : %s", delErr))
}
resp, err := p.requestModelTraining(ctx, parametersBytes)
if err != nil { // this fails if the request never made it to the training container
// therefore we make a best effort to purge the stateful information associated
// with the modeltraining request
ctxca, ca := context.WithTimeout(ctx, 10*time.Second)
defer ca()
delErr := p.deleteTrainingJobInfo(ctxca)
return nil, errors.Join(err, delErr)
delErr := errors.Join(p.deleteTrainingJobInfo(ctxca), err)
p.Logger.Error(fmt.Sprintf("failed to request model training : %s", delErr))
return nil, delErr
}

return resp, nil
Expand Down

0 comments on commit 353ff3a

Please sign in to comment.