diff --git a/cmd/rekor-server/app/root.go b/cmd/rekor-server/app/root.go index df313365c..c91013fe9 100644 --- a/cmd/rekor-server/app/root.go +++ b/cmd/rekor-server/app/root.go @@ -80,6 +80,9 @@ func init() { rootCmd.PersistentFlags().Uint("trillian_log_server.tlog_id", 0, "Trillian tree id") rootCmd.PersistentFlags().String("trillian_log_server.sharding_config", "", "path to config file for inactive shards, in JSON or YAML") + rootCmd.PersistentFlags().Bool("enable_stable_checkpoint", true, "publish stable checkpoints to Redis. When disabled, gossiping may not be possible if the log checkpoint updates too frequently") + rootCmd.PersistentFlags().Uint("publish_frequency", 5, "how often to publish a new checkpoint, in minutes") + hostname, err := os.Hostname() if err != nil { hostname = "localhost" diff --git a/docker-compose.yml b/docker-compose.yml index 69d6b0bc1..fb4c9914d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -92,6 +92,7 @@ services: "--rekor_server.signer=memory", "--enable_attestation_storage", "--attestation_storage_bucket=file:///var/run/attestations", + "--enable_stable_checkpoint", # Uncomment this for production logging # "--log_type=prod", ] diff --git a/go.mod b/go.mod index d03ded8a6..4f9109201 100644 --- a/go.mod +++ b/go.mod @@ -53,6 +53,7 @@ require ( require ( github.com/AdamKorcz/go-fuzz-headers-1 v0.0.0-20230329111138-12e09aba5ebd github.com/cyberphone/json-canonicalization v0.0.0-20220623050100-57a0ce2678a7 + github.com/go-redis/redismock/v9 v9.0.3 github.com/golang/mock v1.6.0 github.com/hashicorp/go-cleanhttp v0.5.2 github.com/hashicorp/go-retryablehttp v0.7.2 @@ -64,6 +65,7 @@ require ( cloud.google.com/go/compute/metadata v0.2.3 // indirect filippo.io/edwards25519 v1.0.0 // indirect github.com/cyphar/filepath-securejoin v0.2.3 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/go-logr/logr v1.2.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect diff --git a/go.sum b/go.sum index a9774043b..0d88b09f4 100644 --- a/go.sum +++ b/go.sum @@ -1024,6 +1024,8 @@ github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GO github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= github.com/go-playground/validator/v10 v10.13.0 h1:cFRQdfaSMCOSfGCCLB20MHvuoHb/s5G8L5pu2ppK5AQ= github.com/go-playground/validator/v10 v10.13.0/go.mod h1:dwu7+CG8/CtBiJFZDz4e+5Upb6OLw04gtBYw0mcG/z4= +github.com/go-redis/redismock/v9 v9.0.3 h1:mtHQi2l51lCmXIbTRTqb1EiHYe9tL5Yk5oorlSJJqR0= +github.com/go-redis/redismock/v9 v9.0.3/go.mod h1:F6tJRfnU8R/NZ0E+Gjvoluk14MqMC5ueSZX6vVQypc0= github.com/go-resty/resty/v2 v2.1.1-0.20191201195748-d7b97669fe48/go.mod h1:dZGr0i9PLlaaTD4H/hoZIDjQ+r6xq8mgbRzHZf7f2J8= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= @@ -1596,6 +1598,7 @@ github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OS github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/oklog/run v1.1.0/go.mod h1:sVPdnTZT1zYwAJeCMu2Th4T21pA3FPOQRfWjQlk7DVU= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= @@ -1615,6 +1618,7 @@ github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1ls github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= github.com/onsi/ginkgo v1.16.2/go.mod h1:CObGmKUOKaSC0RjmoAK7tKyn4Azo5P2IWuoMnvwxz1E= github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo/v2 v2.1.3/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c= github.com/onsi/ginkgo/v2 v2.1.4/go.mod h1:um6tUpWM/cxCK3/FK8BXqEiUMUwRgSM4JXG47RKZmLU= github.com/onsi/ginkgo/v2 v2.1.6/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk= @@ -1637,6 +1641,7 @@ github.com/onsi/gomega v1.20.1/go.mod h1:DtrZpjmvpn2mPm4YWQa0/ALMDj9v4YxLgojwPeR github.com/onsi/gomega v1.21.1/go.mod h1:iYAIXgPSaDHak0LCMA+AWBpIKBr8WZicMxnE8luStNc= github.com/onsi/gomega v1.22.1/go.mod h1:x6n7VNe4hw0vkyYUM4mjIXx3JbLiPaBPNgB7PRQ1tuM= github.com/onsi/gomega v1.23.0/go.mod h1:Z/NWtiqwBrwUt4/2loMmHL63EDLnYHmVbuBpDr2vQAg= +github.com/onsi/gomega v1.25.0 h1:Vw7br2PCDYijJHSfBOWhov+8cAnUf8MfMaIOV323l6Y= github.com/op/go-logging v0.0.0-20160315200505-970db520ece7/go.mod h1:HzydrMdWErDVzsI23lYNej1Htcns9BCg93Dk0bBINWk= github.com/opencontainers/go-digest v0.0.0-20170106003457-a6d0ee40d420/go.mod h1:cMLVZDEM3+U2I4VmLI6N8jQYUd2OVphdqWwCJHrFt2s= github.com/opencontainers/go-digest v0.0.0-20180430190053-c9281466c8b2/go.mod h1:cMLVZDEM3+U2I4VmLI6N8jQYUd2OVphdqWwCJHrFt2s= @@ -2925,6 +2930,7 @@ gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/telebot.v3 v3.0.0/go.mod h1:7rExV8/0mDDNu9epSrDm/8j22KLaActH1Tbee6YjzWg= gopkg.in/telebot.v3 v3.1.2/go.mod h1:GJKwwWqp9nSkIVN51eRKU78aB5f5OnQuWdwiIZfPbko= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= diff --git a/openapi.yaml b/openapi.yaml index d90503f96..5b5e960f1 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -67,6 +67,12 @@ paths: operationId: getLogInfo tags: - tlog + parameters: + - in: query + name: stable + type: boolean + default: false + description: Whether to return a stable checkpoint for the active shard responses: 200: description: A JSON object with the root hash and tree size as properties diff --git a/pkg/api/api.go b/pkg/api/api.go index bfec95c0a..12925b6bd 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -35,6 +35,7 @@ import ( "github.com/sigstore/rekor/pkg/signer" "github.com/sigstore/rekor/pkg/storage" "github.com/sigstore/rekor/pkg/trillianclient" + "github.com/sigstore/rekor/pkg/witness" "github.com/sigstore/sigstore/pkg/cryptoutils" "github.com/sigstore/sigstore/pkg/signature" "github.com/sigstore/sigstore/pkg/signature/options" @@ -60,6 +61,8 @@ type API struct { pubkey string // PEM encoded public key pubkeyHash string // SHA256 hash of DER-encoded public key signer signature.Signer + // stops checkpoint publishing + checkpointPublishCancel context.CancelFunc } func NewAPI(treeID uint) (*API, error) { @@ -134,7 +137,8 @@ func ConfigureAPI(treeID uint) { if err != nil { log.Logger.Panic(err) } - if viper.GetBool("enable_retrieve_api") || slices.Contains(viper.GetStringSlice("enabled_api_endpoints"), "searchIndex") { + if viper.GetBool("enable_retrieve_api") || viper.GetBool("enable_stable_checkpoint") || + slices.Contains(viper.GetStringSlice("enabled_api_endpoints"), "searchIndex") { redisClient = redis.NewClient(&redis.Options{ Addr: fmt.Sprintf("%v:%v", viper.GetString("redis_server.address"), viper.GetUint64("redis_server.port")), Network: "tcp", @@ -148,4 +152,18 @@ func ConfigureAPI(treeID uint) { log.Logger.Panic(err) } } + + if viper.GetBool("enable_stable_checkpoint") { + checkpointPublisher := witness.NewCheckpointPublisher(context.Background(), api.logClient, api.logRanges.ActiveTreeID(), + viper.GetString("rekor_server.hostname"), api.signer, redisClient, viper.GetUint("publish_frequency"), CheckpointPublishCount) + + // create context to cancel goroutine on server shutdown + ctx, cancel := context.WithCancel(context.Background()) + api.checkpointPublishCancel = cancel + checkpointPublisher.StartPublisher(ctx) + } +} + +func StopAPI() { + api.checkpointPublishCancel() } diff --git a/pkg/api/metrics.go b/pkg/api/metrics.go index 36d0d324a..8f0efb9d6 100644 --- a/pkg/api/metrics.go +++ b/pkg/api/metrics.go @@ -53,6 +53,11 @@ var ( Help: "Api QPS by path, method, and response code", }, []string{"path", "method", "code"}) + CheckpointPublishCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "rekor_checkpoint_publish", + Help: "Checkpoint publishing by shard and code", + }, []string{"shard", "code"}) + _ = promauto.NewGaugeFunc( prometheus.GaugeOpts{ Namespace: "rekor", diff --git a/pkg/api/tlog.go b/pkg/api/tlog.go index 33da70b15..ba5309c80 100644 --- a/pkg/api/tlog.go +++ b/pkg/api/tlog.go @@ -53,6 +53,38 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder { inactiveShards = append(inactiveShards, is) } + if swag.BoolValue(params.Stable) && redisClient != nil { + // key is treeID/latest + key := fmt.Sprintf("%d/latest", api.logRanges.ActiveTreeID()) + redisResult, err := redisClient.Get(params.HTTPRequest.Context(), key).Result() + if err != nil { + return handleRekorAPIError(params, http.StatusInternalServerError, + fmt.Errorf("error getting checkpoint from redis: %w", err), "error getting checkpoint from redis") + } + // should not occur, a checkpoint should always be present + if redisResult == "" { + return handleRekorAPIError(params, http.StatusInternalServerError, + fmt.Errorf("no checkpoint found in redis: %w", err), "no checkpoint found in redis") + } + decoded, err := hex.DecodeString(redisResult) + if err != nil { + return handleRekorAPIError(params, http.StatusInternalServerError, + fmt.Errorf("error decoding checkpoint from redis: %w", err), "error decoding checkpoint from redis") + } + checkpoint := util.SignedCheckpoint{} + if err := checkpoint.UnmarshalText(decoded); err != nil { + return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("invalid checkpoint: %w", err), "invalid checkpoint") + } + logInfo := models.LogInfo{ + RootHash: stringPointer(hex.EncodeToString(checkpoint.Hash)), + TreeSize: swag.Int64(int64(checkpoint.Size)), + SignedTreeHead: stringPointer(string(decoded)), + TreeID: stringPointer(fmt.Sprintf("%d", api.logID)), + InactiveShards: inactiveShards, + } + return tlog.NewGetLogInfoOK().WithPayload(&logInfo) + } + resp := tc.GetLatest(0) if resp.Status != codes.OK { return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.Err), trillianCommunicationError) diff --git a/pkg/generated/client/tlog/get_log_info_parameters.go b/pkg/generated/client/tlog/get_log_info_parameters.go index e0ae2cdd3..b2e329427 100644 --- a/pkg/generated/client/tlog/get_log_info_parameters.go +++ b/pkg/generated/client/tlog/get_log_info_parameters.go @@ -30,6 +30,7 @@ import ( "github.com/go-openapi/runtime" cr "github.com/go-openapi/runtime/client" "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" ) // NewGetLogInfoParams creates a new GetLogInfoParams object, @@ -76,6 +77,13 @@ GetLogInfoParams contains all the parameters to send to the API endpoint Typically these are written to a http.Request. */ type GetLogInfoParams struct { + + /* Stable. + + Whether to return a stable checkpoint for the active shard + */ + Stable *bool + timeout time.Duration Context context.Context HTTPClient *http.Client @@ -93,7 +101,18 @@ func (o *GetLogInfoParams) WithDefaults() *GetLogInfoParams { // // All values with no default are reset to their zero value. func (o *GetLogInfoParams) SetDefaults() { - // no default values defined for this parameter + var ( + stableDefault = bool(false) + ) + + val := GetLogInfoParams{ + Stable: &stableDefault, + } + + val.timeout = o.timeout + val.Context = o.Context + val.HTTPClient = o.HTTPClient + *o = val } // WithTimeout adds the timeout to the get log info params @@ -129,6 +148,17 @@ func (o *GetLogInfoParams) SetHTTPClient(client *http.Client) { o.HTTPClient = client } +// WithStable adds the stable to the get log info params +func (o *GetLogInfoParams) WithStable(stable *bool) *GetLogInfoParams { + o.SetStable(stable) + return o +} + +// SetStable adds the stable to the get log info params +func (o *GetLogInfoParams) SetStable(stable *bool) { + o.Stable = stable +} + // WriteToRequest writes these params to a swagger request func (o *GetLogInfoParams) WriteToRequest(r runtime.ClientRequest, reg strfmt.Registry) error { @@ -137,6 +167,23 @@ func (o *GetLogInfoParams) WriteToRequest(r runtime.ClientRequest, reg strfmt.Re } var res []error + if o.Stable != nil { + + // query param stable + var qrStable bool + + if o.Stable != nil { + qrStable = *o.Stable + } + qStable := swag.FormatBool(qrStable) + if qStable != "" { + + if err := r.SetQueryParam("stable", qStable); err != nil { + return err + } + } + } + if len(res) > 0 { return errors.CompositeValidationError(res...) } diff --git a/pkg/generated/restapi/configure_rekor_server.go b/pkg/generated/restapi/configure_rekor_server.go index 2041f0759..b66a0577f 100644 --- a/pkg/generated/restapi/configure_rekor_server.go +++ b/pkg/generated/restapi/configure_rekor_server.go @@ -156,7 +156,9 @@ func configureAPI(api *operations.RekorServerAPI) http.Handler { api.RegisterFormat("signedCheckpoint", &util.SignedNote{}, util.SignedCheckpointValidator) api.PreServerShutdown = func() {} - api.ServerShutdown = func() {} + api.ServerShutdown = func() { + pkgapi.StopAPI() + } return setupGlobalMiddleware(api.Serve(setupMiddlewares)) } diff --git a/pkg/generated/restapi/embedded_spec.go b/pkg/generated/restapi/embedded_spec.go index 9466b1ffd..70d596679 100644 --- a/pkg/generated/restapi/embedded_spec.go +++ b/pkg/generated/restapi/embedded_spec.go @@ -99,6 +99,15 @@ func init() { ], "summary": "Get information about the current state of the transparency log", "operationId": "getLogInfo", + "parameters": [ + { + "type": "boolean", + "default": false, + "description": "Whether to return a stable checkpoint for the active shard", + "name": "stable", + "in": "query" + } + ], "responses": { "200": { "description": "A JSON object with the root hash and tree size as properties", @@ -995,6 +1004,15 @@ func init() { ], "summary": "Get information about the current state of the transparency log", "operationId": "getLogInfo", + "parameters": [ + { + "type": "boolean", + "default": false, + "description": "Whether to return a stable checkpoint for the active shard", + "name": "stable", + "in": "query" + } + ], "responses": { "200": { "description": "A JSON object with the root hash and tree size as properties", diff --git a/pkg/generated/restapi/operations/tlog/get_log_info_parameters.go b/pkg/generated/restapi/operations/tlog/get_log_info_parameters.go index f6b7e151b..72fc76472 100644 --- a/pkg/generated/restapi/operations/tlog/get_log_info_parameters.go +++ b/pkg/generated/restapi/operations/tlog/get_log_info_parameters.go @@ -25,15 +25,25 @@ import ( "net/http" "github.com/go-openapi/errors" + "github.com/go-openapi/runtime" "github.com/go-openapi/runtime/middleware" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" ) // NewGetLogInfoParams creates a new GetLogInfoParams object -// -// There are no default values defined in the spec. +// with the default values initialized. func NewGetLogInfoParams() GetLogInfoParams { - return GetLogInfoParams{} + var ( + // initialize parameters with default values + + stableDefault = bool(false) + ) + + return GetLogInfoParams{ + Stable: &stableDefault, + } } // GetLogInfoParams contains all the bound params for the get log info operation @@ -44,6 +54,12 @@ type GetLogInfoParams struct { // HTTP Request Object HTTPRequest *http.Request `json:"-"` + + /*Whether to return a stable checkpoint for the active shard + In: query + Default: false + */ + Stable *bool } // BindRequest both binds and validates a request, it assumes that complex things implement a Validatable(strfmt.Registry) error interface @@ -55,8 +71,38 @@ func (o *GetLogInfoParams) BindRequest(r *http.Request, route *middleware.Matche o.HTTPRequest = r + qs := runtime.Values(r.URL.Query()) + + qStable, qhkStable, _ := qs.GetOK("stable") + if err := o.bindStable(qStable, qhkStable, route.Formats); err != nil { + res = append(res, err) + } if len(res) > 0 { return errors.CompositeValidationError(res...) } return nil } + +// bindStable binds and validates parameter Stable from query. +func (o *GetLogInfoParams) bindStable(rawData []string, hasKey bool, formats strfmt.Registry) error { + var raw string + if len(rawData) > 0 { + raw = rawData[len(rawData)-1] + } + + // Required: false + // AllowEmptyValue: false + + if raw == "" { // empty values pass all other validations + // Default values have been previously initialized by NewGetLogInfoParams() + return nil + } + + value, err := swag.ConvertBool(raw) + if err != nil { + return errors.InvalidType("stable", "query", "bool", raw) + } + o.Stable = &value + + return nil +} diff --git a/pkg/generated/restapi/operations/tlog/get_log_info_urlbuilder.go b/pkg/generated/restapi/operations/tlog/get_log_info_urlbuilder.go index e344fac34..19587e476 100644 --- a/pkg/generated/restapi/operations/tlog/get_log_info_urlbuilder.go +++ b/pkg/generated/restapi/operations/tlog/get_log_info_urlbuilder.go @@ -25,11 +25,17 @@ import ( "errors" "net/url" golangswaggerpaths "path" + + "github.com/go-openapi/swag" ) // GetLogInfoURL generates an URL for the get log info operation type GetLogInfoURL struct { + Stable *bool + _basePath string + // avoid unkeyed usage + _ struct{} } // WithBasePath sets the base path for this url builder, only required when it's different from the @@ -56,6 +62,18 @@ func (o *GetLogInfoURL) Build() (*url.URL, error) { _basePath := o._basePath _result.Path = golangswaggerpaths.Join(_basePath, _path) + qs := make(url.Values) + + var stableQ string + if o.Stable != nil { + stableQ = swag.FormatBool(*o.Stable) + } + if stableQ != "" { + qs.Set("stable", stableQ) + } + + _result.RawQuery = qs.Encode() + return &_result, nil } diff --git a/pkg/witness/mockclient/generate.go b/pkg/witness/mockclient/generate.go new file mode 100644 index 000000000..a3be36d65 --- /dev/null +++ b/pkg/witness/mockclient/generate.go @@ -0,0 +1,18 @@ +// Copyright 2023 The Sigstore Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package mockclient provides a mockable version of the Trillian log client API. +package mockclient + +//go:generate mockgen -package mockclient -destination mock_log_client.go github.com/google/trillian TrillianLogClient diff --git a/pkg/witness/mockclient/mock_log_client.go b/pkg/witness/mockclient/mock_log_client.go new file mode 100644 index 000000000..33859f9f7 --- /dev/null +++ b/pkg/witness/mockclient/mock_log_client.go @@ -0,0 +1,217 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/google/trillian (interfaces: TrillianLogClient) + +// Package mockclient is a generated GoMock package. +package mockclient + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + trillian "github.com/google/trillian" + grpc "google.golang.org/grpc" +) + +// MockTrillianLogClient is a mock of TrillianLogClient interface. +type MockTrillianLogClient struct { + ctrl *gomock.Controller + recorder *MockTrillianLogClientMockRecorder +} + +// MockTrillianLogClientMockRecorder is the mock recorder for MockTrillianLogClient. +type MockTrillianLogClientMockRecorder struct { + mock *MockTrillianLogClient +} + +// NewMockTrillianLogClient creates a new mock instance. +func NewMockTrillianLogClient(ctrl *gomock.Controller) *MockTrillianLogClient { + mock := &MockTrillianLogClient{ctrl: ctrl} + mock.recorder = &MockTrillianLogClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTrillianLogClient) EXPECT() *MockTrillianLogClientMockRecorder { + return m.recorder +} + +// AddSequencedLeaves mocks base method. +func (m *MockTrillianLogClient) AddSequencedLeaves(arg0 context.Context, arg1 *trillian.AddSequencedLeavesRequest, arg2 ...grpc.CallOption) (*trillian.AddSequencedLeavesResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "AddSequencedLeaves", varargs...) + ret0, _ := ret[0].(*trillian.AddSequencedLeavesResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddSequencedLeaves indicates an expected call of AddSequencedLeaves. +func (mr *MockTrillianLogClientMockRecorder) AddSequencedLeaves(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddSequencedLeaves", reflect.TypeOf((*MockTrillianLogClient)(nil).AddSequencedLeaves), varargs...) +} + +// GetConsistencyProof mocks base method. +func (m *MockTrillianLogClient) GetConsistencyProof(arg0 context.Context, arg1 *trillian.GetConsistencyProofRequest, arg2 ...grpc.CallOption) (*trillian.GetConsistencyProofResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetConsistencyProof", varargs...) + ret0, _ := ret[0].(*trillian.GetConsistencyProofResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetConsistencyProof indicates an expected call of GetConsistencyProof. +func (mr *MockTrillianLogClientMockRecorder) GetConsistencyProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConsistencyProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetConsistencyProof), varargs...) +} + +// GetEntryAndProof mocks base method. +func (m *MockTrillianLogClient) GetEntryAndProof(arg0 context.Context, arg1 *trillian.GetEntryAndProofRequest, arg2 ...grpc.CallOption) (*trillian.GetEntryAndProofResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetEntryAndProof", varargs...) + ret0, _ := ret[0].(*trillian.GetEntryAndProofResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetEntryAndProof indicates an expected call of GetEntryAndProof. +func (mr *MockTrillianLogClientMockRecorder) GetEntryAndProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEntryAndProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetEntryAndProof), varargs...) +} + +// GetInclusionProof mocks base method. +func (m *MockTrillianLogClient) GetInclusionProof(arg0 context.Context, arg1 *trillian.GetInclusionProofRequest, arg2 ...grpc.CallOption) (*trillian.GetInclusionProofResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetInclusionProof", varargs...) + ret0, _ := ret[0].(*trillian.GetInclusionProofResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetInclusionProof indicates an expected call of GetInclusionProof. +func (mr *MockTrillianLogClientMockRecorder) GetInclusionProof(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProof", reflect.TypeOf((*MockTrillianLogClient)(nil).GetInclusionProof), varargs...) +} + +// GetInclusionProofByHash mocks base method. +func (m *MockTrillianLogClient) GetInclusionProofByHash(arg0 context.Context, arg1 *trillian.GetInclusionProofByHashRequest, arg2 ...grpc.CallOption) (*trillian.GetInclusionProofByHashResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetInclusionProofByHash", varargs...) + ret0, _ := ret[0].(*trillian.GetInclusionProofByHashResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetInclusionProofByHash indicates an expected call of GetInclusionProofByHash. +func (mr *MockTrillianLogClientMockRecorder) GetInclusionProofByHash(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInclusionProofByHash", reflect.TypeOf((*MockTrillianLogClient)(nil).GetInclusionProofByHash), varargs...) +} + +// GetLatestSignedLogRoot mocks base method. +func (m *MockTrillianLogClient) GetLatestSignedLogRoot(arg0 context.Context, arg1 *trillian.GetLatestSignedLogRootRequest, arg2 ...grpc.CallOption) (*trillian.GetLatestSignedLogRootResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetLatestSignedLogRoot", varargs...) + ret0, _ := ret[0].(*trillian.GetLatestSignedLogRootResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetLatestSignedLogRoot indicates an expected call of GetLatestSignedLogRoot. +func (mr *MockTrillianLogClientMockRecorder) GetLatestSignedLogRoot(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestSignedLogRoot", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLatestSignedLogRoot), varargs...) +} + +// GetLeavesByRange mocks base method. +func (m *MockTrillianLogClient) GetLeavesByRange(arg0 context.Context, arg1 *trillian.GetLeavesByRangeRequest, arg2 ...grpc.CallOption) (*trillian.GetLeavesByRangeResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetLeavesByRange", varargs...) + ret0, _ := ret[0].(*trillian.GetLeavesByRangeResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetLeavesByRange indicates an expected call of GetLeavesByRange. +func (mr *MockTrillianLogClientMockRecorder) GetLeavesByRange(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLeavesByRange", reflect.TypeOf((*MockTrillianLogClient)(nil).GetLeavesByRange), varargs...) +} + +// InitLog mocks base method. +func (m *MockTrillianLogClient) InitLog(arg0 context.Context, arg1 *trillian.InitLogRequest, arg2 ...grpc.CallOption) (*trillian.InitLogResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "InitLog", varargs...) + ret0, _ := ret[0].(*trillian.InitLogResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// InitLog indicates an expected call of InitLog. +func (mr *MockTrillianLogClientMockRecorder) InitLog(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitLog", reflect.TypeOf((*MockTrillianLogClient)(nil).InitLog), varargs...) +} + +// QueueLeaf mocks base method. +func (m *MockTrillianLogClient) QueueLeaf(arg0 context.Context, arg1 *trillian.QueueLeafRequest, arg2 ...grpc.CallOption) (*trillian.QueueLeafResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueueLeaf", varargs...) + ret0, _ := ret[0].(*trillian.QueueLeafResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueueLeaf indicates an expected call of QueueLeaf. +func (mr *MockTrillianLogClientMockRecorder) QueueLeaf(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueLeaf", reflect.TypeOf((*MockTrillianLogClient)(nil).QueueLeaf), varargs...) +} diff --git a/pkg/witness/publish_checkpoint.go b/pkg/witness/publish_checkpoint.go new file mode 100644 index 000000000..8f946ff35 --- /dev/null +++ b/pkg/witness/publish_checkpoint.go @@ -0,0 +1,199 @@ +// Copyright 2023 The Sigstore Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package witness + +import ( + "context" + "encoding/hex" + "fmt" + "strconv" + "time" + + "github.com/google/trillian" + "github.com/google/trillian/types" + "github.com/prometheus/client_golang/prometheus" + "github.com/redis/go-redis/v9" + "github.com/sigstore/rekor/pkg/log" + "github.com/sigstore/rekor/pkg/trillianclient" + "github.com/sigstore/rekor/pkg/util" + "github.com/sigstore/sigstore/pkg/signature" + "google.golang.org/grpc/codes" +) + +// CheckpointPublisher is a long-running job to periodically publish signed checkpoints to etc.d +type CheckpointPublisher struct { + ctx context.Context + // logClient is the client for Trillian + logClient trillian.TrillianLogClient + // treeID is used to construct the origin and configure the Trillian client + treeID int64 + // hostname is used to construct the origin ("hostname - treeID") + hostname string + // signer signs the checkpoint + signer signature.Signer + // publishFreq is how often a new checkpoint is published to Rekor, in minutes + checkpointFreq uint + // redisClient to upload signed checkpoints + redisClient *redis.Client + // reqCounter tracks successes and failures for publishing + reqCounter *prometheus.CounterVec +} + +// Constant values used with metrics +const ( + Success = iota + SuccessObtainLock + GetCheckpoint + UnmarshalCheckpoint + SignCheckpoint + RedisFailure + RedisLatestFailure +) + +// NewCheckpointPublisher creates a CheckpointPublisher to write stable checkpoints to Redis +func NewCheckpointPublisher(ctx context.Context, + logClient trillian.TrillianLogClient, + treeID int64, + hostname string, + signer signature.Signer, + redisClient *redis.Client, + checkpointFreq uint, + reqCounter *prometheus.CounterVec) CheckpointPublisher { + return CheckpointPublisher{ctx: ctx, logClient: logClient, treeID: treeID, hostname: hostname, + signer: signer, checkpointFreq: checkpointFreq, redisClient: redisClient, reqCounter: reqCounter} +} + +// StartPublisher creates a long-running task that publishes the latest checkpoint every X minutes +// Writing to Redis is best effort. Failure will be detected either through metrics or by witnesses +// or Verifiers monitoring for fresh checkpoints. Failure can occur after a lock is obtained but +// before publishing the latest checkpoint. If this occurs due to a sporadic failure, this simply +// means that a witness will not see a fresh checkpoint for an additional period. +func (c *CheckpointPublisher) StartPublisher(ctx context.Context) { + tc := trillianclient.NewTrillianClient(context.Background(), c.logClient, c.treeID) + sTreeID := strconv.FormatInt(c.treeID, 10) + + // publish on startup to ensure a checkpoint is available the first time Rekor starts up + c.publish(&tc, sTreeID) + + ticker := time.NewTicker(time.Duration(c.checkpointFreq) * time.Minute) + go func() { + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + c.publish(&tc, sTreeID) + } + } + }() +} + +// publish publishes the latest checkpoint to Redis once +func (c *CheckpointPublisher) publish(tc *trillianclient.TrillianClient, sTreeID string) { + // get latest checkpoint + resp := tc.GetLatest(0) + if resp.Status != codes.OK { + c.reqCounter.With( + map[string]string{ + "shard": sTreeID, + "code": strconv.Itoa(GetCheckpoint), + }).Inc() + log.Logger.Errorf("error getting latest checkpoint to publish: %v", resp.Status) + return + } + + // unmarshal checkpoint + root := &types.LogRootV1{} + if err := root.UnmarshalBinary(resp.GetLatestResult.SignedLogRoot.LogRoot); err != nil { + c.reqCounter.With( + map[string]string{ + "shard": sTreeID, + "code": strconv.Itoa(UnmarshalCheckpoint), + }).Inc() + log.Logger.Errorf("error unmarshalling latest checkpoint to publish: %v", err) + return + } + + // sign checkpoint with Rekor private key + checkpoint, err := util.CreateAndSignCheckpoint(context.Background(), c.hostname, c.treeID, root, c.signer) + if err != nil { + c.reqCounter.With( + map[string]string{ + "shard": sTreeID, + "code": strconv.Itoa(SignCheckpoint), + }).Inc() + log.Logger.Errorf("error signing checkpoint to publish: %v", err) + return + } + + // encode checkpoint as hex to write to redis + hexCP := hex.EncodeToString(checkpoint) + + // write checkpoint to Redis if key does not yet exist + // this prevents multiple instances of Rekor from writing different checkpoints in the same time window + ts := time.Now().Truncate(time.Duration(c.checkpointFreq) * time.Minute).UnixNano() + // key is treeID/timestamp, where timestamp is rounded down to the nearest X minutes + key := fmt.Sprintf("%d/%d", c.treeID, ts) + ctx, cancel := context.WithTimeout(c.ctx, 10*time.Second) + defer cancel() + + // return value ignored, which is whether or not the entry was set + // no error is thrown if the key already exists + successNX, err := c.redisClient.SetNX(ctx, key, hexCP, 0).Result() + if err != nil { + c.reqCounter.With( + map[string]string{ + "shard": sTreeID, + "code": strconv.Itoa(RedisFailure), + }).Inc() + log.Logger.Errorf("error with client publishing checkpoint: %v", err) + return + } + // if the key was not set, then the key already exists for this time period + if !successNX { + return + } + + // successful obtaining of lock for time period + c.reqCounter.With( + map[string]string{ + "shard": sTreeID, + "code": strconv.Itoa(SuccessObtainLock), + }).Inc() + + // on successfully obtaining the "lock" for the time window, update latest checkpoint + latestKey := fmt.Sprintf("%d/latest", c.treeID) + latestCtx, latestCancel := context.WithTimeout(c.ctx, 10*time.Second) + defer latestCancel() + + // return value ignored, which is whether or not the entry was set + // no error is thrown if the key already exists + if _, err = c.redisClient.Set(latestCtx, latestKey, hexCP, 0).Result(); err != nil { + c.reqCounter.With( + map[string]string{ + "shard": sTreeID, + "code": strconv.Itoa(RedisLatestFailure), + }).Inc() + log.Logger.Errorf("error with client publishing latest checkpoint: %v", err) + return + } + + // successful publish + c.reqCounter.With( + map[string]string{ + "shard": sTreeID, + "code": strconv.Itoa(Success), + }).Inc() +} diff --git a/pkg/witness/publish_checkpoint_test.go b/pkg/witness/publish_checkpoint_test.go new file mode 100644 index 000000000..c6453f35e --- /dev/null +++ b/pkg/witness/publish_checkpoint_test.go @@ -0,0 +1,324 @@ +// Copyright 2023 The Sigstore Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package witness + +import ( + "context" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "errors" + "fmt" + "testing" + "time" + + "github.com/go-redis/redismock/v9" + "github.com/golang/mock/gomock" + "github.com/google/trillian" + "github.com/google/trillian/types" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/sigstore/rekor/pkg/witness/mockclient" + "github.com/sigstore/sigstore/pkg/signature" + "go.uber.org/goleak" +) + +func TestPublishCheckpoint(t *testing.T) { + treeID := 1234 + hostname := "rekor-test" + freq := 1 + counter := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "rekor_checkpoint_publish", + Help: "Checkpoint publishing by shard and code", + }, []string{"shard", "code"}) + priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + signer, _ := signature.LoadSigner(priv, crypto.SHA256) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + root := &types.LogRootV1{TreeSize: 10, RootHash: []byte{1}, TimestampNanos: 123, Revision: 0} + mRoot, err := root.MarshalBinary() + if err != nil { + t.Fatalf("error marshalling log root: %v", err) + } + + mockTrillianLogClient := mockclient.NewMockTrillianLogClient(ctrl) + mockTrillianLogClient.EXPECT().GetLatestSignedLogRoot(gomock.Any(), &trillian.GetLatestSignedLogRootRequest{ + LogId: int64(treeID), + FirstTreeSize: 0, + }).Return(&trillian.GetLatestSignedLogRootResponse{SignedLogRoot: &trillian.SignedLogRoot{LogRoot: mRoot}}, nil) + + redisClient, mock := redismock.NewClientMock() + ts := time.Now().Truncate(time.Duration(freq) * time.Minute).UnixNano() + mock.Regexp().ExpectSetNX(fmt.Sprintf("%d/%d", treeID, ts), "[0-9a-fA-F]+", 0).SetVal(true) + mock.Regexp().ExpectSet(fmt.Sprintf("%d/latest", treeID), "[0-9a-fA-F]+", 0).SetVal("OK") + + publisher := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClient, uint(freq), counter) + + ctx, cancel := context.WithCancel(context.Background()) + publisher.StartPublisher(ctx) + defer cancel() + + // wait for initial publish + time.Sleep(1 * time.Second) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Error(err) + } + + if res := testutil.CollectAndCount(counter); res != 2 { + t.Fatalf("unexpected number of metrics: %d", res) + } + if res := testutil.ToFloat64(counter.WithLabelValues(fmt.Sprint(treeID), fmt.Sprint(Success))); res != 1.0 { + t.Fatalf("unexpected number of metrics: %2f", res) + } + if res := testutil.ToFloat64(counter.WithLabelValues(fmt.Sprint(treeID), fmt.Sprint(SuccessObtainLock))); res != 1.0 { + t.Fatalf("unexpected number of metrics: %2f", res) + } +} + +func TestPublishCheckpointMultiple(t *testing.T) { + treeID := 1234 + hostname := "rekor-test" + freq := 1 + counter := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "rekor_checkpoint_publish", + Help: "Checkpoint publishing by shard and code", + }, []string{"shard", "code"}) + priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + signer, _ := signature.LoadSigner(priv, crypto.SHA256) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + root := &types.LogRootV1{TreeSize: 10, RootHash: []byte{1}, TimestampNanos: 123, Revision: 0} + mRoot, err := root.MarshalBinary() + if err != nil { + t.Fatalf("error marshalling log root: %v", err) + } + + mockTrillianLogClient := mockclient.NewMockTrillianLogClient(ctrl) + mockTrillianLogClient.EXPECT().GetLatestSignedLogRoot(gomock.Any(), &trillian.GetLatestSignedLogRootRequest{ + LogId: int64(treeID), + FirstTreeSize: 0, + }).Return(&trillian.GetLatestSignedLogRootResponse{SignedLogRoot: &trillian.SignedLogRoot{LogRoot: mRoot}}, nil).MaxTimes(2) + + redisClient, mock := redismock.NewClientMock() + ts := time.Now().Truncate(time.Duration(freq) * time.Minute).UnixNano() + mock.Regexp().ExpectSetNX(fmt.Sprintf("%d/%d", treeID, ts), "[0-9a-fA-F]+", 0).SetVal(true) + mock.Regexp().ExpectSet(fmt.Sprintf("%d/latest", treeID), "[0-9a-fA-F]+", 0).SetVal("OK") + + publisher := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClient, uint(freq), counter) + ctx, cancel := context.WithCancel(context.Background()) + publisher.StartPublisher(ctx) + defer cancel() + + redisClientEx, mockEx := redismock.NewClientMock() + mockEx.Regexp().ExpectSetNX(fmt.Sprintf("%d/%d", treeID, ts), "[0-9a-fA-F]+", 0).SetVal(false) + publisherEx := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClientEx, uint(freq), counter) + ctxEx, cancelEx := context.WithCancel(context.Background()) + publisherEx.StartPublisher(ctxEx) + defer cancelEx() + + // wait for initial publish + time.Sleep(1 * time.Second) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Error(err) + } + if err := mockEx.ExpectationsWereMet(); err != nil { + t.Error(err) + } + + // only publishes once + if res := testutil.CollectAndCount(counter); res != 2 { + t.Fatalf("unexpected number of metrics: %d", res) + } + if res := testutil.ToFloat64(counter.WithLabelValues(fmt.Sprint(treeID), fmt.Sprint(Success))); res != 1.0 { + t.Fatalf("unexpected number of metrics: %2f", res) + } + if res := testutil.ToFloat64(counter.WithLabelValues(fmt.Sprint(treeID), fmt.Sprint(SuccessObtainLock))); res != 1.0 { + t.Fatalf("unexpected number of metrics: %2f", res) + } +} + +func TestPublishCheckpointTrillianError(t *testing.T) { + treeID := 1234 + hostname := "rekor-test" + freq := 1 + counter := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "rekor_checkpoint_publish", + Help: "Checkpoint publishing by shard and code", + }, []string{"shard", "code"}) + priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + signer, _ := signature.LoadSigner(priv, crypto.SHA256) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // return error + mockTrillianLogClient := mockclient.NewMockTrillianLogClient(ctrl) + mockTrillianLogClient.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()).Return(nil, errors.New("error: LatestSLR")) + + redisClient, _ := redismock.NewClientMock() + + publisher := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClient, uint(freq), counter) + ctx, cancel := context.WithCancel(context.Background()) + publisher.StartPublisher(ctx) + defer cancel() + + // wait for initial publish + time.Sleep(1 * time.Second) + + if res := testutil.CollectAndCount(counter); res != 1 { + t.Fatalf("unexpected number of metrics: %d", res) + } + if res := testutil.ToFloat64(counter.WithLabelValues(fmt.Sprint(treeID), fmt.Sprint(GetCheckpoint))); res != 1.0 { + t.Fatalf("unexpected number of metrics: %2f", res) + } +} + +func TestPublishCheckpointInvalidTrillianResponse(t *testing.T) { + treeID := 1234 + hostname := "rekor-test" + freq := 1 + counter := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "rekor_checkpoint_publish", + Help: "Checkpoint publishing by shard and code", + }, []string{"shard", "code"}) + priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + signer, _ := signature.LoadSigner(priv, crypto.SHA256) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // set no log root in response + mockTrillianLogClient := mockclient.NewMockTrillianLogClient(ctrl) + mockTrillianLogClient.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()). + Return(&trillian.GetLatestSignedLogRootResponse{SignedLogRoot: &trillian.SignedLogRoot{LogRoot: []byte{}}}, nil) + + redisClient, _ := redismock.NewClientMock() + + publisher := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClient, uint(freq), counter) + ctx, cancel := context.WithCancel(context.Background()) + publisher.StartPublisher(ctx) + defer cancel() + + // wait for initial publish + time.Sleep(1 * time.Second) + + if res := testutil.CollectAndCount(counter); res != 1 { + t.Fatalf("unexpected number of metrics: %d", res) + } + if res := testutil.ToFloat64(counter.WithLabelValues(fmt.Sprint(treeID), fmt.Sprint(UnmarshalCheckpoint))); res != 1.0 { + t.Fatalf("unexpected number of metrics: %2f", res) + } +} + +func TestPublishCheckpointRedisFailure(t *testing.T) { + treeID := 1234 + hostname := "rekor-test" + freq := 1 + counter := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "rekor_checkpoint_publish", + Help: "Checkpoint publishing by shard and code", + }, []string{"shard", "code"}) + priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + signer, _ := signature.LoadSigner(priv, crypto.SHA256) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + root := &types.LogRootV1{TreeSize: 10, RootHash: []byte{1}, TimestampNanos: 123, Revision: 0} + mRoot, err := root.MarshalBinary() + if err != nil { + t.Fatalf("error marshalling log root: %v", err) + } + + mockTrillianLogClient := mockclient.NewMockTrillianLogClient(ctrl) + mockTrillianLogClient.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()). + Return(&trillian.GetLatestSignedLogRootResponse{SignedLogRoot: &trillian.SignedLogRoot{LogRoot: mRoot}}, nil) + + redisClient, mock := redismock.NewClientMock() + // error on first redis call + mock.Regexp().ExpectSetNX(".+", "[0-9a-fA-F]+", 0).SetErr(errors.New("redis error")) + + publisher := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClient, uint(freq), counter) + ctx, cancel := context.WithCancel(context.Background()) + publisher.StartPublisher(ctx) + defer cancel() + + // wait for initial publish + time.Sleep(1 * time.Second) + + if res := testutil.CollectAndCount(counter); res != 1 { + t.Fatalf("unexpected number of metrics: %d", res) + } + if res := testutil.ToFloat64(counter.WithLabelValues(fmt.Sprint(treeID), fmt.Sprint(RedisFailure))); res != 1.0 { + t.Fatalf("unexpected number of metrics: %2f", res) + } +} + +func TestPublishCheckpointRedisLatestFailure(t *testing.T) { + treeID := 1234 + hostname := "rekor-test" + freq := 1 + counter := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "rekor_checkpoint_publish", + Help: "Checkpoint publishing by shard and code", + }, []string{"shard", "code"}) + priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + signer, _ := signature.LoadSigner(priv, crypto.SHA256) + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + root := &types.LogRootV1{TreeSize: 10, RootHash: []byte{1}, TimestampNanos: 123, Revision: 0} + mRoot, err := root.MarshalBinary() + if err != nil { + t.Fatalf("error marshalling log root: %v", err) + } + + mockTrillianLogClient := mockclient.NewMockTrillianLogClient(ctrl) + mockTrillianLogClient.EXPECT().GetLatestSignedLogRoot(gomock.Any(), gomock.Any()). + Return(&trillian.GetLatestSignedLogRootResponse{SignedLogRoot: &trillian.SignedLogRoot{LogRoot: mRoot}}, nil) + + redisClient, mock := redismock.NewClientMock() + mock.Regexp().ExpectSetNX(".+", "[0-9a-fA-F]+", 0).SetVal(true) + // error on second redis call + mock.Regexp().ExpectSet(".*", "[0-9a-fA-F]+", 0).SetErr(errors.New("error")) + + publisher := NewCheckpointPublisher(context.Background(), mockTrillianLogClient, int64(treeID), hostname, signer, redisClient, uint(freq), counter) + ctx, cancel := context.WithCancel(context.Background()) + publisher.StartPublisher(ctx) + defer cancel() + + // wait for initial publish + time.Sleep(1 * time.Second) + + // two metrics, one success for initial redis and one failure for latest + if res := testutil.CollectAndCount(counter); res != 2 { + t.Fatalf("unexpected number of metrics: %d", res) + } + if res := testutil.ToFloat64(counter.WithLabelValues(fmt.Sprint(treeID), fmt.Sprint(RedisLatestFailure))); res != 1.0 { + t.Fatalf("unexpected number of metrics: %2f", res) + } +} + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +}