diff --git a/challenge/challenge.go b/challenge/challenge.go index 2a0049b..d1b0ee4 100644 --- a/challenge/challenge.go +++ b/challenge/challenge.go @@ -2,6 +2,7 @@ package challenge import ( + "context" "crypto/x509" "errors" @@ -16,8 +17,8 @@ type Store interface { } // Middleware wraps next in a CSRSigner that verifies and invalidates the challenge -func Middleware(store Store, next scepserver.CSRSigner) scepserver.CSRSignerFunc { - return func(m *scep.CSRReqMessage) (*x509.Certificate, error) { +func Middleware(store Store, next scepserver.CSRSignerContext) scepserver.CSRSignerContextFunc { + return func(ctx context.Context, m *scep.CSRReqMessage) (*x509.Certificate, error) { // TODO: compare challenge only for PKCSReq? valid, err := store.HasChallenge(m.ChallengePassword) if err != nil { @@ -26,6 +27,6 @@ func Middleware(store Store, next scepserver.CSRSigner) scepserver.CSRSignerFunc if !valid { return nil, errors.New("invalid challenge") } - return next.SignCSR(m) + return next.SignCSRContext(ctx, m) } } diff --git a/challenge/challenge_bolt_test.go b/challenge/challenge_bolt_test.go index 003acbd..1a34cb5 100644 --- a/challenge/challenge_bolt_test.go +++ b/challenge/challenge_bolt_test.go @@ -1,6 +1,7 @@ package challenge import ( + "context" "io/ioutil" "os" "testing" @@ -69,12 +70,14 @@ func TestDynamicChallenge(t *testing.T) { ChallengePassword: challengePassword, } - _, err = signer.SignCSR(csrReq) + ctx := context.Background() + + _, err = signer.SignCSRContext(ctx, csrReq) if err != nil { t.Error(err) } - _, err = signer.SignCSR(csrReq) + _, err = signer.SignCSRContext(ctx, csrReq) if err == nil { t.Error("challenge should not be valid twice") } diff --git a/cmd/scepserver/scepserver.go b/cmd/scepserver/scepserver.go index 4e1ade5..97c0170 100644 --- a/cmd/scepserver/scepserver.go +++ b/cmd/scepserver/scepserver.go @@ -147,9 +147,9 @@ func main() { if *flSignServerAttrs { signerOpts = append(signerOpts, scepdepot.WithSeverAttrs()) } - var signer scepserver.CSRSigner = scepdepot.NewSigner(depot, signerOpts...) + var signer scepserver.CSRSignerContext = scepserver.SignCSRAdapter(scepdepot.NewSigner(depot, signerOpts...)) if *flChallengePassword != "" { - signer = scepserver.ChallengeMiddleware(*flChallengePassword, signer) + signer = scepserver.StaticChallengeMiddleware(*flChallengePassword, signer) } if csrVerifier != nil { signer = csrverifier.Middleware(csrVerifier, signer) diff --git a/cryptoutil/cryptoutil_test.go b/cryptoutil/cryptoutil_test.go index ab83c2e..53a73ee 100644 --- a/cryptoutil/cryptoutil_test.go +++ b/cryptoutil/cryptoutil_test.go @@ -4,18 +4,23 @@ import ( "crypto" "crypto/ecdsa" "crypto/elliptic" + "crypto/rand" "crypto/rsa" "math/big" "testing" ) func TestGenerateSubjectKeyID(t *testing.T) { + ecKey, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + if err != nil { + t.Fatal(err) + } for _, test := range []struct { testName string pub crypto.PublicKey }{ {"RSA", &rsa.PublicKey{N: big.NewInt(123), E: 65537}}, - {"ECDSA", &ecdsa.PublicKey{X: big.NewInt(123), Y: big.NewInt(123), Curve: elliptic.P224()}}, + {"ECDSA", ecKey.Public()}, } { test := test t.Run(test.testName, func(t *testing.T) { diff --git a/csrverifier/csrverifier.go b/csrverifier/csrverifier.go index bfc350b..da6f5aa 100644 --- a/csrverifier/csrverifier.go +++ b/csrverifier/csrverifier.go @@ -2,6 +2,7 @@ package csrverifier import ( + "context" "crypto/x509" "errors" @@ -15,8 +16,8 @@ type CSRVerifier interface { } // Middleware wraps next in a CSRSigner that runs verifier -func Middleware(verifier CSRVerifier, next scepserver.CSRSigner) scepserver.CSRSignerFunc { - return func(m *scep.CSRReqMessage) (*x509.Certificate, error) { +func Middleware(verifier CSRVerifier, next scepserver.CSRSignerContext) scepserver.CSRSignerContextFunc { + return func(ctx context.Context, m *scep.CSRReqMessage) (*x509.Certificate, error) { ok, err := verifier.Verify(m.RawDecrypted) if err != nil { return nil, err @@ -24,6 +25,6 @@ func Middleware(verifier CSRVerifier, next scepserver.CSRSigner) scepserver.CSRS if !ok { return nil, errors.New("CSR verify failed") } - return next.SignCSR(m) + return next.SignCSRContext(ctx, m) } } diff --git a/server/csrsigner.go b/server/csrsigner.go index 604776f..1ddfad0 100644 --- a/server/csrsigner.go +++ b/server/csrsigner.go @@ -1,6 +1,7 @@ package scepserver import ( + "context" "crypto/subtle" "crypto/x509" "errors" @@ -8,6 +9,22 @@ import ( "github.com/micromdm/scep/v2/scep" ) +// CSRSignerContext is a handler for signing CSRs by a CA/RA. +// +// SignCSRContext should take the CSR in the CSRReqMessage and return a +// Certificate signed by the CA. +type CSRSignerContext interface { + SignCSRContext(context.Context, *scep.CSRReqMessage) (*x509.Certificate, error) +} + +// CSRSignerContextFunc is an adapter for CSR signing by the CA/RA. +type CSRSignerContextFunc func(context.Context, *scep.CSRReqMessage) (*x509.Certificate, error) + +// SignCSR calls f(ctx, m). +func (f CSRSignerContextFunc) SignCSRContext(ctx context.Context, m *scep.CSRReqMessage) (*x509.Certificate, error) { + return f(ctx, m) +} + // CSRSigner is a handler for CSR signing by the CA/RA // // SignCSR should take the CSR in the CSRReqMessage and return a @@ -16,29 +33,36 @@ type CSRSigner interface { SignCSR(*scep.CSRReqMessage) (*x509.Certificate, error) } -// CSRSignerFunc is an adapter for CSR signing by the CA/RA +// CSRSignerFunc is an adapter for CSR signing by the CA/RA. type CSRSignerFunc func(*scep.CSRReqMessage) (*x509.Certificate, error) -// SignCSR calls f(m) +// SignCSR calls f(m). func (f CSRSignerFunc) SignCSR(m *scep.CSRReqMessage) (*x509.Certificate, error) { return f(m) } -// NopCSRSigner does nothing -func NopCSRSigner() CSRSignerFunc { - return func(m *scep.CSRReqMessage) (*x509.Certificate, error) { +// NopCSRSigner does nothing. +func NopCSRSigner() CSRSignerContextFunc { + return func(_ context.Context, _ *scep.CSRReqMessage) (*x509.Certificate, error) { return nil, nil } } -// ChallengeMiddleware wraps next in a CSRSigner that validates the challenge from the CSR -func ChallengeMiddleware(challenge string, next CSRSigner) CSRSignerFunc { +// StaticChallengeMiddleware wraps next and validates the challenge from the CSR. +func StaticChallengeMiddleware(challenge string, next CSRSignerContext) CSRSignerContextFunc { challengeBytes := []byte(challenge) - return func(m *scep.CSRReqMessage) (*x509.Certificate, error) { + return func(ctx context.Context, m *scep.CSRReqMessage) (*x509.Certificate, error) { // TODO: compare challenge only for PKCSReq? if subtle.ConstantTimeCompare(challengeBytes, []byte(m.ChallengePassword)) != 1 { return nil, errors.New("invalid challenge") } + return next.SignCSRContext(ctx, m) + } +} + +// SignCSRAdapter adapts a next (i.e. no context) to a context signer. +func SignCSRAdapter(next CSRSigner) CSRSignerContextFunc { + return func(_ context.Context, m *scep.CSRReqMessage) (*x509.Certificate, error) { return next.SignCSR(m) } } diff --git a/server/csrsigner_test.go b/server/csrsigner_test.go index b4c17bf..62696de 100644 --- a/server/csrsigner_test.go +++ b/server/csrsigner_test.go @@ -1,6 +1,7 @@ package scepserver import ( + "context" "testing" "github.com/micromdm/scep/v2/scep" @@ -8,18 +9,20 @@ import ( func TestChallengeMiddleware(t *testing.T) { testPW := "RIGHT" - signer := ChallengeMiddleware(testPW, NopCSRSigner()) + signer := StaticChallengeMiddleware(testPW, NopCSRSigner()) csrReq := &scep.CSRReqMessage{ChallengePassword: testPW} - _, err := signer.SignCSR(csrReq) + ctx := context.Background() + + _, err := signer.SignCSRContext(ctx, csrReq) if err != nil { t.Error(err) } csrReq.ChallengePassword = "WRONG" - _, err = signer.SignCSR(csrReq) + _, err = signer.SignCSRContext(ctx, csrReq) if err == nil { t.Error("invalid challenge should generate an error") } diff --git a/server/service.go b/server/service.go index 58ef85e..b20eb47 100644 --- a/server/service.go +++ b/server/service.go @@ -47,7 +47,7 @@ type service struct { // The (chainable) CSR signing function. Intended to handle all // SCEP request functionality such as CSR & challenge checking, CA // issuance, RA proxying, etc. - signer CSRSigner + signer CSRSignerContext /// info logging is implemented in the service middleware layer. debugLogger log.Logger @@ -80,7 +80,7 @@ func (svc *service) PKIOperation(ctx context.Context, data []byte) ([]byte, erro return nil, err } - crt, err := svc.signer.SignCSR(msg.CSRReqMessage) + crt, err := svc.signer.SignCSRContext(ctx, msg.CSRReqMessage) if err == nil && crt == nil { err = errors.New("no signed certificate") } @@ -119,7 +119,7 @@ func WithAddlCA(ca *x509.Certificate) ServiceOption { } // NewService creates a new scep service -func NewService(crt *x509.Certificate, key *rsa.PrivateKey, signer CSRSigner, opts ...ServiceOption) (Service, error) { +func NewService(crt *x509.Certificate, key *rsa.PrivateKey, signer CSRSignerContext, opts ...ServiceOption) (Service, error) { s := &service{ crt: crt, key: key, diff --git a/server/service_bolt_test.go b/server/service_bolt_test.go index 9bd40b0..3c5f46c 100644 --- a/server/service_bolt_test.go +++ b/server/service_bolt_test.go @@ -46,7 +46,7 @@ func TestCaCert(t *testing.T) { caCert := certs[0] // SCEP service - svc, err := scepserver.NewService(caCert, key, scepdepot.NewSigner(depot)) + svc, err := scepserver.NewService(caCert, key, scepserver.SignCSRAdapter(scepdepot.NewSigner(depot))) if err != nil { t.Fatal(err) }