Skip to content

Commit

Permalink
Fix teeserver context reset issue & add container signature cache (#397)
Browse files Browse the repository at this point in the history
* Fix teeserver context reset issue

* Adding signature cache
  • Loading branch information
yawangwang authored Jan 30, 2024
1 parent e2d4797 commit fd156ad
Show file tree
Hide file tree
Showing 12 changed files with 299 additions and 36 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,11 @@ jobs:
run: |
GO_EXECUTABLE_PATH=$(which go)
sudo $GO_EXECUTABLE_PATH test -v -run "TestFetchImageSignaturesDockerPublic" ./launcher
- name: Run specific tests to capture potential data race
run: go test ./launcher/agent -race -run TestCacheConcurrentSetGet
if: (runner.os == 'Linux' || runner.os == 'macOS') && matrix.architecture == 'x64'
- name: Test all modules
run: go test -v ./... ./cmd/... ./launcher/...
run: go test -v ./... ./cmd/... ./launcher/... -skip=TestCacheConcurrentSetGet

lint:
strategy:
Expand Down
15 changes: 14 additions & 1 deletion cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,20 @@ steps:
gcloud builds submit --config=test_memory_monitoring.yaml --region us-west1 \
--substitutions _IMAGE_NAME=${OUTPUT_IMAGE_PREFIX}-hardened-${OUTPUT_IMAGE_SUFFIX},_IMAGE_PROJECT=${PROJECT_ID}
exit
- name: 'gcr.io/cloud-builders/gcloud'
id: ODAWithSignedContainerTest
waitFor: ['HardenedImageBuild']
env:
- 'OUTPUT_IMAGE_PREFIX=$_OUTPUT_IMAGE_PREFIX'
- 'OUTPUT_IMAGE_SUFFIX=$_OUTPUT_IMAGE_SUFFIX'
- 'PROJECT_ID=$PROJECT_ID'
script: |
#!/usr/bin/env bash
cd launcher/image/test
echo "running ODA and signed container tests on ${OUTPUT_IMAGE_PREFIX}-hardened-${OUTPUT_IMAGE_SUFFIX}"
gcloud builds submit --config=test_oda_with_signed_container.yaml --region us-west1 \
--substitutions _IMAGE_NAME=${OUTPUT_IMAGE_PREFIX}-hardened-${OUTPUT_IMAGE_SUFFIX},_IMAGE_PROJECT=${PROJECT_ID}
exit
options:
pool:
name: 'projects/confidential-space-images-dev/locations/us-west1/workerPools/cs-image-build-vpc'
48 changes: 41 additions & 7 deletions launcher/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type principalIDTokenFetcher func(audience string) ([][]byte, error)
type AttestationAgent interface {
MeasureEvent(cel.Content) error
Attest(context.Context, AttestAgentOpts) ([]byte, error)
Refresh(context.Context) error
}

// AttestAgentOpts contains user generated options when calling the
Expand All @@ -54,6 +55,7 @@ type agent struct {
cosCel cel.CEL
launchSpec spec.LaunchSpec
logger *log.Logger
sigsCache *sigsCache
}

// CreateAttestationAgent returns an agent capable of performing remote
Expand All @@ -72,6 +74,7 @@ func CreateAttestationAgent(tpm io.ReadWriteCloser, akFetcher tpmKeyFetcher, ver
sigsFetcher: sigsFetcher,
launchSpec: launchSpec,
logger: logger,
sigsCache: &sigsCache{},
}
}

Expand Down Expand Up @@ -111,12 +114,15 @@ func (a *agent) Attest(ctx context.Context, opts AttestAgentOpts) ([]byte, error
},
}

