diff --git a/plugins/aiops/pkg/gateway/modeltraining.go b/plugins/aiops/pkg/gateway/modeltraining.go index d47ed023f4..dc10b8a29b 100644 --- a/plugins/aiops/pkg/gateway/modeltraining.go +++ b/plugins/aiops/pkg/gateway/modeltraining.go @@ -9,6 +9,7 @@ import ( "time" "github.com/google/uuid" + backoffv2 "github.com/lestrrat-go/backoff/v2" "github.com/nats-io/nats.go" corev1 "github.com/rancher/opni/pkg/apis/core/v1" "github.com/rancher/opni/plugins/aiops/apis/admin" @@ -67,27 +68,33 @@ func (p *AIOpsPlugin) requestModelTraining( if err != nil { return nil, status.Errorf(codes.Internal, "Failed to get model training parameters: %v", err) } - msg, err := natsConnection.Request(modelTrainingNatsSubject, parametersPayload, time.Minute) - if err != nil { - return nil, status.Errorf(codes.Unavailable, "Failed to train model: %v", err) + retrier := backoffv2.Exponential( + backoffv2.WithMaxRetries(10), + backoffv2.WithMinInterval(2*time.Second), + backoffv2.WithMaxInterval(5*time.Second), + backoffv2.WithMultiplier(1.2), + ) + signalStartTraining := retrier.Start(ctx) + var msgErr error + for backoffv2.Continue(signalStartTraining) { + msg, err := natsConnection.Request(modelTrainingNatsSubject, parametersPayload, time.Minute) + if err == nil { + return &modeltraining.ModelTrainingResponse{ + Response: string(msg.Data), + }, nil + } + msgErr = err } - return &modeltraining.ModelTrainingResponse{ - Response: string(msg.Data), - }, nil + return nil, fmt.Errorf("failed to post to training message to NATS : %s", msgErr) } -func (p *AIOpsPlugin) TrainModel(ctx context.Context, in *modeltraining.ModelTrainingParametersList) (*modeltraining.ModelTrainingResponse, error) { - _, err := p.LaunchAIServices(ctx) - if err != nil { - return nil, status.Error(codes.FailedPrecondition, fmt.Sprintf("failed to launch AI services : %s", err)) - } +func (p *AIOpsPlugin) persistInitialJobInfo(ctx context.Context, in *modeltraining.ModelTrainingParametersList) ([]byte, error) { ctxca, ca := context.WithTimeout(ctx, 10*time.Second) defer ca() modelTrainingKv, err := p.modelTrainingKv.GetContext(ctxca) if err != nil { return nil, status.Errorf(codes.FailedPrecondition, "Failed to get model training KV: %v", err) } - modelTrainingParameters, parametersBytes, err := modelTrainingParams(in) if err != nil { return nil, status.Errorf(codes.Internal, "Failed to marshall training parameters: %v", err) @@ -96,6 +103,7 @@ func (p *AIOpsPlugin) TrainModel(ctx context.Context, in *modeltraining.ModelTra if _, err := modelTrainingKv.Put(modelTrainingParametersKey, parametersBytes); err != nil { return nil, err } + initialStatus := modeltraining.ModelStatus{Status: "training", Statistics: &modeltraining.ModelTrainingStatistics{Stage: "fetching data"}} if len(modelTrainingParameters) == 0 { initialStatus = modeltraining.ModelStatus{Status: "no model trained"} @@ -104,15 +112,29 @@ func (p *AIOpsPlugin) TrainModel(ctx context.Context, in *modeltraining.ModelTra if err != nil { return nil, status.Errorf(codes.Internal, "Failed to put model training status: %v", err) } + return parametersBytes, nil +} +func (p *AIOpsPlugin) TrainModel(ctx context.Context, in *modeltraining.ModelTrainingParametersList) (*modeltraining.ModelTrainingResponse, error) { + parametersBytes, err := p.persistInitialJobInfo(ctx, in) + if err != nil { + return nil, err + } + _, err = p.LaunchAIServices(ctx) + if err != nil { + delErr := errors.Join(p.deleteTrainingJobInfo(ctx), err) + p.Logger.Error(fmt.Sprintf("failed to launch AI services : %s", delErr)) + return nil, status.Error(codes.FailedPrecondition, fmt.Sprintf("failed to launch AI services : %s", delErr)) + } resp, err := p.requestModelTraining(ctx, parametersBytes) if err != nil { // this fails if the request never made it to the training container // therefore we make a best effort to purge the stateful information associated // with the modeltraining request ctxca, ca := context.WithTimeout(ctx, 10*time.Second) defer ca() - delErr := p.deleteTrainingJobInfo(ctxca) - return nil, errors.Join(err, delErr) + delErr := errors.Join(p.deleteTrainingJobInfo(ctxca), err) + p.Logger.Error(fmt.Sprintf("failed to request model training : %s", delErr)) + return nil, delErr } return resp, nil