Skip to content

Commit

Permalink
Merge ACME package back into the PKI package (#19826)
Browse files Browse the repository at this point in the history
* Squash pki/acme package down to pki folder

Without refactoring most of PKI to export the storage layer, which we
were initially hesitant about, it would be nearly impossible to have the
ACME layer handle its own storage while being in the acme/ subpackage
under the pki package.

Thus, merge the two packages together again.

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>

* Properly format errors for missing parameters

When missing required ACME request parameters, don't return Vault-level
errors, but drop into the PKI package to return properly-formatted ACME
error messages.

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>

* Error type clarifications

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>

* Fix GetOk with type conversion calls

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>

---------

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>
  • Loading branch information
cipherboy committed Mar 29, 2023
1 parent 7b40f73 commit 32e3cd6
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 56 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package acme
package pki

import (
"encoding/json"
Expand Down
24 changes: 12 additions & 12 deletions builtin/logical/pki/acme/jws.go → builtin/logical/pki/acme_jws.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package acme
package pki

import (
"crypto"
Expand All @@ -23,7 +23,7 @@ var AllowedOuterJWSTypes = map[string]interface{}{
}

// This wraps a JWS message structure.
type JWSCtx struct {
type jwsCtx struct {
Algo string `json:"alg"`
Kid string `json:"kid"`
jwk json.RawMessage `json:"jwk"`
Expand All @@ -33,7 +33,7 @@ type JWSCtx struct {
Existing bool `json:"-"`
}

func (c *JWSCtx) UnmarshalJSON(a *ACMEState, jws []byte) error {
func (c *jwsCtx) UnmarshalJSON(a *acmeState, jws []byte) error {
var err error
if err = json.Unmarshal(jws, c); err != nil {
return err
Expand All @@ -44,15 +44,15 @@ func (c *JWSCtx) UnmarshalJSON(a *ACMEState, jws []byte) error {
//
// > The "jwk" and "kid" fields are mutually exclusive. Servers MUST
// > reject requests that contain both.
return fmt.Errorf("invalid header: got both account 'kid' and 'jwk' in the same message; expected only one")
return fmt.Errorf("invalid header: got both account 'kid' and 'jwk' in the same message; expected only one: %w", ErrMalformed)
}

if c.Kid == "" && len(c.jwk) == 0 {
// See RFC 8555 Section 6.2. Request Authentication:
//
// > Either "jwk" (JSON Web Key) or "kid" (Key ID) as specified
// > below
return fmt.Errorf("invalid header: got neither required fields of 'kid' nor 'jwk'")
return fmt.Errorf("invalid header: got neither required fields of 'kid' nor 'jwk': %w", ErrMalformed)
}

if _, present := AllowedOuterJWSTypes[c.Algo]; !present {
Expand All @@ -65,7 +65,7 @@ func (c *JWSCtx) UnmarshalJSON(a *ACMEState, jws []byte) error {
// > * This field MUST NOT contain "none" or a Message
// > Authentication Code (MAC) algorithm (e.g. one in which the
// > algorithm registry description mentions MAC/HMAC).
return fmt.Errorf("invalid header: unexpected value for 'algo'")
return fmt.Errorf("invalid header: unexpected value for 'algo': %w", ErrMalformed)
}

if c.Kid != "" {
Expand All @@ -82,7 +82,7 @@ func (c *JWSCtx) UnmarshalJSON(a *ACMEState, jws []byte) error {
}

if !c.key.Valid() {
return fmt.Errorf("received invalid jwk")
return fmt.Errorf("received invalid jwk: %w", ErrMalformed)
}

if c.Kid != "" {
Expand All @@ -103,29 +103,29 @@ func hasValues(h jose.Header) bool {
return h.KeyID != "" || h.JSONWebKey != nil || h.Algorithm != "" || h.Nonce != "" || len(h.ExtraHeaders) > 0
}

func (c *JWSCtx) VerifyJWS(signature string) (map[string]interface{}, error) {
func (c *jwsCtx) VerifyJWS(signature string) (map[string]interface{}, error) {
// See RFC 8555 Section 6.2. Request Authentication:
//
// > The JWS Unencoded Payload Option [RFC7797] MUST NOT be used
//
// This is validated by go-jose.
sig, err := jose.ParseSigned(signature)
if err != nil {
return nil, fmt.Errorf("error parsing signature: %w", err)
return nil, fmt.Errorf("error parsing signature: %s: %w", err, ErrMalformed)
}

if len(sig.Signatures) > 1 {
// See RFC 8555 Section 6.2. Request Authentication:
//
// > The JWS MUST NOT have multiple signatures
return nil, fmt.Errorf("request had multiple signatures")
return nil, fmt.Errorf("request had multiple signatures: %w", ErrMalformed)
}

if hasValues(sig.Signatures[0].Unprotected) {
// See RFC 8555 Section 6.2. Request Authentication:
//
// > The JWS Unprotected Header [RFC7515] MUST NOT be used
return nil, fmt.Errorf("request had unprotected headers")
return nil, fmt.Errorf("request had unprotected headers: %w", ErrMalformed)
}

payload, err := sig.Verify(c.key)
Expand All @@ -135,7 +135,7 @@ func (c *JWSCtx) VerifyJWS(signature string) (map[string]interface{}, error) {

var m map[string]interface{}
if err := json.Unmarshal(payload, &m); err != nil {
return nil, fmt.Errorf("failed to json unmarshal 'payload': %w", err)
return nil, fmt.Errorf("failed to json unmarshal 'payload': %s: %w", err, ErrMalformed)
}

return m, nil
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package acme
package pki

import (
"crypto/rand"
Expand All @@ -15,13 +15,13 @@ import (
// How long nonces are considered valid.
const nonceExpiry = 15 * time.Minute

type ACMEState struct {
type acmeState struct {
nextExpiry *atomic.Int64
nonces *sync.Map // map[string]time.Time
}

func NewACMEState() *ACMEState {
return &ACMEState{
func NewACMEState() *acmeState {
return &acmeState{
nextExpiry: new(atomic.Int64),
nonces: new(sync.Map),
}
Expand All @@ -36,7 +36,7 @@ func generateNonce() (string, error) {
return base64.RawURLEncoding.EncodeToString(data), nil
}

func (a *ACMEState) GetNonce() (string, time.Time, error) {
func (a *acmeState) GetNonce() (string, time.Time, error) {
now := time.Now()
nonce, err := generateNonce()
if err != nil {
Expand All @@ -55,7 +55,7 @@ func (a *ACMEState) GetNonce() (string, time.Time, error) {
return nonce, then, nil
}

func (a *ACMEState) RedeemNonce(nonce string) bool {
func (a *acmeState) RedeemNonce(nonce string) bool {
rawTimeout, present := a.nonces.LoadAndDelete(nonce)
if !present {
return false
Expand All @@ -69,7 +69,7 @@ func (a *ACMEState) RedeemNonce(nonce string) bool {
return true
}

func (a *ACMEState) DoTidyNonces() {
func (a *acmeState) DoTidyNonces() {
now := time.Now()
expiry := a.nextExpiry.Load()
then := time.Unix(expiry, 0)
Expand All @@ -79,7 +79,7 @@ func (a *ACMEState) DoTidyNonces() {
}
}

func (a *ACMEState) TidyNonces() {
func (a *acmeState) TidyNonces() {
now := time.Now()
nextRun := now.Add(nonceExpiry)

Expand All @@ -99,22 +99,22 @@ func (a *ACMEState) TidyNonces() {
a.nextExpiry.Store(nextRun.Unix())
}

func (a *ACMEState) CreateAccount(c *JWSCtx, contact []string, termsOfServiceAgreed bool) (map[string]interface{}, error) {
func (a *acmeState) CreateAccount(c *jwsCtx, contact []string, termsOfServiceAgreed bool) (map[string]interface{}, error) {
// TODO
return nil, nil
}

func (a *ACMEState) LoadAccount(keyID string) (map[string]interface{}, error) {
func (a *acmeState) LoadAccount(keyID string) (map[string]interface{}, error) {
// TODO
return nil, nil
}

func (a *ACMEState) DoesAccountExist(keyId string) bool {
func (a *acmeState) DoesAccountExist(keyId string) bool {
account, err := a.LoadAccount(keyId)
return err == nil && len(account) > 0
}

func (a *ACMEState) LoadJWK(keyID string) ([]byte, error) {
func (a *acmeState) LoadJWK(keyID string) ([]byte, error) {
key, err := a.LoadAccount(keyID)
if err != nil {
return nil, err
Expand All @@ -128,15 +128,20 @@ func (a *ACMEState) LoadJWK(keyID string) ([]byte, error) {
return jwk.([]byte), nil
}

func (a *ACMEState) ParseRequestParams(data *framework.FieldData) (*JWSCtx, map[string]interface{}, error) {
var c JWSCtx
func (a *acmeState) ParseRequestParams(data *framework.FieldData) (*jwsCtx, map[string]interface{}, error) {
var c jwsCtx
var m map[string]interface{}

// Parse the key out.
jwkBase64 := data.Get("protected").(string)
rawJWKBase64, ok := data.GetOk("protected")
if !ok {
return nil, nil, fmt.Errorf("missing required field 'protected': %w", ErrMalformed)
}
jwkBase64 := rawJWKBase64.(string)

jwkBytes, err := base64.RawURLEncoding.DecodeString(jwkBase64)
if err != nil {
return nil, nil, fmt.Errorf("failed to base64 parse 'protected': %w", err)
return nil, nil, fmt.Errorf("failed to base64 parse 'protected': %s: %w", err, ErrMalformed)
}
if err = c.UnmarshalJSON(a, jwkBytes); err != nil {
return nil, nil, fmt.Errorf("failed to json unmarshal 'protected': %w", err)
Expand All @@ -146,11 +151,20 @@ func (a *ACMEState) ParseRequestParams(data *framework.FieldData) (*JWSCtx, map[
// should read and redeem the nonce here too, to avoid doing any extra
// work if it is invalid.
if !a.RedeemNonce(c.Nonce) {
return nil, nil, fmt.Errorf("invalid or reused nonce")
return nil, nil, fmt.Errorf("invalid or reused nonce: %w", ErrBadNonce)
}

payloadBase64 := data.Get("payload").(string)
signatureBase64 := data.Get("signature").(string)
rawPayloadBase64, ok := data.GetOk("payload")
if !ok {
return nil, nil, fmt.Errorf("missing required field 'payload': %w", ErrMalformed)
}
payloadBase64 := rawPayloadBase64.(string)

rawSignatureBase64, ok := data.GetOk("signature")
if !ok {
return nil, nil, fmt.Errorf("missing required field 'signature': %w", ErrMalformed)
}
signatureBase64 := rawSignatureBase64.(string)

// go-jose only seems to support compact signature encodings.
compactSig := fmt.Sprintf("%v.%v.%v", jwkBase64, payloadBase64, signatureBase64)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package acme
package pki

import (
"testing"
Expand Down
6 changes: 2 additions & 4 deletions builtin/logical/pki/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ import (

atomic2 "go.uber.org/atomic"

"github.com/hashicorp/vault/builtin/logical/pki/acme"

"github.com/armon/go-metrics"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/helper/constants"
Expand Down Expand Up @@ -290,7 +288,7 @@ func Backend(conf *logical.BackendConfig) *backend {

b.unifiedTransferStatus = newUnifiedTransferStatus()

b.acmeState = acme.NewACMEState()
b.acmeState = NewACMEState()
return &b
}

Expand Down Expand Up @@ -325,7 +323,7 @@ type backend struct {
issuersLock sync.RWMutex

// Context around ACME operations
acmeState *acme.ACMEState
acmeState *acmeState
}

type roleOperation func(ctx context.Context, req *logical.Request, data *framework.FieldData, role *roleEntry) (*logical.Response, error)
Expand Down
9 changes: 4 additions & 5 deletions builtin/logical/pki/path_acme_directory.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"net/url"
"strings"

"github.com/hashicorp/vault/builtin/logical/pki/acme"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
Expand Down Expand Up @@ -69,7 +68,7 @@ func (b *backend) acmeWrapper(op acmeOperation) framework.OperationFunc {

if false {
// TODO sclark: Check if ACME is enable here
return nil, fmt.Errorf("ACME is disabled in configuration: %w", acme.ErrServerInternal)
return nil, fmt.Errorf("ACME is disabled in configuration: %w", ErrServerInternal)
}

baseUrl, err := getAcmeBaseUrl(sc, r.Path)
Expand All @@ -93,12 +92,12 @@ func getAcmeBaseUrl(sc *storageContext, path string) (*url.URL, error) {
}

if cfg.Path == "" {
return nil, fmt.Errorf("ACME feature requires local cluster path configuration to be set: %w", acme.ErrServerInternal)
return nil, fmt.Errorf("ACME feature requires local cluster path configuration to be set: %w", ErrServerInternal)
}

baseUrl, err := url.Parse(cfg.Path)
if err != nil {
return nil, fmt.Errorf("ACME feature a proper URL configured in local cluster path: %w", acme.ErrServerInternal)
return nil, fmt.Errorf("ACME feature a proper URL configured in local cluster path: %w", ErrServerInternal)
}

directoryPrefix := ""
Expand All @@ -114,7 +113,7 @@ func acmeErrorWrapper(op framework.OperationFunc) framework.OperationFunc {
return func(ctx context.Context, r *logical.Request, data *framework.FieldData) (*logical.Response, error) {
resp, err := op(ctx, r, data)
if err != nil {
return acme.TranslateError(err)
return TranslateError(err)
}

return resp, nil
Expand Down
Loading

0 comments on commit 32e3cd6

Please sign in to comment.