if a.launchSpec.Experiments.EnableSignedContainerImage {
signatures := fetchContainerImageSignatures(ctx, a.sigsFetcher, a.launchSpec.SignedImageRepos, a.logger)
if len(signatures) > 0 {
req.ContainerImageSignatures = signatures
a.logger.Printf("Found container image signatures: %v\n", signatures)
}
var signatures []oci.Signature
if a.launchSpec.Experiments.EnableSignedContainerCache {
signatures = a.sigsCache.get()
} else {
signatures = fetchContainerImageSignatures(ctx, a.sigsFetcher, a.launchSpec.SignedImageRepos, a.logger)
}
if len(signatures) > 0 {
req.ContainerImageSignatures = signatures
a.logger.Printf("Found container image signatures: %v\n", signatures)
}

resp, err := a.client.VerifyAttestation(ctx, req)
Expand All @@ -129,6 +135,17 @@ func (a *agent) Attest(ctx context.Context, opts AttestAgentOpts) ([]byte, error
return resp.ClaimsToken, nil
}

// Refresh refreshes the internal state of the attestation agent.
// It will reset the container image signatures for now.
func (a *agent) Refresh(ctx context.Context) error {
if a.launchSpec.Experiments.EnableSignedContainerCache {
signatures := fetchContainerImageSignatures(ctx, a.sigsFetcher, a.launchSpec.SignedImageRepos, a.logger)
a.sigsCache.set(signatures)
a.logger.Printf("Refreshed container image signature cache: %v\n", signatures)
}
return nil
}

func (a *agent) getAttestation(nonce []byte) (*pb.Attestation, error) {
ak, err := a.akFetcher(a.tpm)
if err != nil {
Expand All @@ -148,7 +165,6 @@ func (a *agent) getAttestation(nonce []byte) (*pb.Attestation, error) {
return attestation, nil
}

// TODO: cache signatures so we don't need to fetch every time.
func fetchContainerImageSignatures(ctx context.Context, fetcher signaturediscovery.Fetcher, targetRepos []string, logger *log.Logger) []oci.Signature {
signatures := make([][]oci.Signature, len(targetRepos))

Expand All @@ -173,3 +189,21 @@ func fetchContainerImageSignatures(ctx context.Context, fetcher signaturediscove
}
return foundSigs
}

type sigsCache struct {
mu sync.RWMutex
items []oci.Signature
}

func (c *sigsCache) set(sigs []oci.Signature) {
c.mu.Lock()
defer c.mu.Unlock()
c.items = make([]oci.Signature, len(sigs))
copy(c.items, sigs)
}

func (c *sigsCache) get() []oci.Signature {
c.mu.RLock()
defer c.mu.RUnlock()
return c.items
}
75 changes: 64 additions & 11 deletions launcher/agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,25 @@ import (
"encoding/base64"
"fmt"
"log"
"runtime"
"sync"
"testing"

"github.com/golang-jwt/jwt/v4"
"github.com/google/go-cmp/cmp"
"github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm-tools/internal/test"
"github.com/google/go-tpm-tools/launcher/internal/experiments"
"github.com/google/go-tpm-tools/launcher/internal/oci"
"github.com/google/go-tpm-tools/launcher/internal/oci/cosign"
"github.com/google/go-tpm-tools/launcher/internal/signaturediscovery"
"github.com/google/go-tpm-tools/launcher/spec"
"github.com/google/go-tpm-tools/launcher/verifier"
"github.com/google/go-tpm-tools/launcher/verifier/fake"
)

func TestAttest(t *testing.T) {
ctx := context.Background()
testCases := []struct {
name string
launchSpec spec.LaunchSpec
Expand All @@ -37,7 +42,7 @@ func TestAttest(t *testing.T) {
name: "enable signed container",
launchSpec: spec.LaunchSpec{
SignedImageRepos: []string{signaturediscovery.FakeRepoWithSignatures},
Experiments: experiments.Experiments{EnableSignedContainerImage: true},
Experiments: experiments.Experiments{EnableSignedContainerCache: true},
},
principalIDTokenFetcher: placeholderPrincipalFetcher,
containerSignaturesFetcher: signaturediscovery.NewFakeClient(),
Expand All @@ -61,7 +66,10 @@ func TestAttest(t *testing.T) {

agent := CreateAttestationAgent(tpm, client.AttestationKeyECC, verifierClient, tc.principalIDTokenFetcher, tc.containerSignaturesFetcher, tc.launchSpec, log.Default())

tokenBytes, err := agent.Attest(context.Background(), AttestAgentOpts{})
if err := agent.Refresh(ctx); err != nil {
t.Errorf("failed to fresh attestation agent: %v", err)
}
tokenBytes, err := agent.Attest(ctx, AttestAgentOpts{})
if err != nil {
t.Errorf("failed to attest to Attestation Service: %v", err)
}
Expand All @@ -88,7 +96,7 @@ func TestAttest(t *testing.T) {
if claims.Subject != "https://www.googleapis.com/compute/v1/projects/fakeProject/zones/fakeZone/instances/fakeInstance" {
t.Errorf("Invalid sub")
}
if tc.launchSpec.Experiments.EnableSignedContainerImage {
if tc.launchSpec.Experiments.EnableSignedContainerCache {
got := claims.ContainerImageSignatures
want := []fake.ContainerImageSignatureClaims{
{
Expand Down Expand Up @@ -214,14 +222,7 @@ func TestFetchContainerImageSignatures(t *testing.T) {
if len(gotSigs) != len(tc.wantBase64Sigs) {
t.Errorf("fetchContainerImageSignatures did not return expected signatures for test case %s, got signatures length %d, but want %d", tc.name, len(gotSigs), len(tc.wantBase64Sigs))
}
var gotBase64Sigs []string
for _, gotSig := range gotSigs {
base64Sig, err := gotSig.Base64Encoded()
if err != nil {
t.Fatalf("fetchContainerImageSignatures did not return expected base64 signatures for test case %s: %v", tc.name, err)
}
gotBase64Sigs = append(gotBase64Sigs, base64Sig)
}
gotBase64Sigs := convertOCISignatureToBase64(t, gotSigs)
if !cmp.Equal(gotBase64Sigs, tc.wantBase64Sigs) {
t.Errorf("fetchContainerImageSignatures did not return expected signatures for test case %s, got signatures %v, but want %v", tc.name, gotBase64Sigs, tc.wantBase64Sigs)
}
Expand Down Expand Up @@ -255,3 +256,55 @@ func TestFetchContainerImageSignatures(t *testing.T) {
})
}
}

func TestCacheConcurrentSetGet(t *testing.T) {
cache := &sigsCache{}
if sigs := cache.get(); len(sigs) != 0 {
t.Errorf("signature cache should be empty, but got: %v", sigs)
}

var wg sync.WaitGroup
for i := 0; i < runtime.NumCPU(); i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
if idx%2 == 1 {
sigs := generateRandSigs(t)
cache.set(sigs)
} else {
cache.get()
}
}(i)
}
wg.Wait()
}

func generateRandSigs(t *testing.T) []oci.Signature {
t.Helper()

b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
t.Fatalf("Unable to generate random bytes: %v", err)
}

randB64Str := base64.StdEncoding.EncodeToString(b)
return []oci.Signature{
cosign.NewFakeSignature(randB64Str, oci.ECDSAP256SHA256),
}
}

func convertOCISignatureToBase64(t *testing.T, sigs []oci.Signature) []string {
t.Helper()

var base64Sigs []string
for _, sig := range sigs {
b64Sig, err := sig.Base64Encoded()
if err != nil {
t.Fatalf("oci.Signature did not return expected base64 signature: %v", err)
}
base64Sigs = append(base64Sigs, b64Sig)
}

return base64Sigs
}
5 changes: 4 additions & 1 deletion launcher/container_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ func (r *ContainerRunner) measureContainerClaims(ctx context.Context) error {
// The token file will be written to a tmp file and then renamed.
func (r *ContainerRunner) refreshToken(ctx context.Context) (time.Duration, error) {
r.logger.Print("refreshing attestation verifier OIDC token")
if err := r.attestAgent.Refresh(ctx); err != nil {
return 0, fmt.Errorf("failed to refresh attestation agent: %v", err)
}
// request a default token
token, err := r.attestAgent.Attest(ctx, agent.AttestAgentOpts{})
if err != nil {
Expand Down Expand Up @@ -502,7 +505,7 @@ func (r *ContainerRunner) Run(ctx context.Context) error {
// create and start the TEE server behind the experiment
if r.launchSpec.Experiments.EnableOnDemandAttestation {
r.logger.Println("EnableOnDemandAttestation is enabled: initializing TEE server.")
teeServer, err := teeserver.New(path.Join(launcherfile.HostTmpPath, teeServerSocket), r.attestAgent, r.logger)
teeServer, err := teeserver.New(ctx, path.Join(launcherfile.HostTmpPath, teeServerSocket), r.attestAgent, r.logger)
if err != nil {
return fmt.Errorf("failed to create the TEE server: %v", err)
}
Expand Down
Loading

0 comments on commit fd156ad

Please sign in to comment.