diff --git a/launcher/container_runner.go b/launcher/container_runner.go index 65f2745f..bba5d9a0 100644 --- a/launcher/container_runner.go +++ b/launcher/container_runner.go @@ -197,7 +197,9 @@ func (r *ContainerRunner) measureContainerClaims(ctx context.Context) error { if err := r.attestAgent.MeasureEvent(cel.CosTlv{EventType: cel.ImageDigestType, EventContent: []byte(image.Target().Digest)}); err != nil { return err } - r.attestAgent.MeasureEvent(cel.CosTlv{EventType: cel.RestartPolicyType, EventContent: []byte(r.launchSpec.RestartPolicy)}) + if err := r.attestAgent.MeasureEvent(cel.CosTlv{EventType: cel.RestartPolicyType, EventContent: []byte(r.launchSpec.RestartPolicy)}); err != nil { + return err + } if imageConfig, err := image.Config(ctx); err == nil { // if NO error if err := r.attestAgent.MeasureEvent(cel.CosTlv{EventType: cel.ImageIDType, EventContent: []byte(imageConfig.Digest)}); err != nil { return err @@ -286,7 +288,6 @@ func (r *ContainerRunner) fetchAndWriteToken(ctx context.Context) error { } // Run the container -// Doesn't support container restart yet // Container output will always be redirected to stdio for now func (r *ContainerRunner) Run(ctx context.Context) error { ctx, cancel := context.WithCancel(ctx) @@ -300,30 +301,36 @@ func (r *ContainerRunner) Run(ctx context.Context) error { return fmt.Errorf("failed to fetch and write OIDC token: %v", err) } - task, err := r.container.NewTask(ctx, cio.NewCreator(cio.WithStdio)) - if err != nil { - return err - } - defer task.Delete(ctx) - - exitStatus, err := task.Wait(ctx) - if err != nil { - return err - } - log.Println("task started") - - if err := task.Start(ctx); err != nil { - return err - } - status := <-exitStatus + for { + task, err := r.container.NewTask(ctx, cio.NewCreator(cio.WithStdio)) + if err != nil { + return err + } + exitStatus, err := task.Wait(ctx) + if err != nil { + return err + } + log.Println("task started") - code, _, err := status.Result() - if err != nil { - return err - } + if err := task.Start(ctx); err != nil { + return err + } + status := <-exitStatus - if code != 0 { - return fmt.Errorf("task ended with non-zero return code %d", code) + code, _, err := status.Result() + if err != nil { + return err + } + task.Delete(ctx) + + log.Printf("task ended with return code %d \n", code) + if r.launchSpec.RestartPolicy == spec.Always { + log.Println("restarting task") + } else if r.launchSpec.RestartPolicy == spec.OnFailure && code != 0 { + log.Println("restarting task on failure") + } else { + break + } } return nil