diff --git a/launcher/container_runner.go b/launcher/container_runner.go index b03ba5f7a..55959c6df 100644 --- a/launcher/container_runner.go +++ b/launcher/container_runner.go @@ -10,6 +10,7 @@ import ( "net/url" "os" "path" + "strings" "time" "cloud.google.com/go/compute/metadata" @@ -23,10 +24,13 @@ import ( "github.com/google/go-tpm-tools/client" "github.com/google/go-tpm-tools/launcher/agent" "github.com/google/go-tpm-tools/launcher/spec" + "github.com/google/go-tpm-tools/launcher/verifier" "github.com/google/go-tpm-tools/launcher/verifier/grpcclient" + "github.com/google/go-tpm-tools/launcher/verifier/rest" v1 "github.com/opencontainers/image-spec/specs-go/v1" specs "github.com/opencontainers/runtime-spec/specs-go" "golang.org/x/oauth2" + "golang.org/x/oauth2/google" "google.golang.org/api/impersonate" "google.golang.org/api/option" "google.golang.org/grpc" @@ -161,14 +165,6 @@ func NewRunner(ctx context.Context, cdClient *containerd.Client, token oauth2.To return nil, fmt.Errorf("length of Args [%d] is shorter or equal to the length of the given Cmd [%d], maybe the Entrypoint is set to empty in the image?", len(containerSpec.Process.Args), len(launchSpec.Cmd)) } - // TODO(b/212586174): Dial with secure credentials. - opt := grpc.WithTransportCredentials(insecure.NewCredentials()) - conn, err := grpc.Dial(launchSpec.AttestationServiceAddr, opt) - if err != nil { - return nil, fmt.Errorf("failed to open connection to attestation service: %v", err) - } - verifierClient := grpcclient.NewClient(conn, logger) - // Fetch ID token with specific audience. // See https://cloud.google.com/functions/docs/securing/authenticating#functions-bearer-token-example-go. principalFetcher := func(audience string) ([][]byte, error) { @@ -198,6 +194,21 @@ func NewRunner(ctx context.Context, cdClient *containerd.Client, token oauth2.To return tokens, nil } + asAddr := launchSpec.AttestationServiceAddr + var verifierClient verifier.Client + var conn *grpc.ClientConn + // Temporary support for both gRPC and REST-based attestation verifier. + // Use REST when empty flag or the presence of http in the addr, else gRPC. + // TODO: remove once fully migrated to the REST-based verifier. + if asAddr == "" || strings.Contains(asAddr, "http") { + verifierClient, err = getRESTClient(ctx, asAddr, launchSpec) + } else { + verifierClient, conn, err = getGRPCClient(asAddr, logger) + } + if err != nil { + return nil, fmt.Errorf("failed to create verifier client: %v", err) + } + return &ContainerRunner{ container, launchSpec, @@ -207,6 +218,39 @@ func NewRunner(ctx context.Context, cdClient *containerd.Client, token oauth2.To }, nil } +// getGRPCClient returns a gRPC verifier.Client pointing to the given address. +// It also returns a grpc.ClientConn for closing out the connection. +func getGRPCClient(asAddr string, logger *log.Logger) (verifier.Client, *grpc.ClientConn, error) { + opt := grpc.WithTransportCredentials(insecure.NewCredentials()) + conn, err := grpc.Dial(asAddr, opt) + if err != nil { + return nil, nil, fmt.Errorf("failed to open connection to gRPC attestation service: %v", err) + } + return grpcclient.NewClient(conn, logger), conn, nil +} + +// getRESTClient returns a REST verifier.Client that points to the given address. +// It defaults to the Attestation Verifier instance at +// https://confidentialcomputing.googleapis.com. +func getRESTClient(ctx context.Context, asAddr string, spec spec.LauncherSpec) (verifier.Client, error) { + httpClient, err := google.DefaultClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client: %v", err) + } + + opts := []option.ClientOption{option.WithHTTPClient(httpClient)} + if asAddr != "" { + opts = append(opts, option.WithEndpoint(asAddr)) + } + + const defaultRegion = "us-central1" + restClient, err := rest.NewClient(ctx, spec.ProjectID, defaultRegion, opts...) + if err != nil { + return nil, err + } + return restClient, nil +} + // parseEnvVars parses the environment variables to the oci format func parseEnvVars(envVars []spec.EnvVar) []string { var result []string @@ -433,5 +477,7 @@ func (r *ContainerRunner) Close(ctx context.Context) { // Exit gracefully: // Delete container and close connection to attestation service. r.container.Delete(ctx, containerd.WithSnapshotCleanup) - r.attestConn.Close() + if r.attestConn != nil { + r.attestConn.Close() + } } diff --git a/launcher/spec/launcher_spec.go b/launcher/spec/launcher_spec.go index f6381e360..a3bd977da 100644 --- a/launcher/spec/launcher_spec.go +++ b/launcher/spec/launcher_spec.go @@ -28,10 +28,6 @@ const ( Never RestartPolicy = "Never" ) -const ( - defaultAttestationServiceEndpoint = "attestation-verifier.confidential-computing-test-org.joonix.net:9090" -) - const ( imageRefKey = "tee-image-reference" restartPolicyKey = "tee-restart-policy" @@ -56,12 +52,15 @@ type EnvVar struct { // LauncherSpec contains specification set by the operator who wants to // launch a container. type LauncherSpec struct { + // MDS-based values. ImageRef string RestartPolicy RestartPolicy Cmd []string Envs []EnvVar AttestationServiceAddr string ImpersonateServiceAccounts []string + ProjectID string + Region string } // UnmarshalJSON unmarshals an instance attributes list in JSON format from the metadata @@ -110,6 +109,18 @@ func (s *LauncherSpec) UnmarshalJSON(b []byte) error { return nil } +func getRegion(client *metadata.Client) (string, error) { + zone, err := client.Zone() + if err != nil { + return "", fmt.Errorf("failed to retrieve zone from MDS: %v", err) + } + lastDash := strings.LastIndex(zone, "-") + if lastDash == -1 { + return "", fmt.Errorf("got malformed zone from MDS: %v", zone) + } + return zone[:lastDash], nil +} + // GetLauncherSpec takes in a metadata server client, reads and parse operator's // input to the GCE instance custom metadata and return a LauncherSpec. // ImageRef (tee-image-reference) is required, will return an error if @@ -125,8 +136,14 @@ func GetLauncherSpec(client *metadata.Client) (LauncherSpec, error) { return LauncherSpec{}, err } - if spec.AttestationServiceAddr == "" { - spec.AttestationServiceAddr = defaultAttestationServiceEndpoint + spec.ProjectID, err = client.ProjectID() + if err != nil { + return LauncherSpec{}, fmt.Errorf("failed to retrieve projectID from MDS: %v", err) + } + + spec.Region, err = getRegion(client) + if err != nil { + return LauncherSpec{}, err } return *spec, nil