Skip to content

Commit

Permalink
Fix goroutine leak, check if redis client is configured
Browse files Browse the repository at this point in the history
Signed-off-by: Hayden Blauzvern <hblauzvern@google.com>
  • Loading branch information
haydentherapper committed May 10, 2023
1 parent 7556390 commit 98cb398
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 13 deletions.
12 changes: 11 additions & 1 deletion pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ type API struct {
pubkey string // PEM encoded public key
pubkeyHash string // SHA256 hash of DER-encoded public key
signer signature.Signer
// stops checkpoint publishing
checkpointPublishCancel context.CancelFunc
}

func NewAPI(treeID uint) (*API, error) {
Expand Down Expand Up @@ -154,6 +156,14 @@ func ConfigureAPI(treeID uint) {
if viper.GetBool("enable_stable_checkpoint") {
checkpointPublisher := witness.NewCheckpointPublisher(context.Background(), api.logClient, api.logRanges.ActiveTreeID(),
viper.GetString("rekor_server.hostname"), api.signer, redisClient, viper.GetUint("publish_frequency"), CheckpointPublishCount)
checkpointPublisher.StartPublisher()

// create context to cancel goroutine on server shutdown
ctx, cancel := context.WithCancel(context.Background())
api.checkpointPublishCancel = cancel
checkpointPublisher.StartPublisher(ctx)
}
}

func StopAPI() {
api.checkpointPublishCancel()
}
3 changes: 2 additions & 1 deletion pkg/api/tlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,15 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
inactiveShards = append(inactiveShards, is)
}

if swag.BoolValue(params.Stable) {
if swag.BoolValue(params.Stable) && redisClient != nil {
// key is treeID/latest
key := fmt.Sprintf("%d/latest", api.logRanges.ActiveTreeID())
redisResult, err := redisClient.Get(params.HTTPRequest.Context(), key).Result()
if err != nil {
return handleRekorAPIError(params, http.StatusInternalServerError,
fmt.Errorf("error getting checkpoint from redis: %w", err), "error getting checkpoint from redis")
}
// should not occur, a checkpoint should always be present
if redisResult == "" {
return handleRekorAPIError(params, http.StatusInternalServerError,
fmt.Errorf("no checkpoint found in redis: %w", err), "no checkpoint found in redis")
Expand Down
4 changes: 3 additions & 1 deletion pkg/generated/restapi/configure_rekor_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ func configureAPI(api *operations.RekorServerAPI) http.Handler {
api.RegisterFormat("signedCheckpoint", &util.SignedNote{}, util.SignedCheckpointValidator)

api.PreServerShutdown = func() {}
api.ServerShutdown = func() {}
api.ServerShutdown = func() {
pkgapi.StopAPI()
}

return setupGlobalMiddleware(api.Serve(setupMiddlewares))
}
Expand Down
10 changes: 7 additions & 3 deletions pkg/witness/publish_checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func NewCheckpointPublisher(ctx context.Context,
// or Verifiers monitoring for fresh checkpoints. Failure can occur after a lock is obtained but
// before publishing the latest checkpoint. If this occurs due to a sporadic failure, this simply
// means that a witness will not see a fresh checkpoint for an additional period.
func (c *CheckpointPublisher) StartPublisher() {
func (c *CheckpointPublisher) StartPublisher(ctx context.Context) {
tc := trillianclient.NewTrillianClient(context.Background(), c.logClient, c.treeID)
sTreeID := strconv.FormatInt(c.treeID, 10)

Expand All @@ -90,8 +90,12 @@ func (c *CheckpointPublisher) StartPublisher() {
ticker := time.NewTicker(time.Duration(c.checkpointFreq) * time.Minute)
go func() {
for {
<-ticker.C
c.publish(&tc, sTreeID)
select {
case <-ctx.Done():
return
case <-ticker.C:
c.publish(&tc, sTreeID)
}
}
}()
}
Expand Down
29 changes: 22 additions & 7 deletions pkg/witness/publish_checkpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ func TestPublishCheckpoint(t *testing.T) {
mock.Regexp().ExpectSet(fmt.Sprintf("%d/latest", treeID), "[0-9a-fA-F]+", 0).SetVal("OK")

publisher := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClient, uint(freq), counter)
publisher.StartPublisher()

ctx, cancel := context.WithCancel(context.Background())
publisher.StartPublisher(ctx)
defer cancel()

// wait for initial publish
time.Sleep(1 * time.Second)
Expand Down Expand Up @@ -119,12 +122,16 @@ func TestPublishCheckpointMultiple(t *testing.T) {
mock.Regexp().ExpectSet(fmt.Sprintf("%d/latest", treeID), "[0-9a-fA-F]+", 0).SetVal("OK")

publisher := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClient, uint(freq), counter)
publisher.StartPublisher()
ctx, cancel := context.WithCancel(context.Background())
publisher.StartPublisher(ctx)
defer cancel()

redisClientEx, mockEx := redismock.NewClientMock()
mockEx.Regexp().ExpectSetNX(fmt.Sprintf("%d/%d", treeID, ts), "[0-9a-fA-F]+", 0).SetVal(false)
publisherEx := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClientEx, uint(freq), counter)
publisherEx.StartPublisher()
ctxEx, cancelEx := context.WithCancel(context.Background())
publisherEx.StartPublisher(ctxEx)
defer cancelEx()

// wait for initial publish
time.Sleep(1 * time.Second)
Expand Down Expand Up @@ -169,7 +176,9 @@ func TestPublishCheckpointTrillianError(t *testing.T) {
redisClient, _ := redismock.NewClientMock()

publisher := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClient, uint(freq), counter)
publisher.StartPublisher()
ctx, cancel := context.WithCancel(context.Background())
publisher.StartPublisher(ctx)
defer cancel()

// wait for initial publish
time.Sleep(1 * time.Second)
Expand Down Expand Up @@ -204,7 +213,9 @@ func TestPublishCheckpointInvalidTrillianResponse(t *testing.T) {
redisClient, _ := redismock.NewClientMock()

publisher := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClient, uint(freq), counter)
publisher.StartPublisher()
ctx, cancel := context.WithCancel(context.Background())
publisher.StartPublisher(ctx)
defer cancel()

// wait for initial publish
time.Sleep(1 * time.Second)
Expand Down Expand Up @@ -246,7 +257,9 @@ func TestPublishCheckpointRedisFailure(t *testing.T) {
mock.Regexp().ExpectSetNX(".+", "[0-9a-fA-F]+", 0).SetErr(errors.New("redis error"))

publisher := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClient, uint(freq), counter)
publisher.StartPublisher()
ctx, cancel := context.WithCancel(context.Background())
publisher.StartPublisher(ctx)
defer cancel()

// wait for initial publish
time.Sleep(1 * time.Second)
Expand Down Expand Up @@ -289,7 +302,9 @@ func TestPublishCheckpointRedisLatestFailure(t *testing.T) {
mock.Regexp().ExpectSet(".*", "[0-9a-fA-F]+", 0).SetErr(errors.New("error"))

publisher := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClient, uint(freq), counter)
publisher.StartPublisher()
ctx, cancel := context.WithCancel(context.Background())
publisher.StartPublisher(ctx)
defer cancel()

// wait for initial publish
time.Sleep(1 * time.Second)
Expand Down

0 comments on commit 98cb398

Please sign in to comment.