From 31dba6e720aa13af52d19d8c1e5d1b4342924f6a Mon Sep 17 00:00:00 2001 From: Jiankun Lu Date: Fri, 11 Nov 2022 19:31:34 -0800 Subject: [PATCH] [launcher] Change retry behavior to reboot Update scripts and service unit file Signed-off-by: Jiankun Lu --- .gitignore | 2 +- cloudbuild.yaml | 4 +- launcher/auth.go | 2 +- launcher/container_runner.go | 92 ++++++++------- launcher/container_runner_test.go | 2 +- launcher/errors.go | 28 +++++ launcher/image/container-runner.service | 7 +- launcher/image/entrypoint.sh | 8 +- launcher/image/exit_script.sh | 14 +++ launcher/image/preload.sh | 27 ++--- launcher/launcher/main.go | 142 ++++++++++++++++++++++++ launcher/launcher/main_test.go | 131 ++++++++++++++++++++++ launcher/main.go | 89 --------------- 13 files changed, 378 insertions(+), 170 deletions(-) create mode 100644 launcher/errors.go create mode 100755 launcher/image/exit_script.sh create mode 100644 launcher/launcher/main.go create mode 100644 launcher/launcher/main_test.go delete mode 100644 launcher/main.go diff --git a/.gitignore b/.gitignore index 25bbca454..aab2b1736 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -launcher/launcher +launcher/launcher/launcher *.test *.test.exe cmd/gotpm/gotpm diff --git a/cloudbuild.yaml b/cloudbuild.yaml index cd59c0a27..c53c70e21 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -11,8 +11,8 @@ steps: args: - -c - | - cd launcher - go build -o image/launcher + cd launcher/launcher + go build -o ../image/launcher - name: 'gcr.io/cos-cloud/cos-customizer' args: ['start-image-build', '-build-context=launcher/image', diff --git a/launcher/auth.go b/launcher/auth.go index ad291006f..31570014b 100644 --- a/launcher/auth.go +++ b/launcher/auth.go @@ -1,4 +1,4 @@ -package main +package launcher import ( "encoding/json" diff --git a/launcher/container_runner.go b/launcher/container_runner.go index f5f712363..61fa7bd49 100644 --- a/launcher/container_runner.go +++ b/launcher/container_runner.go @@ -1,4 +1,5 @@ -package main +// Package launcher contains functionalities to start a measured workload +package launcher import ( "context" @@ -86,7 +87,7 @@ func fetchImpersonatedToken(ctx context.Context, serviceAccount string, audience func NewRunner(ctx context.Context, cdClient *containerd.Client, token oauth2.Token, launchSpec spec.LaunchSpec, mdsClient *metadata.Client, tpm io.ReadWriteCloser, logger *log.Logger) (*ContainerRunner, error) { image, err := initImage(ctx, cdClient, launchSpec, token, logger) if err != nil { - return nil, err + return nil, &NonRetryableError{err} } mounts := make([]specs.Mount, 0) @@ -112,10 +113,10 @@ func NewRunner(ctx context.Context, cdClient *containerd.Client, token oauth2.To logger.Printf("Image Labels : %v\n", imageLabels) launchPolicy, err := spec.GetLaunchPolicy(imageLabels) if err != nil { - return nil, err + return nil, &NonRetryableError{err} } if err := launchPolicy.Verify(launchSpec); err != nil { - return nil, err + return nil, &NonRetryableError{err} } if imageConfig, err := image.Config(ctx); err != nil { @@ -127,7 +128,7 @@ func NewRunner(ctx context.Context, cdClient *containerd.Client, token oauth2.To hostname, err := os.Hostname() if err != nil { - return nil, fmt.Errorf("cannot get hostname: [%w]", err) + return nil, &RetryableError{fmt.Errorf("cannot get hostname: [%w]", err)} } container, err = cdClient.NewContainer( @@ -151,19 +152,21 @@ func NewRunner(ctx context.Context, cdClient *containerd.Client, token oauth2.To if container != nil { container.Delete(ctx, containerd.WithSnapshotCleanup) } - return nil, fmt.Errorf("failed to create a container: [%w]", err) + return nil, &RetryableError{fmt.Errorf("failed to create a container: [%w]", err)} } containerSpec, err := container.Spec(ctx) if err != nil { - return nil, err + return nil, &RetryableError{err} } // Container process Args length should be strictly longer than the Cmd // override length set by the operator, as we want the Entrypoint filed // to be mandatory for the image. // Roughly speaking, Args = Entrypoint + Cmd if len(containerSpec.Process.Args) <= len(launchSpec.Cmd) { - return nil, fmt.Errorf("length of Args [%d] is shorter or equal to the length of the given Cmd [%d], maybe the Entrypoint is set to empty in the image?", len(containerSpec.Process.Args), len(launchSpec.Cmd)) + return nil, &NonRetryableError{ + fmt.Errorf("length of Args [%d] is shorter or equal to the length of the given Cmd [%d], maybe the Entrypoint is set to empty in the image?", + len(containerSpec.Process.Args), len(launchSpec.Cmd))} } // Fetch ID token with specific audience. @@ -207,7 +210,7 @@ func NewRunner(ctx context.Context, cdClient *containerd.Client, token oauth2.To verifierClient, conn, err = getGRPCClient(asAddr, logger) } if err != nil { - return nil, fmt.Errorf("failed to create verifier client: %v", err) + return nil, &NonRetryableError{fmt.Errorf("failed to create verifier client: %v", err)} } return &ContainerRunner{ @@ -411,53 +414,46 @@ func (r *ContainerRunner) Run(ctx context.Context) error { defer cancel() if err := r.measureContainerClaims(ctx); err != nil { - return fmt.Errorf("failed to measure container claims: %v", err) + return &NonRetryableError{fmt.Errorf("failed to measure container claims: %v", err)} } if err := r.fetchAndWriteToken(ctx); err != nil { - return fmt.Errorf("failed to fetch and write OIDC token: %v", err) + return &NonRetryableError{fmt.Errorf("failed to fetch and write OIDC token: %v", err)} } - for { - var streamOpt cio.Opt - if r.launchSpec.LogRedirect { - streamOpt = cio.WithStreams(nil, r.logger.Writer(), r.logger.Writer()) - r.logger.Println("container stdout/stderr will be redirected") - } else { - streamOpt = cio.WithStreams(nil, nil, nil) - r.logger.Println("container stdout/stderr will not be redirected") - } - - task, err := r.container.NewTask(ctx, cio.NewCreator(streamOpt)) - if err != nil { - return err - } - exitStatus, err := task.Wait(ctx) - if err != nil { - return err - } - r.logger.Println("task started") + var streamOpt cio.Opt + if r.launchSpec.LogRedirect { + streamOpt = cio.WithStreams(nil, r.logger.Writer(), r.logger.Writer()) + r.logger.Println("container stdout/stderr will be redirected") + } else { + streamOpt = cio.WithStreams(nil, nil, nil) + r.logger.Println("container stdout/stderr will not be redirected") + } - if err := task.Start(ctx); err != nil { - return err - } - status := <-exitStatus + task, err := r.container.NewTask(ctx, cio.NewCreator(streamOpt)) + if err != nil { + return &RetryableError{err} + } + exitStatus, err := task.Wait(ctx) + if err != nil { + return &RetryableError{err} + } + r.logger.Println("workload task started") - code, _, err := status.Result() - if err != nil { - return err - } - task.Delete(ctx) - - r.logger.Printf("task ended with return code %d \n", code) - if r.launchSpec.RestartPolicy == spec.Always { - r.logger.Println("restarting task") - } else if r.launchSpec.RestartPolicy == spec.OnFailure && code != 0 { - r.logger.Println("restarting task on failure") - } else { - break - } + if err := task.Start(ctx); err != nil { + return &RetryableError{err} } + status := <-exitStatus + code, _, err := status.Result() + if err != nil { + return &NonRetryableError{err} + } + if _, err := task.Delete(ctx); err != nil { + return &NonRetryableError{err} + } + if code != 0 { + return &WorkloadError{code} + } return nil } diff --git a/launcher/container_runner_test.go b/launcher/container_runner_test.go index c6d7832e2..974f21def 100644 --- a/launcher/container_runner_test.go +++ b/launcher/container_runner_test.go @@ -1,4 +1,4 @@ -package main +package launcher import ( "bytes" diff --git a/launcher/errors.go b/launcher/errors.go new file mode 100644 index 000000000..4f4709655 --- /dev/null +++ b/launcher/errors.go @@ -0,0 +1,28 @@ +package launcher + +// RetryableError means launcher should reboot the VM to retry. +type RetryableError struct { + Err error +} + +// NonRetryableError means launcher shouldn't reboot the VM to retry. +type NonRetryableError struct { + Err error +} + +// WorkloadError represents the result of an workload/task that is non-zero. +type WorkloadError struct { + ReturnCode uint32 +} + +func (e *RetryableError) Error() string { + return e.Err.Error() +} + +func (e *NonRetryableError) Error() string { + return e.Err.Error() +} + +func (e *WorkloadError) Error() string { + return "workload finished with non-zero return code" +} diff --git a/launcher/image/container-runner.service b/launcher/image/container-runner.service index 644e72b92..f5973739c 100644 --- a/launcher/image/container-runner.service +++ b/launcher/image/container-runner.service @@ -4,12 +4,9 @@ Wants=network-online.target gcr-online.target containerd.service After=network-online.target gcr-online.target containerd.service [Service] -ExecStart=/var/lib/google/cc_container_launcher -# Shutdown the host after the launcher exits -ExecStopPost=/bin/sleep 60 -ExecStopPost=/usr/bin/systemctl poweroff +ExecStart=/usr/share/oem/confidential_space/cs_container_launcher +ExecStopPost=/usr/share/oem/confidential_space/exit_script.sh Restart=no -# RestartSec=90 StandardOutput=journal+console StandardError=journal+console diff --git a/launcher/image/entrypoint.sh b/launcher/image/entrypoint.sh index ee32613ed..374c8995e 100644 --- a/launcher/image/entrypoint.sh +++ b/launcher/image/entrypoint.sh @@ -1,14 +1,8 @@ #!/bin/bash main() { - # copy the binary - cp /usr/share/oem/cc_container_launcher /var/lib/google/cc_container_launcher - chmod +x /var/lib/google/cc_container_launcher - # copy systemd files - cp /usr/share/oem/container-runner.service /etc/systemd/system/container-runner.service - mkdir -p /etc/systemd/system/container-runner.service.d/ - cp /usr/share/oem/launcher.conf /etc/systemd/system/container-runner.service.d/launcher.conf + cp /usr/share/oem/confidential_space/container-runner.service /etc/systemd/system/container-runner.service systemctl daemon-reload systemctl enable container-runner.service diff --git a/launcher/image/exit_script.sh b/launcher/image/exit_script.sh new file mode 100755 index 000000000..9c12dd04c --- /dev/null +++ b/launcher/image/exit_script.sh @@ -0,0 +1,14 @@ +#! /bin/bash + +if [[ $EXIT_STATUS -eq 3 ]] +then + # reboot after 2 min + shutdown --reboot +2 +fi + +if [[ $EXIT_STATUS -eq 0 ]] || [[ $EXIT_STATUS -eq 1 ]] || [[ $EXIT_STATUS -eq 2 ]] +then + # poweroff after 2 min + shutdown --poweroff +2 +fi + diff --git a/launcher/image/preload.sh b/launcher/image/preload.sh index a8943c202..8f056fb69 100644 --- a/launcher/image/preload.sh +++ b/launcher/image/preload.sh @@ -1,20 +1,15 @@ #!/bin/bash +readonly OEM_PATH='/usr/share/oem' +readonly CS_PATH="${OEM_PATH}/confidential_space" + copy_launcher() { - cp launcher /usr/share/oem/cc_container_launcher + cp launcher "${CS_PATH}/cs_container_launcher" } setup_launcher_systemd_unit() { - cp container-runner.service /usr/share/oem/container-runner.service - - if [ "$IMAGE_ENV" == "hardened" ]; then - cp hardened.conf /usr/share/oem/launcher.conf - elif [ "$IMAGE_ENV" == "debug" ]; then - cp debug.conf /usr/share/oem/launcher.conf - else - echo "Unknown IMAGE_ENV: ${IMAGE_ENV}. Use hardened or debug" - exit 1 - fi + cp container-runner.service "${CS_PATH}/container-runner.service" + cp exit_script.sh "${CS_PATH}/exit_script.sh" } append_cmdline() { @@ -40,9 +35,9 @@ enable_unit() { } configure_entrypoint() { - cp "$1" /usr/share/oem/user-data - touch /usr/share/oem/meta-data - append_cmdline "'ds=nocloud;s=/usr/share/oem/'" + cp "$1" ${OEM_PATH}/user-data + touch ${OEM_PATH}/meta-data + append_cmdline "'ds=nocloud;s=${OEM_PATH}/'" } configure_necessary_systemd_units() { @@ -62,7 +57,6 @@ configure_systemd_units_for_debug() { # No-op for now, as debug will default to using multi-user.target. : } - configure_systemd_units_for_hardened() { configure_necessary_systemd_units # Make entrypoint (via cloud-init) the default unit. @@ -81,7 +75,8 @@ configure_systemd_units_for_hardened() { } main() { - mount -o remount,rw /usr/share/oem + mount -o remount,rw ${OEM_PATH} + mkdir ${CS_PATH} # Install container launcher entrypoint. configure_entrypoint "entrypoint.sh" diff --git a/launcher/launcher/main.go b/launcher/launcher/main.go new file mode 100644 index 000000000..9eca15b66 --- /dev/null +++ b/launcher/launcher/main.go @@ -0,0 +1,142 @@ +// package main is a program that will start a container with attestation. +package main + +import ( + "context" + "io" + "log" + "os" + + "cloud.google.com/go/compute/metadata" + "cloud.google.com/go/logging" + "github.com/containerd/containerd" + "github.com/containerd/containerd/defaults" + "github.com/containerd/containerd/namespaces" + "github.com/google/go-tpm-tools/launcher" + "github.com/google/go-tpm-tools/launcher/spec" + "github.com/google/go-tpm/tpm2" +) + +const ( + logName = "confidential-space-launcher" +) + +const ( + successRC = 0 // workload successful (no reboot) + failRC = 1 // workload or launcher internal failed (no reboot) + // panic() returns 2 + rebootRC = 3 // reboot + holdRC = 4 // hold +) + +var logger *log.Logger +var mdsClient *metadata.Client +var launchSpec spec.LaunchSpec + +func main() { + var exitCode int + defer func() { + os.Exit(exitCode) + }() + + logger = log.Default() + logger.Println("TEE container launcher initiating") + + mdsClient = metadata.NewClient(nil) + projectID, err := mdsClient.ProjectID() + if err != nil { + logger.Printf("cannot get projectID, not in GCE? %v", err) + // cannot get projectID from MDS, exit directly + exitCode = failRC + return + } + + logClient, err := logging.NewClient(context.Background(), projectID) + if err != nil { + logger.Printf("cannot setup Cloud Logging, using the default stdout logger %v", err) + } else { + defer logClient.Close() + logger.Printf("logs will be published to Cloud Logging under the log name %s\n", logName) + logger = logClient.Logger(logName).StandardLogger(logging.Info) + loggerAndStdout := io.MultiWriter(os.Stdout, logger.Writer()) // for now also print log to stdout + logger.SetOutput(loggerAndStdout) + } + + // get restart policy and ishardened from spec + launchSpec, err = spec.GetLaunchSpec(mdsClient) + if err != nil { + logger.Println(err) + // if cannot get launchSpec, exit directly + exitCode = failRC + return + } + + if err = startLauncher(); err != nil { + logger.Println(err) + } + + exitCode = getExitCode(launchSpec.Hardened, launchSpec.RestartPolicy, err) +} + +func getExitCode(isHardened bool, restartPolicy spec.RestartPolicy, err error) int { + exitCode := 0 + + // if in a debug image, will always hold + if !isHardened { + return holdRC + } + + if err != nil { + switch err.(type) { + default: + // unknown error + exitCode = failRC + case *launcher.RetryableError, *launcher.WorkloadError: + if restartPolicy == spec.Always || restartPolicy == spec.OnFailure { + exitCode = rebootRC + } else { + exitCode = failRC + } + case *launcher.NonRetryableError: + exitCode = failRC + } + } else { + // if no error + if restartPolicy == spec.Always { + exitCode = rebootRC + } else { + exitCode = successRC + } + } + + return exitCode +} + +func startLauncher() error { + logger.Println("Launch Spec: ", launchSpec) + client, err := containerd.New(defaults.DefaultAddress) + if err != nil { + return &launcher.RetryableError{Err: err} + } + defer client.Close() + + tpm, err := tpm2.OpenTPM("/dev/tpmrm0") + if err != nil { + return &launcher.RetryableError{Err: err} + } + defer tpm.Close() + + token, err := launcher.RetrieveAuthToken(mdsClient) + if err != nil { + logger.Printf("failed to retrieve auth token: %v, using empty auth", err) + } + + ctx := namespaces.WithNamespace(context.Background(), namespaces.Default) + r, err := launcher.NewRunner(ctx, client, token, launchSpec, mdsClient, tpm, logger) + if err != nil { + return err + } + defer r.Close(ctx) + + return r.Run(ctx) +} diff --git a/launcher/launcher/main_test.go b/launcher/launcher/main_test.go new file mode 100644 index 000000000..035523e44 --- /dev/null +++ b/launcher/launcher/main_test.go @@ -0,0 +1,131 @@ +package main + +import ( + "testing" + + "github.com/google/go-tpm-tools/launcher" + "github.com/google/go-tpm-tools/launcher/spec" +) + +func TestGetExitCode(t *testing.T) { + testcases := []struct { + name string + isHardened bool + restartPolicy spec.RestartPolicy + err error + expectedReturnCode int + }{ + // no error, debug image + { + "debug, always restart, nil error", + false, spec.Always, nil, holdRC, + }, + { + "debug, never restart, nil error", + false, spec.Never, nil, holdRC, + }, + { + "debug, onfailure restart, nil error", + false, spec.OnFailure, nil, holdRC, + }, + // no error, hardened image + { + "hardened, always restart, nil error", + true, spec.Always, nil, rebootRC, + }, + { + "hardened, never restart, nil error", + true, spec.Never, nil, successRC, + }, + { + "hardened, onfailure restart, nil error", + true, spec.OnFailure, nil, successRC, + }, + // retryable error, debug image + { + "debug, always restart, retryable error", + false, spec.Always, &launcher.RetryableError{}, holdRC, + }, + { + "debug, never restart, retryable error", + false, spec.Never, &launcher.RetryableError{}, holdRC, + }, + { + "debug, onfailure restart, retryable error", + false, spec.OnFailure, &launcher.RetryableError{}, holdRC, + }, + // workload error, debug image (same as retryable error) + { + "debug, always restart, workload error", + false, spec.Always, &launcher.WorkloadError{}, holdRC, + }, + { + "debug, never restart, workload error", + false, spec.Never, &launcher.WorkloadError{}, holdRC, + }, + { + "debug, onfailure restart, workload error", + false, spec.OnFailure, &launcher.WorkloadError{}, holdRC, + }, + // retryable error, hardened image + { + "hardened, always restart, retryable error", + true, spec.Always, &launcher.RetryableError{}, rebootRC, + }, + { + "hardened, never restart, retryable error", + true, spec.Never, &launcher.RetryableError{}, failRC, + }, + { + "hardened, onfailure restart, retryable error", + true, spec.OnFailure, &launcher.RetryableError{}, rebootRC, + }, + // workload error, hardened image (same as retryable error) + { + "hardened, always restart, workload error", + true, spec.Always, &launcher.WorkloadError{}, rebootRC, + }, + { + "hardened, never restart, workload error", + true, spec.Never, &launcher.WorkloadError{}, failRC, + }, + { + "hardened, onfailure restart, workload error", + true, spec.OnFailure, &launcher.WorkloadError{}, rebootRC, + }, + // non-retryable error, debug image + { + "debug, always restart, non-retryable error", + false, spec.Always, &launcher.NonRetryableError{}, holdRC, + }, + { + "debug, never restart, non-retryable error", + false, spec.Never, &launcher.NonRetryableError{}, holdRC, + }, + { + "debug, onfailure restart, non-retryable error", + false, spec.OnFailure, &launcher.NonRetryableError{}, holdRC, + }, + // non-retryable error, hardened image + { + "debug, always restart, non-retryable error", + true, spec.Always, &launcher.NonRetryableError{}, failRC, + }, + { + "debug, never restart, non-retryable error", + true, spec.Never, &launcher.NonRetryableError{}, failRC, + }, + { + "debug, onfailure restart, non-retryable error", + true, spec.OnFailure, &launcher.NonRetryableError{}, failRC, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + if rc := getExitCode(tc.isHardened, tc.restartPolicy, tc.err); rc != tc.expectedReturnCode { + t.Errorf("got %d, wanted %d", rc, tc.expectedReturnCode) + } + }) + } +} diff --git a/launcher/main.go b/launcher/main.go deleted file mode 100644 index c1de57130..000000000 --- a/launcher/main.go +++ /dev/null @@ -1,89 +0,0 @@ -// package main is a program that will start a container with attestation. -package main - -import ( - "context" - "flag" - "io" - "log" - "os" - - "cloud.google.com/go/compute/metadata" - "cloud.google.com/go/logging" - "github.com/containerd/containerd" - "github.com/containerd/containerd/defaults" - "github.com/containerd/containerd/namespaces" - "github.com/google/go-tpm-tools/launcher/spec" - "github.com/google/go-tpm/tpm2" -) - -const ( - logName = "confidential-space-launcher" -) - -func main() { - flag.Parse() - os.Exit(run()) -} - -func run() int { - logger := log.Default() - logger.Println("TEE container launcher starting...") - - mdsClient := metadata.NewClient(nil) - ctx := namespaces.WithNamespace(context.Background(), namespaces.Default) - projectID, err := mdsClient.ProjectID() - if err != nil { - logger.Printf("cannot get projectID, not in GCE? %v", err) - return 1 - } - logClient, err := logging.NewClient(context.Background(), projectID) - if err != nil { - logger.Printf("cannot setup Cloud Logging, using the default stdout logger %v", err) - } else { - defer logClient.Close() - logger.Printf("logs will be published to Cloud Logging under the log name %s\n", logName) - logger = logClient.Logger(logName).StandardLogger(logging.Info) - loggerAndStdout := io.MultiWriter(os.Stdout, logger.Writer()) // for now also print log to stdout - logger.SetOutput(loggerAndStdout) - } - - launchSpec, err := spec.GetLaunchSpec(mdsClient) - if err != nil { - logger.Println(err) - return 1 - } - logger.Println("Launch Spec: ", launchSpec) - - client, err := containerd.New(defaults.DefaultAddress) - if err != nil { - logger.Println(err) - return 1 - } - defer client.Close() - - tpm, err := tpm2.OpenTPM("/dev/tpmrm0") - if err != nil { - logger.Println(err) - return 1 - } - defer tpm.Close() - - token, err := RetrieveAuthToken(mdsClient) - if err != nil { - logger.Printf("failed to retrieve auth token: %v, using empty auth", err) - } - - r, err := NewRunner(ctx, client, token, launchSpec, mdsClient, tpm, logger) - if err != nil { - logger.Println(err) - return 1 - } - defer r.Close(ctx) - - if err := r.Run(ctx); err != nil { - logger.Println(err) - return 1 - } - return 0 -}