Skip to content

Commit

Permalink
Support per-shard signing keys
Browse files Browse the repository at this point in the history
This change enables key rotation with a per-shard
signing key configuration. The LogRanges structure
now holds both active and inactive shards, with the
LogRange structure containing a signer, encoded
public key and log ID based on the public key.

This change is backwards compatible. If no signing
configuration is specified, the active shard
signing configuration is used for all shards.

Minor change: Standardized log ID vs tree ID, where
the former is the pubkey hash and the latter is the
ID for the Trillian tree.

Signed-off-by: Hayden Blauzvern <hblauzvern@google.com>
  • Loading branch information
haydentherapper committed Jan 15, 2025
1 parent dcad58c commit 052a08a
Show file tree
Hide file tree
Showing 12 changed files with 468 additions and 182 deletions.
55 changes: 14 additions & 41 deletions pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@ package api

import (
"context"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"fmt"
"os"
"path/filepath"
Expand All @@ -42,9 +40,6 @@ import (
"github.com/sigstore/rekor/pkg/storage"
"github.com/sigstore/rekor/pkg/trillianclient"
"github.com/sigstore/rekor/pkg/witness"
"github.com/sigstore/sigstore/pkg/cryptoutils"
"github.com/sigstore/sigstore/pkg/signature"
"github.com/sigstore/sigstore/pkg/signature/options"

_ "github.com/sigstore/rekor/pkg/pubsub/gcp" // Load GCP pubsub implementation
)
Expand Down Expand Up @@ -92,12 +87,9 @@ func dial(rpcServer string) (*grpc.ClientConn, error) {
}

type API struct {
logClient trillian.TrillianLogClient
logID int64
logRanges sharding.LogRanges
pubkey string // PEM encoded public key
pubkeyHash string // SHA256 hash of DER-encoded public key
signer signature.Signer
logClient trillian.TrillianLogClient
treeID int64
logRanges sharding.LogRanges
// stops checkpoint publishing
checkpointPublishCancel context.CancelFunc
// Publishes notifications when new entries are added to the log. May be
Expand All @@ -117,12 +109,6 @@ func NewAPI(treeID uint) (*API, error) {
logAdminClient := trillian.NewTrillianAdminClient(tConn)
logClient := trillian.NewTrillianLogClient(tConn)

shardingConfig := viper.GetString("trillian_log_server.sharding_config")
ranges, err := sharding.NewLogRanges(ctx, logClient, shardingConfig, treeID)
if err != nil {
return nil, fmt.Errorf("unable get sharding details from sharding config: %w", err)
}

tid := int64(treeID)
if tid == 0 {
log.Logger.Info("No tree ID specified, attempting to create a new tree")
Expand All @@ -133,27 +119,18 @@ func NewAPI(treeID uint) (*API, error) {
tid = t.TreeId
}
log.Logger.Infof("Starting Rekor server with active tree %v", tid)
ranges.SetActive(tid)

rekorSigner, err := signer.New(ctx, viper.GetString("rekor_server.signer"),
viper.GetString("rekor_server.signer-passwd"),
viper.GetString("rekor_server.tink_kek_uri"),
viper.GetString("rekor_server.tink_keyset_path"),
)
if err != nil {
return nil, fmt.Errorf("getting new signer: %w", err)
}
pk, err := rekorSigner.PublicKey(options.WithContext(ctx))
if err != nil {
return nil, fmt.Errorf("getting public key: %w", err)
shardingConfig := viper.GetString("trillian_log_server.sharding_config")
signingConfig := signer.SigningConfig{
SigningSchemeOrKeyPath: viper.GetString("rekor_server.signer"),
FileSignerPassword: viper.GetString("rekor_server.signer-passwd"),
TinkKEKURI: viper.GetString("rekor_server.tink_kek_uri"),
TinkKeysetPath: viper.GetString("rekor_server.tink_keyset_path"),
}
b, err := x509.MarshalPKIXPublicKey(pk)
ranges, err := sharding.NewLogRanges(ctx, logClient, shardingConfig, tid, signingConfig)
if err != nil {
return nil, fmt.Errorf("marshalling public key: %w", err)
return nil, fmt.Errorf("unable get sharding details from sharding config: %w", err)
}
pubkeyHashBytes := sha256.Sum256(b)

pubkey := cryptoutils.PEMEncode(cryptoutils.PublicKeyPEMType, b)

var newEntryPublisher pubsub.Publisher
if p := viper.GetString("rekor_server.new_entry_publisher"); p != "" {
Expand All @@ -170,12 +147,8 @@ func NewAPI(treeID uint) (*API, error) {
return &API{
// Transparency Log Stuff
logClient: logClient,
logID: tid,
treeID: tid,
logRanges: ranges,
// Signing/verifying fields
pubkey: string(pubkey),
pubkeyHash: hex.EncodeToString(pubkeyHashBytes[:]),
signer: rekorSigner,
// Utility functionality not required for operation of the core service
newEntryPublisher: newEntryPublisher,
}, nil
Expand Down Expand Up @@ -212,8 +185,8 @@ func ConfigureAPI(treeID uint) {

if viper.GetBool("enable_stable_checkpoint") {
redisClient = NewRedisClient()
checkpointPublisher := witness.NewCheckpointPublisher(context.Background(), api.logClient, api.logRanges.ActiveTreeID(),
viper.GetString("rekor_server.hostname"), api.signer, redisClient, viper.GetUint("publish_frequency"), CheckpointPublishCount)
checkpointPublisher := witness.NewCheckpointPublisher(context.Background(), api.logClient, api.logRanges.GetActive().TreeID,
viper.GetString("rekor_server.hostname"), api.logRanges.GetActive().Signer, redisClient, viper.GetUint("publish_frequency"), CheckpointPublishCount)

// create context to cancel goroutine on server shutdown
ctx, cancel := context.WithCancel(context.Background())
Expand Down
35 changes: 20 additions & 15 deletions pkg/api/entries.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func signEntry(ctx context.Context, signer signature.Signer, entry models.LogEnt
}

// logEntryFromLeaf creates a signed LogEntry struct from trillian structs
func logEntryFromLeaf(ctx context.Context, signer signature.Signer, _ trillianclient.TrillianClient, leaf *trillian.LogLeaf,
func logEntryFromLeaf(ctx context.Context, _ trillianclient.TrillianClient, leaf *trillian.LogLeaf,
signedLogRoot *trillian.SignedLogRoot, proof *trillian.Proof, tid int64, ranges sharding.LogRanges) (models.LogEntry, error) {

log.ContextLogger(ctx).Debugf("log entry from leaf %d", leaf.GetLeafIndex())
Expand All @@ -88,19 +88,24 @@ func logEntryFromLeaf(ctx context.Context, signer signature.Signer, _ trilliancl
}

virtualIndex := sharding.VirtualLogIndex(leaf.GetLeafIndex(), tid, ranges)
logRange, err := ranges.GetLogRangeByTreeID(tid)
if err != nil {
return nil, err
}

logEntryAnon := models.LogEntryAnon{
LogID: swag.String(api.pubkeyHash),
LogID: swag.String(logRange.PemPubKey),
LogIndex: &virtualIndex,
Body: leaf.LeafValue,
IntegratedTime: swag.Int64(leaf.IntegrateTimestamp.AsTime().Unix()),
}

signature, err := signEntry(ctx, signer, logEntryAnon)
signature, err := signEntry(ctx, logRange.Signer, logEntryAnon)
if err != nil {
return nil, fmt.Errorf("signing entry error: %w", err)
}

scBytes, err := util.CreateAndSignCheckpoint(ctx, viper.GetString("rekor_server.hostname"), tid, root.TreeSize, root.RootHash, api.signer)
scBytes, err := util.CreateAndSignCheckpoint(ctx, viper.GetString("rekor_server.hostname"), tid, root.TreeSize, root.RootHash, logRange.Signer)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -194,7 +199,7 @@ func createLogEntry(params entries.CreateLogEntryParams) (models.LogEntry, middl
return nil, handleRekorAPIError(params, http.StatusInternalServerError, err, failedToGenerateCanonicalEntry)
}

tc := trillianclient.NewTrillianClient(ctx, api.logClient, api.logID)
tc := trillianclient.NewTrillianClient(ctx, api.logClient, api.treeID)

resp := tc.AddLeaf(leaf)
// this represents overall GRPC response state (not the results of insertion into the log)
Expand All @@ -209,7 +214,7 @@ func createLogEntry(params entries.CreateLogEntryParams) (models.LogEntry, middl
case int32(code.Code_OK):
case int32(code.Code_ALREADY_EXISTS), int32(code.Code_FAILED_PRECONDITION):
existingUUID := hex.EncodeToString(rfc6962.DefaultHasher.HashLeaf(leaf))
activeTree := fmt.Sprintf("%x", api.logID)
activeTree := fmt.Sprintf("%x", api.treeID)
entryIDstruct, err := sharding.CreateEntryIDFromParts(activeTree, existingUUID)
if err != nil {
err := fmt.Errorf("error creating EntryID from active treeID %v and uuid %v: %w", activeTree, existingUUID, err)
Expand All @@ -230,7 +235,7 @@ func createLogEntry(params entries.CreateLogEntryParams) (models.LogEntry, middl
queuedLeaf := resp.GetAddResult.QueuedLeaf.Leaf

uuid := hex.EncodeToString(queuedLeaf.GetMerkleLeafHash())
activeTree := fmt.Sprintf("%x", api.logID)
activeTree := fmt.Sprintf("%x", api.treeID)
entryIDstruct, err := sharding.CreateEntryIDFromParts(activeTree, uuid)
if err != nil {
err := fmt.Errorf("error creating EntryID from active treeID %v and uuid %v: %w", activeTree, uuid, err)
Expand All @@ -239,9 +244,9 @@ func createLogEntry(params entries.CreateLogEntryParams) (models.LogEntry, middl
entryID := entryIDstruct.ReturnEntryIDString()

// The log index should be the virtual log index across all shards
virtualIndex := sharding.VirtualLogIndex(queuedLeaf.LeafIndex, api.logRanges.ActiveTreeID(), api.logRanges)
virtualIndex := sharding.VirtualLogIndex(queuedLeaf.LeafIndex, api.logRanges.GetActive().TreeID, api.logRanges)
logEntryAnon := models.LogEntryAnon{
LogID: swag.String(api.pubkeyHash),
LogID: swag.String(api.logRanges.GetActive().LogID),
LogIndex: swag.Int64(virtualIndex),
Body: queuedLeaf.GetLeafValue(),
IntegratedTime: swag.Int64(queuedLeaf.IntegrateTimestamp.AsTime().Unix()),
Expand Down Expand Up @@ -286,7 +291,7 @@ func createLogEntry(params entries.CreateLogEntryParams) (models.LogEntry, middl
}
}

signature, err := signEntry(ctx, api.signer, logEntryAnon)
signature, err := signEntry(ctx, api.logRanges.GetActive().Signer, logEntryAnon)
if err != nil {
return nil, handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("signing entry error: %w", err), signingError)
}
Expand All @@ -300,7 +305,7 @@ func createLogEntry(params entries.CreateLogEntryParams) (models.LogEntry, middl
hashes = append(hashes, hex.EncodeToString(hash))
}

scBytes, err := util.CreateAndSignCheckpoint(ctx, viper.GetString("rekor_server.hostname"), api.logID, root.TreeSize, root.RootHash, api.signer)
scBytes, err := util.CreateAndSignCheckpoint(ctx, viper.GetString("rekor_server.hostname"), api.treeID, root.TreeSize, root.RootHash, api.logRanges.GetActive().Signer)
if err != nil {
return nil, handleRekorAPIError(params, http.StatusInternalServerError, err, sthGenerateError)
}
Expand Down Expand Up @@ -511,7 +516,7 @@ func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Respo
continue
}
tcs := trillianclient.NewTrillianClient(httpReqCtx, api.logClient, shard)
logEntry, err := logEntryFromLeaf(httpReqCtx, api.signer, tcs, leafResp.Leaf, leafResp.SignedLogRoot, leafResp.Proof, shard, api.logRanges)
logEntry, err := logEntryFromLeaf(httpReqCtx, tcs, leafResp.Leaf, leafResp.SignedLogRoot, leafResp.Proof, shard, api.logRanges)
if err != nil {
return handleRekorAPIError(params, http.StatusInternalServerError, err, err.Error())
}
Expand Down Expand Up @@ -558,7 +563,7 @@ func retrieveLogEntryByIndex(ctx context.Context, logIndex int) (models.LogEntry
return models.LogEntry{}, ErrNotFound
}

return logEntryFromLeaf(ctx, api.signer, tc, leaf, result.SignedLogRoot, result.Proof, tid, api.logRanges)
return logEntryFromLeaf(ctx, tc, leaf, result.SignedLogRoot, result.Proof, tid, api.logRanges)
}

// Retrieve a Log Entry
Expand All @@ -580,7 +585,7 @@ func retrieveLogEntry(ctx context.Context, entryUUID string) (models.LogEntry, e

// If we got a UUID instead of an EntryID, search all shards
if errors.Is(err, sharding.ErrPlainUUID) {
trees := []sharding.LogRange{{TreeID: api.logRanges.ActiveTreeID()}}
trees := []sharding.LogRange{api.logRanges.GetActive()}
trees = append(trees, api.logRanges.GetInactive()...)

for _, t := range trees {
Expand Down Expand Up @@ -623,7 +628,7 @@ func retrieveUUIDFromTree(ctx context.Context, uuid string, tid int64) (models.L
return models.LogEntry{}, err
}

logEntry, err := logEntryFromLeaf(ctx, api.signer, tc, result.Leaf, result.SignedLogRoot, result.Proof, tid, api.logRanges)
logEntry, err := logEntryFromLeaf(ctx, tc, result.Leaf, result.SignedLogRoot, result.Proof, tid, api.logRanges)
if err != nil {
return models.LogEntry{}, fmt.Errorf("could not create log entry from leaf: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/public_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (

func GetPublicKeyHandler(params pubkey.GetPublicKeyParams) middleware.Responder {
treeID := swag.StringValue(params.TreeID)
pk, err := api.logRanges.PublicKey(api.pubkey, treeID)
pk, err := api.logRanges.PublicKey(treeID)
if err != nil {
return handleRekorAPIError(params, http.StatusBadRequest, err, "")
}
Expand Down
22 changes: 10 additions & 12 deletions pkg/api/tlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,18 @@ import (
"github.com/sigstore/rekor/pkg/log"
"github.com/sigstore/rekor/pkg/trillianclient"
"github.com/sigstore/rekor/pkg/util"
"github.com/sigstore/sigstore/pkg/signature"
)

// GetLogInfoHandler returns the current size of the tree and the STH
func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
tc := trillianclient.NewTrillianClient(params.HTTPRequest.Context(), api.logClient, api.logID)
tc := trillianclient.NewTrillianClient(params.HTTPRequest.Context(), api.logClient, api.treeID)

// for each inactive shard, get the loginfo
var inactiveShards []*models.InactiveShardLogInfo
for _, shard := range api.logRanges.GetInactive() {
if shard.TreeID == api.logRanges.ActiveTreeID() {
break
}
// Get details for this inactive shard
is, err := inactiveShardLogInfo(params.HTTPRequest.Context(), shard.TreeID)
is, err := inactiveShardLogInfo(params.HTTPRequest.Context(), shard.TreeID, shard.Signer)
if err != nil {
return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("inactive shard error: %w", err), unexpectedInactiveShardError)
}
Expand All @@ -55,7 +53,7 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {

if swag.BoolValue(params.Stable) && redisClient != nil {
// key is treeID/latest
key := fmt.Sprintf("%d/latest", api.logRanges.ActiveTreeID())
key := fmt.Sprintf("%d/latest", api.logRanges.GetActive().TreeID)
redisResult, err := redisClient.Get(params.HTTPRequest.Context(), key).Result()
if err != nil {
return handleRekorAPIError(params, http.StatusInternalServerError,
Expand All @@ -79,7 +77,7 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
RootHash: stringPointer(hex.EncodeToString(checkpoint.Hash)),
TreeSize: swag.Int64(int64(checkpoint.Size)),
SignedTreeHead: stringPointer(string(decoded)),
TreeID: stringPointer(fmt.Sprintf("%d", api.logID)),
TreeID: stringPointer(fmt.Sprintf("%d", api.treeID)),
InactiveShards: inactiveShards,
}
return tlog.NewGetLogInfoOK().WithPayload(&logInfo)
Expand All @@ -100,7 +98,7 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
treeSize := int64(root.TreeSize)

scBytes, err := util.CreateAndSignCheckpoint(params.HTTPRequest.Context(),
viper.GetString("rekor_server.hostname"), api.logRanges.ActiveTreeID(), root.TreeSize, root.RootHash, api.signer)
viper.GetString("rekor_server.hostname"), api.logRanges.GetActive().TreeID, root.TreeSize, root.RootHash, api.logRanges.GetActive().Signer)
if err != nil {
return handleRekorAPIError(params, http.StatusInternalServerError, err, sthGenerateError)
}
Expand All @@ -109,7 +107,7 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
RootHash: &hashString,
TreeSize: &treeSize,
SignedTreeHead: stringPointer(string(scBytes)),
TreeID: stringPointer(fmt.Sprintf("%d", api.logID)),
TreeID: stringPointer(fmt.Sprintf("%d", api.treeID)),
InactiveShards: inactiveShards,
}

Expand All @@ -126,7 +124,7 @@ func GetLogProofHandler(params tlog.GetLogProofParams) middleware.Responder {
errMsg := fmt.Sprintf(firstSizeLessThanLastSize, *params.FirstSize, params.LastSize)
return handleRekorAPIError(params, http.StatusBadRequest, fmt.Errorf("consistency proof: %s", errMsg), errMsg)
}
tc := trillianclient.NewTrillianClient(params.HTTPRequest.Context(), api.logClient, api.logID)
tc := trillianclient.NewTrillianClient(params.HTTPRequest.Context(), api.logClient, api.treeID)
if treeID := swag.StringValue(params.TreeID); treeID != "" {
id, err := strconv.Atoi(treeID)
if err != nil {
Expand Down Expand Up @@ -170,7 +168,7 @@ func GetLogProofHandler(params tlog.GetLogProofParams) middleware.Responder {
return tlog.NewGetLogProofOK().WithPayload(&consistencyProof)
}

func inactiveShardLogInfo(ctx context.Context, tid int64) (*models.InactiveShardLogInfo, error) {
func inactiveShardLogInfo(ctx context.Context, tid int64, signer signature.Signer) (*models.InactiveShardLogInfo, error) {
tc := trillianclient.NewTrillianClient(ctx, api.logClient, tid)
resp := tc.GetLatest(0)
if resp.Status != codes.OK {
Expand All @@ -186,7 +184,7 @@ func inactiveShardLogInfo(ctx context.Context, tid int64) (*models.InactiveShard
hashString := hex.EncodeToString(root.RootHash)
treeSize := int64(root.TreeSize)

scBytes, err := util.CreateAndSignCheckpoint(ctx, viper.GetString("rekor_server.hostname"), tid, root.TreeSize, root.RootHash, api.signer)
scBytes, err := util.CreateAndSignCheckpoint(ctx, viper.GetString("rekor_server.hostname"), tid, root.TreeSize, root.RootHash, signer)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/sharding/log_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func VirtualLogIndex(leafIndex int64, tid int64, ranges LogRanges) int64 {
// if we have no inactive ranges, we have just one log! return the leafIndex as is
// as long as it matches the active tree ID
if ranges.NoInactive() {
if ranges.GetActive() == tid {
if ranges.GetActive().TreeID == tid {
return leafIndex
}
return -1
Expand All @@ -34,7 +34,7 @@ func VirtualLogIndex(leafIndex int64, tid int64, ranges LogRanges) int64 {
}

// If no TreeID in Inactive matches the tid, the virtual index should be the active tree
if ranges.GetActive() == tid {
if ranges.GetActive().TreeID == tid {
return virtualIndex + leafIndex
}

Expand Down
8 changes: 4 additions & 4 deletions pkg/sharding/log_index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestVirtualLogIndex(t *testing.T) {
TreeID: 100,
TreeLength: 5,
}},
active: 300,
active: LogRange{TreeID: 300},
},
expectedIndex: 7,
},
Expand All @@ -64,7 +64,7 @@ func TestVirtualLogIndex(t *testing.T) {
TreeID: 300,
TreeLength: 4,
}},
active: 400,
active: LogRange{TreeID: 400},
},
expectedIndex: 6,
},
Expand All @@ -74,15 +74,15 @@ func TestVirtualLogIndex(t *testing.T) {
leafIndex: 2,
tid: 30,
ranges: LogRanges{
active: 30,
active: LogRange{TreeID: 30},
},
expectedIndex: 2,
}, {
description: "invalid tid passed in",
leafIndex: 2,
tid: 4,
ranges: LogRanges{
active: 30,
active: LogRange{TreeID: 30},
},
expectedIndex: -1,
},
Expand Down
Loading

0 comments on commit 052a08a

Please sign in to comment.