From 98cb3987f617aeca4a89c9f53d36d7492372c978 Mon Sep 17 00:00:00 2001 From: Hayden Blauzvern Date: Wed, 10 May 2023 21:35:41 +0000 Subject: [PATCH] Fix goroutine leak, check if redis client is configured Signed-off-by: Hayden Blauzvern --- pkg/api/api.go | 12 +++++++- pkg/api/tlog.go | 3 +- .../restapi/configure_rekor_server.go | 4 ++- pkg/witness/publish_checkpoint.go | 10 +++++-- pkg/witness/publish_checkpoint_test.go | 29 ++++++++++++++----- 5 files changed, 45 insertions(+), 13 deletions(-) diff --git a/pkg/api/api.go b/pkg/api/api.go index 102e2de7d..12925b6bd 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -61,6 +61,8 @@ type API struct { pubkey string // PEM encoded public key pubkeyHash string // SHA256 hash of DER-encoded public key signer signature.Signer + // stops checkpoint publishing + checkpointPublishCancel context.CancelFunc } func NewAPI(treeID uint) (*API, error) { @@ -154,6 +156,14 @@ func ConfigureAPI(treeID uint) { if viper.GetBool("enable_stable_checkpoint") { checkpointPublisher := witness.NewCheckpointPublisher(context.Background(), api.logClient, api.logRanges.ActiveTreeID(), viper.GetString("rekor_server.hostname"), api.signer, redisClient, viper.GetUint("publish_frequency"), CheckpointPublishCount) - checkpointPublisher.StartPublisher() + + // create context to cancel goroutine on server shutdown + ctx, cancel := context.WithCancel(context.Background()) + api.checkpointPublishCancel = cancel + checkpointPublisher.StartPublisher(ctx) } } + +func StopAPI() { + api.checkpointPublishCancel() +} diff --git a/pkg/api/tlog.go b/pkg/api/tlog.go index 3bb01c6a0..ba5309c80 100644 --- a/pkg/api/tlog.go +++ b/pkg/api/tlog.go @@ -53,7 +53,7 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder { inactiveShards = append(inactiveShards, is) } - if swag.BoolValue(params.Stable) { + if swag.BoolValue(params.Stable) && redisClient != nil { // key is treeID/latest key := fmt.Sprintf("%d/latest", api.logRanges.ActiveTreeID()) redisResult, err := redisClient.Get(params.HTTPRequest.Context(), key).Result() @@ -61,6 +61,7 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder { return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("error getting checkpoint from redis: %w", err), "error getting checkpoint from redis") } + // should not occur, a checkpoint should always be present if redisResult == "" { return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("no checkpoint found in redis: %w", err), "no checkpoint found in redis") diff --git a/pkg/generated/restapi/configure_rekor_server.go b/pkg/generated/restapi/configure_rekor_server.go index 2041f0759..b66a0577f 100644 --- a/pkg/generated/restapi/configure_rekor_server.go +++ b/pkg/generated/restapi/configure_rekor_server.go @@ -156,7 +156,9 @@ func configureAPI(api *operations.RekorServerAPI) http.Handler { api.RegisterFormat("signedCheckpoint", &util.SignedNote{}, util.SignedCheckpointValidator) api.PreServerShutdown = func() {} - api.ServerShutdown = func() {} + api.ServerShutdown = func() { + pkgapi.StopAPI() + } return setupGlobalMiddleware(api.Serve(setupMiddlewares)) } diff --git a/pkg/witness/publish_checkpoint.go b/pkg/witness/publish_checkpoint.go index 6b8b91662..8f946ff35 100644 --- a/pkg/witness/publish_checkpoint.go +++ b/pkg/witness/publish_checkpoint.go @@ -80,7 +80,7 @@ func NewCheckpointPublisher(ctx context.Context, // or Verifiers monitoring for fresh checkpoints. Failure can occur after a lock is obtained but // before publishing the latest checkpoint. If this occurs due to a sporadic failure, this simply // means that a witness will not see a fresh checkpoint for an additional period. -func (c *CheckpointPublisher) StartPublisher() { +func (c *CheckpointPublisher) StartPublisher(ctx context.Context) { tc := trillianclient.NewTrillianClient(context.Background(), c.logClient, c.treeID) sTreeID := strconv.FormatInt(c.treeID, 10) @@ -90,8 +90,12 @@ func (c *CheckpointPublisher) StartPublisher() { ticker := time.NewTicker(time.Duration(c.checkpointFreq) * time.Minute) go func() { for { - <-ticker.C - c.publish(&tc, sTreeID) + select { + case <-ctx.Done(): + return + case <-ticker.C: + c.publish(&tc, sTreeID) + } } }() } diff --git a/pkg/witness/publish_checkpoint_test.go b/pkg/witness/publish_checkpoint_test.go index faa51164d..d99d150db 100644 --- a/pkg/witness/publish_checkpoint_test.go +++ b/pkg/witness/publish_checkpoint_test.go @@ -67,7 +67,10 @@ func TestPublishCheckpoint(t *testing.T) { mock.Regexp().ExpectSet(fmt.Sprintf("%d/latest", treeID), "[0-9a-fA-F]+", 0).SetVal("OK") publisher := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClient, uint(freq), counter) - publisher.StartPublisher() + + ctx, cancel := context.WithCancel(context.Background()) + publisher.StartPublisher(ctx) + defer cancel() // wait for initial publish time.Sleep(1 * time.Second) @@ -119,12 +122,16 @@ func TestPublishCheckpointMultiple(t *testing.T) { mock.Regexp().ExpectSet(fmt.Sprintf("%d/latest", treeID), "[0-9a-fA-F]+", 0).SetVal("OK") publisher := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClient, uint(freq), counter) - publisher.StartPublisher() + ctx, cancel := context.WithCancel(context.Background()) + publisher.StartPublisher(ctx) + defer cancel() redisClientEx, mockEx := redismock.NewClientMock() mockEx.Regexp().ExpectSetNX(fmt.Sprintf("%d/%d", treeID, ts), "[0-9a-fA-F]+", 0).SetVal(false) publisherEx := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClientEx, uint(freq), counter) - publisherEx.StartPublisher() + ctxEx, cancelEx := context.WithCancel(context.Background()) + publisherEx.StartPublisher(ctxEx) + defer cancelEx() // wait for initial publish time.Sleep(1 * time.Second) @@ -169,7 +176,9 @@ func TestPublishCheckpointTrillianError(t *testing.T) { redisClient, _ := redismock.NewClientMock() publisher := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClient, uint(freq), counter) - publisher.StartPublisher() + ctx, cancel := context.WithCancel(context.Background()) + publisher.StartPublisher(ctx) + defer cancel() // wait for initial publish time.Sleep(1 * time.Second) @@ -204,7 +213,9 @@ func TestPublishCheckpointInvalidTrillianResponse(t *testing.T) { redisClient, _ := redismock.NewClientMock() publisher := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClient, uint(freq), counter) - publisher.StartPublisher() + ctx, cancel := context.WithCancel(context.Background()) + publisher.StartPublisher(ctx) + defer cancel() // wait for initial publish time.Sleep(1 * time.Second) @@ -246,7 +257,9 @@ func TestPublishCheckpointRedisFailure(t *testing.T) { mock.Regexp().ExpectSetNX(".+", "[0-9a-fA-F]+", 0).SetErr(errors.New("redis error")) publisher := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClient, uint(freq), counter) - publisher.StartPublisher() + ctx, cancel := context.WithCancel(context.Background()) + publisher.StartPublisher(ctx) + defer cancel() // wait for initial publish time.Sleep(1 * time.Second) @@ -289,7 +302,9 @@ func TestPublishCheckpointRedisLatestFailure(t *testing.T) { mock.Regexp().ExpectSet(".*", "[0-9a-fA-F]+", 0).SetErr(errors.New("error")) publisher := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClient, uint(freq), counter) - publisher.StartPublisher() + ctx, cancel := context.WithCancel(context.Background()) + publisher.StartPublisher(ctx) + defer cancel() // wait for initial publish time.Sleep(1 * time.Second)