Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin' into kletterstein/cacert-keyusage
Browse files Browse the repository at this point in the history
  • Loading branch information
bkstein committed Dec 1, 2023
2 parents c3b810d + 988fe4e commit 2ae66e4
Show file tree
Hide file tree
Showing 19 changed files with 266 additions and 113 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ usage: scep [<command>] [<args>]
type <command> --help to see usage for each subcommand
```
Use the `ca -init` subcommand to create a new CA and private key.
Use the `ca -init` subcommand to create a new CA and private key.
CA sub-command usage:
```
Expand All @@ -95,6 +95,8 @@ Usage of ca:
password to store rsa key
-keySize int
rsa key size (default 4096)
-common_name string
common name (CN) for CA cert (default "MICROMDM SCEP CA")
-organization string
organization for CA cert (default "scep-ca")
-organizational_unit string
Expand Down
7 changes: 4 additions & 3 deletions challenge/challenge.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package challenge

import (
"context"
"crypto/x509"
"errors"

Expand All @@ -16,8 +17,8 @@ type Store interface {
}

// Middleware wraps next in a CSRSigner that verifies and invalidates the challenge
func Middleware(store Store, next scepserver.CSRSigner) scepserver.CSRSignerFunc {
return func(m *scep.CSRReqMessage) (*x509.Certificate, error) {
func Middleware(store Store, next scepserver.CSRSignerContext) scepserver.CSRSignerContextFunc {
return func(ctx context.Context, m *scep.CSRReqMessage) (*x509.Certificate, error) {
// TODO: compare challenge only for PKCSReq?
valid, err := store.HasChallenge(m.ChallengePassword)
if err != nil {
Expand All @@ -26,6 +27,6 @@ func Middleware(store Store, next scepserver.CSRSigner) scepserver.CSRSignerFunc
if !valid {
return nil, errors.New("invalid challenge")
}
return next.SignCSR(m)
return next.SignCSRContext(ctx, m)
}
}
7 changes: 5 additions & 2 deletions challenge/challenge_bolt_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package challenge

import (
"context"
"io/ioutil"
"os"
"testing"
Expand Down Expand Up @@ -69,12 +70,14 @@ func TestDynamicChallenge(t *testing.T) {
ChallengePassword: challengePassword,
}

_, err = signer.SignCSR(csrReq)
ctx := context.Background()

_, err = signer.SignCSRContext(ctx, csrReq)
if err != nil {
t.Error(err)
}

_, err = signer.SignCSR(csrReq)
_, err = signer.SignCSRContext(ctx, csrReq)
if err == nil {
t.Error("challenge should not be valid twice")
}
Expand Down
5 changes: 3 additions & 2 deletions cmd/scepclient/csr.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ const (
)

type csrOptions struct {
cn, org, country, ou, locality, province, challenge string
key *rsa.PrivateKey
cn, org, country, ou, locality, province, dnsName, challenge string
key *rsa.PrivateKey
}

func loadOrMakeCSR(path string, opts *csrOptions) (*x509.CertificateRequest, error) {
Expand All @@ -44,6 +44,7 @@ func loadOrMakeCSR(path string, opts *csrOptions) (*x509.CertificateRequest, err
CertificateRequest: x509.CertificateRequest{
Subject: subject,
SignatureAlgorithm: x509.SHA256WithRSA,
DNSNames: subjOrNil(opts.dnsName),
},
}
if opts.challenge != "" {
Expand Down
13 changes: 9 additions & 4 deletions cmd/scepclient/scepclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type runCfg struct {
debug bool
logfmt string
caCertMsg string
dnsName string
}

func run(cfg runCfg) error {
Expand Down Expand Up @@ -88,6 +89,7 @@ func run(cfg runCfg) error {
province: cfg.province,
challenge: cfg.challenge,
key: key,
dnsName: cfg.dnsName,
}

csr, err := loadOrMakeCSR(cfg.csrPath, opts)
Expand Down Expand Up @@ -234,10 +236,11 @@ func logCerts(logger log.Logger, certs []*x509.Certificate) {

// validateFingerprint makes sure fingerprint looks like a hash.
// We remove spaces and colons from fingerprint as it may come in various forms:
// e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
// E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855
// e3b0c442 98fc1c14 9afbf4c8 996fb924 27ae41e4 649b934c a495991b 7852b855
// e3:b0:c4:42:98:fc:1c:14:9a:fb:f4:c8:99:6f:b9:24:27:ae:41:e4:64:9b:93:4c:a4:95:99:1b:78:52:b8:55
//
// e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
// E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855
// e3b0c442 98fc1c14 9afbf4c8 996fb924 27ae41e4 649b934c a495991b 7852b855
// e3:b0:c4:42:98:fc:1c:14:9a:fb:f4:c8:99:6f:b9:24:27:ae:41:e4:64:9b:93:4c:a4:95:99:1b:78:52:b8:55
func validateFingerprint(fingerprint string) (hash []byte, err error) {
fingerprint = strings.NewReplacer(" ", "", ":", "").Replace(fingerprint)
hash, err = hex.DecodeString(fingerprint)
Expand Down Expand Up @@ -279,6 +282,7 @@ func main() {
flProvince = flag.String("province", "", "province for certificate")
flCountry = flag.String("country", "US", "country code in certificate")
flCACertMessage = flag.String("cacert-message", "", "message sent with GetCACert operation")
flDNSName = flag.String("dnsname", "", "DNS name to be included in the certificate (SAN)")

// in case of multiple certificate authorities, we need to figure out who the recipient of the encrypted
// data is.
Expand Down Expand Up @@ -340,6 +344,7 @@ func main() {
debug: *flDebugLogging,
logfmt: logfmt,
caCertMsg: *flCACertMessage,
dnsName: *flDNSName,
}

if err := run(cfg); err != nil {
Expand Down
65 changes: 47 additions & 18 deletions cmd/scepserver/scepserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ func main() {
//main flags
var (
flVersion = flag.Bool("version", false, "prints version information")
flPort = flag.String("port", envString("SCEP_HTTP_LISTEN_PORT", "8080"), "port to listen on")
flHTTPAddr = flag.String("http-addr", envString("SCEP_HTTP_ADDR", ""), "http listen address. defaults to \":8080\"")
flPort = flag.String("port", envString("SCEP_HTTP_LISTEN_PORT", "8080"), "http port to listen on (if you want to specify an address, use -http-addr instead)")
flDepotPath = flag.String("depot", envString("SCEP_FILE_DEPOT", "depot"), "path to ca folder")
flCAPass = flag.String("capass", envString("SCEP_CA_PASS", ""), "passwd for the ca.key")
flClDuration = flag.String("crtvalid", envString("SCEP_CERT_VALID", "365"), "validity for new client certificates in days")
Expand All @@ -52,6 +53,7 @@ func main() {
flCSRVerifierExec = flag.String("csrverifierexec", envString("SCEP_CSR_VERIFIER_EXEC", ""), "will be passed the CSRs for verification")
flDebug = flag.Bool("debug", envBool("SCEP_LOG_DEBUG"), "enable debug logging")
flLogJSON = flag.Bool("log-json", envBool("SCEP_LOG_JSON"), "output JSON logs")
flSignServerAttrs = flag.Bool("sign-server-attrs", envBool("SCEP_SIGN_SERVER_ATTRS"), "sign cert attrs for server usage")
)
flag.Usage = func() {
flag.PrintDefaults()
Expand All @@ -67,7 +69,19 @@ func main() {
fmt.Println(version)
os.Exit(0)
}
port := ":" + *flPort

// -http-addr and -port conflict. Don't allow the user to set both.
httpAddrSet := setByUser("http-addr", "SCEP_HTTP_ADDR")
portSet := setByUser("port", "SCEP_HTTP_LISTEN_PORT")
var httpAddr string
if httpAddrSet && portSet {
fmt.Fprintln(os.Stderr, "cannot set both -http-addr and -port")
os.Exit(1)
} else if httpAddrSet {
httpAddr = *flHTTPAddr
} else {
httpAddr = ":" + *flPort
}

var logger log.Logger
{
Expand Down Expand Up @@ -125,14 +139,17 @@ func main() {
lginfo.Log("err", "missing CA certificate")
os.Exit(1)
}
var signer scepserver.CSRSigner = scepdepot.NewSigner(
depot,
signerOpts := []scepdepot.Option{
scepdepot.WithAllowRenewalDays(allowRenewal),
scepdepot.WithValidityDays(clientValidity),
scepdepot.WithCAPass(*flCAPass),
)
}
if *flSignServerAttrs {
signerOpts = append(signerOpts, scepdepot.WithSeverAttrs())
}
var signer scepserver.CSRSignerContext = scepserver.SignCSRAdapter(scepdepot.NewSigner(depot, signerOpts...))
if *flChallengePassword != "" {
signer = scepserver.ChallengeMiddleware(*flChallengePassword, signer)
signer = scepserver.StaticChallengeMiddleware(*flChallengePassword, signer)
}
if csrVerifier != nil {
signer = csrverifier.Middleware(csrVerifier, signer)
Expand All @@ -156,8 +173,8 @@ func main() {
// start http server
errs := make(chan error, 2)
go func() {
lginfo.Log("transport", "http", "address", port, "msg", "listening")
errs <- http.ListenAndServe(port, h)
lginfo.Log("transport", "http", "address", httpAddr, "msg", "listening")
errs <- http.ListenAndServe(httpAddr, h)
}()
go func() {
c := make(chan os.Signal)
Expand All @@ -170,14 +187,15 @@ func main() {

func caMain(cmd *flag.FlagSet) int {
var (
flDepotPath = cmd.String("depot", "depot", "path to ca folder")
flInit = cmd.Bool("init", false, "create a new CA")
flYears = cmd.Int("years", 10, "default CA years")
flKeySize = cmd.Int("keySize", 4096, "rsa key size")
flOrg = cmd.String("organization", "scep-ca", "organization for CA cert")
flOrgUnit = cmd.String("organizational_unit", "SCEP CA", "organizational unit (OU) for CA cert")
flPassword = cmd.String("key-password", "", "password to store rsa key")
flCountry = cmd.String("country", "US", "country for CA cert")
flDepotPath = cmd.String("depot", "depot", "path to ca folder")
flInit = cmd.Bool("init", false, "create a new CA")
flYears = cmd.Int("years", 10, "default CA years")
flKeySize = cmd.Int("keySize", 4096, "rsa key size")
flCommonName = cmd.String("common_name", "MICROMDM SCEP CA", "common name (CN) for CA cert")
flOrg = cmd.String("organization", "scep-ca", "organization for CA cert")
flOrgUnit = cmd.String("organizational_unit", "SCEP CA", "organizational unit (OU) for CA cert")
flPassword = cmd.String("key-password", "", "password to store rsa key")
flCountry = cmd.String("country", "US", "country for CA cert")
)
cmd.Parse(os.Args[2:])
if *flInit {
Expand All @@ -187,7 +205,7 @@ func caMain(cmd *flag.FlagSet) int {
fmt.Println(err)
return 1
}
if err := createCertificateAuthority(key, *flYears, *flOrg, *flOrgUnit, *flCountry, *flDepotPath); err != nil {
if err := createCertificateAuthority(key, *flYears, *flCommonName, *flOrg, *flOrgUnit, *flCountry, *flDepotPath); err != nil {
fmt.Println(err)
return 1
}
Expand Down Expand Up @@ -232,9 +250,10 @@ func createKey(bits int, password []byte, depot string) (*rsa.PrivateKey, error)
return key, nil
}

func createCertificateAuthority(key *rsa.PrivateKey, years int, organization string, organizationalUnit string, country string, depot string) error {
func createCertificateAuthority(key *rsa.PrivateKey, years int, commonName string, organization string, organizationalUnit string, country string, depot string) error {
cert := scepdepot.NewCACert(
scepdepot.WithYears(years),
scepdepot.WithCommonName(commonName),
scepdepot.WithOrganization(organization),
scepdepot.WithOrganizationalUnit(organizationalUnit),
scepdepot.WithCountry(country),
Expand Down Expand Up @@ -288,3 +307,13 @@ func envBool(key string) bool {
}
return false
}

func setByUser(flagName, envName string) bool {
userDefinedFlags := make(map[string]bool)
flag.Visit(func(f *flag.Flag) {
userDefinedFlags[f.Name] = true
})
flagSet := userDefinedFlags[flagName]
_, envSet := os.LookupEnv(envName)
return flagSet || envSet
}
7 changes: 6 additions & 1 deletion cryptoutil/cryptoutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,23 @@ import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"math/big"
"testing"
)

func TestGenerateSubjectKeyID(t *testing.T) {
ecKey, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
if err != nil {
t.Fatal(err)
}
for _, test := range []struct {
testName string
pub crypto.PublicKey
}{
{"RSA", &rsa.PublicKey{N: big.NewInt(123), E: 65537}},
{"ECDSA", &ecdsa.PublicKey{X: big.NewInt(123), Y: big.NewInt(123), Curve: elliptic.P224()}},
{"ECDSA", ecKey.Public()},
} {
test := test
t.Run(test.testName, func(t *testing.T) {
Expand Down
7 changes: 4 additions & 3 deletions csrverifier/csrverifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package csrverifier

import (
"context"
"crypto/x509"
"errors"

Expand All @@ -15,15 +16,15 @@ type CSRVerifier interface {
}

// Middleware wraps next in a CSRSigner that runs verifier
func Middleware(verifier CSRVerifier, next scepserver.CSRSigner) scepserver.CSRSignerFunc {
return func(m *scep.CSRReqMessage) (*x509.Certificate, error) {
func Middleware(verifier CSRVerifier, next scepserver.CSRSignerContext) scepserver.CSRSignerContextFunc {
return func(ctx context.Context, m *scep.CSRReqMessage) (*x509.Certificate, error) {
ok, err := verifier.Verify(m.RawDecrypted)
if err != nil {
return nil, err
}
if !ok {
return nil, errors.New("CSR verify failed")
}
return next.SignCSR(m)
return next.SignCSRContext(ctx, m)
}
}
Loading

0 comments on commit 2ae66e4

Please sign in to comment.