Skip to content

Commit

Permalink
Merge pull request #367 from jkl73/customaudnonce
Browse files Browse the repository at this point in the history
[launcher] Add tee server and custom audience/nonces token support
  • Loading branch information
jkl73 authored Nov 1, 2023
2 parents 0dd0099 + 431350b commit a8d45d8
Show file tree
Hide file tree
Showing 13 changed files with 363 additions and 265 deletions.
168 changes: 0 additions & 168 deletions go.work.sum

Large diffs are not rendered by default.

13 changes: 11 additions & 2 deletions launcher/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ type principalIDTokenFetcher func(audience string) ([][]byte, error)
// struct to make testing easier.
type AttestationAgent interface {
MeasureEvent(cel.Content) error
Attest(context.Context) ([]byte, error)
Attest(context.Context, AttestAgentOpts) ([]byte, error)
}

// AttestAgentOpts contains user generated options when calling the
// VerifyAttestation API
type AttestAgentOpts struct {
Aud string
Nonces []string
}

type agent struct {
Expand Down Expand Up @@ -76,7 +83,7 @@ func (a *agent) MeasureEvent(event cel.Content) error {
// Attest fetches the nonce and connection ID from the Attestation Service,
// creates an attestation message, and returns the resultant
// principalIDTokens and Metadata Server-generated ID tokens for the instance.
func (a *agent) Attest(ctx context.Context) ([]byte, error) {
func (a *agent) Attest(ctx context.Context, opts AttestAgentOpts) ([]byte, error) {
challenge, err := a.client.CreateChallenge(ctx)
if err != nil {
return nil, err
Expand All @@ -96,6 +103,8 @@ func (a *agent) Attest(ctx context.Context) ([]byte, error) {
Challenge: challenge,
GcpCredentials: principalTokens,
Attestation: attestation,
CustomAudience: opts.Aud,
CustomNonce: opts.Nonces,
}

if a.launchSpec.Experiments.EnableSignedContainerImage {
Expand Down
2 changes: 1 addition & 1 deletion launcher/agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func TestAttest(t *testing.T) {

agent := CreateAttestationAgent(tpm, client.AttestationKeyECC, verifierClient, tc.principalIDTokenFetcher, tc.containerSignaturesFetcher, tc.launchSpec, log.Default())

tokenBytes, err := agent.Attest(context.Background())
tokenBytes, err := agent.Attest(context.Background(), AttestAgentOpts{})
if err != nil {
t.Errorf("failed to attest to Attestation Service: %v", err)
}
Expand Down
40 changes: 17 additions & 23 deletions launcher/container_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ import (
"github.com/google/go-tpm-tools/launcher/internal/signaturediscovery"
"github.com/google/go-tpm-tools/launcher/launcherfile"
"github.com/google/go-tpm-tools/launcher/spec"
"github.com/google/go-tpm-tools/launcher/teeserver"
"github.com/google/go-tpm-tools/launcher/verifier"
"github.com/google/go-tpm-tools/launcher/verifier/rest"
v1 "github.com/opencontainers/image-spec/specs-go/v1"
specs "github.com/opencontainers/runtime-spec/specs-go"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/impersonate"
"google.golang.org/api/option"
)

Expand All @@ -53,6 +53,8 @@ type ContainerRunner struct {

const tokenFileTmp = ".token.tmp"

const teeServerSocket = "teeserver.sock"

// Since we only allow one container on a VM, using a deterministic id is probably fine
const (
containerID = "tee-container"
Expand All @@ -74,26 +76,6 @@ const (
defaultRefreshJitter = 0.1
)

func fetchImpersonatedToken(ctx context.Context, serviceAccount string, audience string, opts ...option.ClientOption) ([]byte, error) {
config := impersonate.IDTokenConfig{
Audience: audience,
TargetPrincipal: serviceAccount,
IncludeEmail: true,
}

tokenSource, err := impersonate.IDTokenSource(ctx, config, opts...)
if err != nil {
return nil, fmt.Errorf("error creating token source: %v", err)
}

token, err := tokenSource.Token()
if err != nil {
return nil, fmt.Errorf("error retrieving token: %v", err)
}

return []byte(token.AccessToken), nil
}

// NewRunner returns a runner.
func NewRunner(ctx context.Context, cdClient *containerd.Client, token oauth2.Token, launchSpec spec.LaunchSpec, mdsClient *metadata.Client, tpm io.ReadWriteCloser, logger *log.Logger, serialConsole *os.File) (*ContainerRunner, error) {
image, err := initImage(ctx, cdClient, launchSpec, token)
Expand All @@ -103,6 +85,7 @@ func NewRunner(ctx context.Context, cdClient *containerd.Client, token oauth2.To

mounts := make([]specs.Mount, 0)
mounts = appendTokenMounts(mounts)

envs, err := formatEnvVars(launchSpec.Envs)
if err != nil {
return nil, err
Expand Down Expand Up @@ -214,7 +197,7 @@ func NewRunner(ctx context.Context, cdClient *containerd.Client, token oauth2.To

// Fetch impersonated ID tokens.
for _, sa := range launchSpec.ImpersonateServiceAccounts {
idToken, err := fetchImpersonatedToken(ctx, sa, audience)
idToken, err := FetchImpersonatedToken(ctx, sa, audience)
if err != nil {
return nil, fmt.Errorf("failed to get impersonated token for %v: %w", sa, err)
}
Expand Down Expand Up @@ -360,7 +343,8 @@ func (r *ContainerRunner) measureContainerClaims(ctx context.Context) error {
// The token file will be written to a tmp file and then renamed.
func (r *ContainerRunner) refreshToken(ctx context.Context) (time.Duration, error) {
r.logger.Print("refreshing attestation verifier OIDC token")
token, err := r.attestAgent.Attest(ctx)
// request a default token
token, err := r.attestAgent.Attest(ctx, agent.AttestAgentOpts{})
if err != nil {
return 0, fmt.Errorf("failed to retrieve attestation service token: %v", err)
}
Expand Down Expand Up @@ -512,6 +496,16 @@ func (r *ContainerRunner) Run(ctx context.Context) error {
}

r.logger.Printf("EnableTestFeatureForImage is set to %v\n", r.launchSpec.Experiments.EnableTestFeatureForImage)
// create and start the TEE server behind the experiment
if r.launchSpec.Experiments.EnableOnDemandAttestation {
r.logger.Println("EnableOnDemandAttestation is enabled: initializing TEE server.")
teeServer, err := teeserver.New(path.Join(launcherfile.HostTmpPath, teeServerSocket), r.attestAgent, r.logger)
if err != nil {
return fmt.Errorf("failed to create the TEE server: %v", err)
}
go teeServer.Serve()
defer teeServer.Shutdown(ctx)
}

var streamOpt cio.Opt
switch r.launchSpec.LogRedirect {
Expand Down
82 changes: 13 additions & 69 deletions launcher/container_runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"path"
"strconv"
Expand All @@ -23,10 +20,10 @@ import (
"github.com/containerd/containerd/namespaces"
"github.com/golang-jwt/jwt/v4"
"github.com/google/go-tpm-tools/cel"
"github.com/google/go-tpm-tools/launcher/agent"
"github.com/google/go-tpm-tools/launcher/launcherfile"
"github.com/google/go-tpm-tools/launcher/spec"
"golang.org/x/oauth2"
"google.golang.org/api/option"
)

const (
Expand All @@ -36,7 +33,7 @@ const (
// Fake attestation agent.
type fakeAttestationAgent struct {
measureEventFunc func(cel.Content) error
attestFunc func(context.Context) ([]byte, error)
attestFunc func(context.Context, agent.AttestAgentOpts) ([]byte, error)
}

func (f *fakeAttestationAgent) MeasureEvent(event cel.Content) error {
Expand All @@ -47,9 +44,9 @@ func (f *fakeAttestationAgent) MeasureEvent(event cel.Content) error {
return fmt.Errorf("unimplemented")
}

func (f *fakeAttestationAgent) Attest(ctx context.Context) ([]byte, error) {
func (f *fakeAttestationAgent) Attest(ctx context.Context, _ agent.AttestAgentOpts) ([]byte, error) {
if f.attestFunc != nil {
return f.attestFunc(ctx)
return f.attestFunc(ctx, agent.AttestAgentOpts{})
}

return nil, fmt.Errorf("unimplemented")
Expand Down Expand Up @@ -102,7 +99,7 @@ func TestRefreshToken(t *testing.T) {

runner := ContainerRunner{
attestAgent: &fakeAttestationAgent{
attestFunc: func(context.Context) ([]byte, error) {
attestFunc: func(context.Context, agent.AttestAgentOpts) ([]byte, error) {
return expectedToken, nil
},
},
Expand Down Expand Up @@ -146,15 +143,15 @@ func TestRefreshTokenError(t *testing.T) {
{
name: "Attest fails",
agent: &fakeAttestationAgent{
attestFunc: func(context.Context) ([]byte, error) {
attestFunc: func(context.Context, agent.AttestAgentOpts) ([]byte, error) {
return nil, errors.New("attest error")
},
},
},
{
name: "Attest returns expired token",
agent: &fakeAttestationAgent{
attestFunc: func(context.Context) ([]byte, error) {
attestFunc: func(context.Context, agent.AttestAgentOpts) ([]byte, error) {
return createJWT(t, -5*time.Second), nil
},
},
Expand Down Expand Up @@ -184,7 +181,7 @@ func TestFetchAndWriteTokenSucceeds(t *testing.T) {

runner := ContainerRunner{
attestAgent: &fakeAttestationAgent{
attestFunc: func(context.Context) ([]byte, error) {
attestFunc: func(context.Context, agent.AttestAgentOpts) ([]byte, error) {
return expectedToken, nil
},
},
Expand Down Expand Up @@ -212,11 +209,11 @@ func TestTokenIsNotChangedIfRefreshFails(t *testing.T) {

expectedToken := createJWT(t, 5*time.Second)
ttl := 5 * time.Second
successfulAttestFunc := func(context.Context) ([]byte, error) {
successfulAttestFunc := func(context.Context, agent.AttestAgentOpts) ([]byte, error) {
return expectedToken, nil
}

errorAttestFunc := func(context.Context) ([]byte, error) {
errorAttestFunc := func(context.Context, agent.AttestAgentOpts) ([]byte, error) {
return nil, errors.New("attest unsuccessful")
}

Expand Down Expand Up @@ -289,7 +286,7 @@ func testRetryPolicyWithNTries(t *testing.T, numTries int, expectRefresh bool) {
// Wait the initial token's 5s plus a second per retry (MaxInterval).
ttl := time.Duration(numTries)*time.Second + 5*time.Second
retry := -1
attestFunc := func(context.Context) ([]byte, error) {
attestFunc := func(context.Context, agent.AttestAgentOpts) ([]byte, error) {
retry++
// Success on the initial fetch (subsequent calls use refresher goroutine).
if retry == 0 {
Expand Down Expand Up @@ -350,7 +347,7 @@ func TestFetchAndWriteTokenWithTokenRefresh(t *testing.T) {

runner := ContainerRunner{
attestAgent: &fakeAttestationAgent{
attestFunc: func(context.Context) ([]byte, error) {
attestFunc: func(context.Context, agent.AttestAgentOpts) ([]byte, error) {
return expectedToken, nil
},
},
Expand All @@ -374,7 +371,7 @@ func TestFetchAndWriteTokenWithTokenRefresh(t *testing.T) {
// Change attest agent to return new token.
expectedRefreshedToken := createJWT(t, 10*time.Second)
runner.attestAgent = &fakeAttestationAgent{
attestFunc: func(context.Context) ([]byte, error) {
attestFunc: func(context.Context, agent.AttestAgentOpts) ([]byte, error) {
return expectedRefreshedToken, nil
},
}
Expand Down Expand Up @@ -402,59 +399,6 @@ func TestFetchAndWriteTokenWithTokenRefresh(t *testing.T) {
}
}

type testRoundTripper struct {
roundTripFunc func(*http.Request) *http.Response
}

func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return t.roundTripFunc(req), nil
}

type idTokenResp struct {
Token string `json:"token"`
}

func TestFetchImpersonatedToken(t *testing.T) {
expectedEmail := "test2@google.com"

expectedToken := []byte("test_token")

expectedURL := fmt.Sprintf(idTokenEndpoint, expectedEmail)
client := &http.Client{
Transport: &testRoundTripper{
roundTripFunc: func(req *http.Request) *http.Response {
if req.URL.String() != expectedURL {
t.Errorf("HTTP call was not made to a endpoint: got %v, want %v", req.URL.String(), expectedURL)
}

resp := idTokenResp{
Token: string(expectedToken),
}

respBody, err := json.Marshal(resp)
if err != nil {
t.Fatalf("Unable to marshal HTTP response: %v", err)
}

return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewBuffer(respBody)),
}
},
},
}

token, err := fetchImpersonatedToken(context.Background(), expectedEmail, "test_aud", option.WithHTTPClient(client))
if err != nil {
t.Fatalf("fetchImpersonatedToken returned error: %v", err)
}

if !bytes.Equal(token, expectedToken) {
t.Errorf("fetchImpersonatedToken did not return expected token: got %v, want %v", token, expectedToken)
}
}

func TestGetNextRefresh(t *testing.T) {
// 0 <= random < 1.
for _, randNum := range []float64{0, .1415926, .5, .75, .999999999} {
Expand Down
1 change: 1 addition & 0 deletions launcher/internal/experiments/experiments.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
type Experiments struct {
EnableTestFeatureForImage bool
EnableSignedContainerImage bool
EnableOnDemandAttestation bool
}

// New takes a filepath, opens the file, and calls ReadJsonInput with the contents
Expand Down
Loading

0 comments on commit a8d45d8

Please sign in to comment.