From ab5ddac82853042a97221b436469cf66ee6baaba Mon Sep 17 00:00:00 2001 From: Jesse Peterson Date: Thu, 30 Nov 2023 22:23:50 -0800 Subject: [PATCH] BoltDB Depot serial number changes for data races (#190) --- depot/bolt/depot.go | 47 +++++++++++++++++----------------------- depot/bolt/depot_test.go | 2 +- depot/file/depot.go | 21 ++++++++++-------- depot/signer.go | 5 ----- 4 files changed, 33 insertions(+), 42 deletions(-) diff --git a/depot/bolt/depot.go b/depot/bolt/depot.go index a0c1161..cdeee81 100644 --- a/depot/bolt/depot.go +++ b/depot/bolt/depot.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "math/big" + "sync" "github.com/micromdm/scep/v2/depot" @@ -18,6 +19,7 @@ import ( // https://github.com/boltdb/bolt type Depot struct { *bolt.DB + serialMu sync.RWMutex } const ( @@ -36,7 +38,7 @@ func NewBoltDepot(db *bolt.DB) (*Depot, error) { if err != nil { return nil, err } - return &Depot{db}, nil + return &Depot{DB: db}, nil } // For some read operations Bolt returns a direct memory reference to @@ -93,26 +95,28 @@ func (db *Depot) Put(cn string, crt *x509.Certificate) error { if crt == nil || crt.Raw == nil { return fmt.Errorf("%q does not specify a valid certificate for storage", cn) } - serial, err := db.Serial() - if err != nil { - return err - } - - err = db.Update(func(tx *bolt.Tx) error { + err := db.Update(func(tx *bolt.Tx) error { bucket := tx.Bucket([]byte(certBucket)) if bucket == nil { return fmt.Errorf("bucket %q not found!", certBucket) } - name := cn + "." + serial.String() + name := cn + "." + crt.SerialNumber.String() return bucket.Put([]byte(name), crt.Raw) }) + return err +} + +func (db *Depot) Serial() (*big.Int, error) { + db.serialMu.Lock() + defer db.serialMu.Unlock() + s, err := db.readSerial() if err != nil { - return err + return nil, err } - return db.incrementSerial(serial) + return s, db.incrementSerial(s) } -func (db *Depot) Serial() (*big.Int, error) { +func (db *Depot) readSerial() (*big.Int, error) { s := big.NewInt(2) if !db.hasKey([]byte("serial")) { if err := db.writeSerial(s); err != nil { @@ -132,10 +136,7 @@ func (db *Depot) Serial() (*big.Int, error) { s = s.SetBytes(k) return nil }) - if err != nil { - return nil, err - } - return s, nil + return s, err } func (db *Depot) writeSerial(s *big.Int) error { @@ -156,7 +157,7 @@ func (db *Depot) hasKey(name []byte) bool { if bucket == nil { return fmt.Errorf("bucket %q not found!", certBucket) } - k := bucket.Get([]byte("serial")) + k := bucket.Get(name) if k != nil { present = true } @@ -166,15 +167,8 @@ func (db *Depot) hasKey(name []byte) bool { } func (db *Depot) incrementSerial(s *big.Int) error { - serial := s.Add(s, big.NewInt(1)) - err := db.Update(func(tx *bolt.Tx) error { - bucket := tx.Bucket([]byte(certBucket)) - if bucket == nil { - return fmt.Errorf("bucket %q not found!", certBucket) - } - return bucket.Put([]byte("serial"), []byte(serial.Bytes())) - }) - return err + serial := new(big.Int).Add(s, big.NewInt(1)) + return db.writeSerial(serial) } func (db *Depot) HasCN(cn string, allowTime int, cert *x509.Certificate, revokeOldCertificate bool) (bool, error) { @@ -185,8 +179,7 @@ func (db *Depot) HasCN(cn string, allowTime int, cert *x509.Certificate, revokeO } var hasCN bool err := db.View(func(tx *bolt.Tx) error { - // TODO: "scep_certificates" is internal const in micromdm/scep - curs := tx.Bucket([]byte("scep_certificates")).Cursor() + curs := tx.Bucket([]byte(certBucket)).Cursor() prefix := []byte(cert.Subject.CommonName) for k, v := curs.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = curs.Next() { if bytes.Compare(v, cert.Raw) == 0 { diff --git a/depot/bolt/depot_test.go b/depot/bolt/depot_test.go index e64dbed..34a9e94 100644 --- a/depot/bolt/depot_test.go +++ b/depot/bolt/depot_test.go @@ -100,7 +100,7 @@ func TestDepot_incrementSerial(t *testing.T) { if err := db.incrementSerial(tt.args); (err != nil) != tt.wantErr { t.Errorf("%q. Depot.incrementSerial() error = %v, wantErr %v", tt.name, err, tt.wantErr) } - got, _ := db.Serial() + got, _ := db.readSerial() if !reflect.DeepEqual(got, tt.want) { t.Errorf("%q. Depot.Serial() = %v, want %v", tt.name, got, tt.want) } diff --git a/depot/file/depot.go b/depot/file/depot.go index b843b95..ab8d336 100644 --- a/depot/file/depot.go +++ b/depot/file/depot.go @@ -16,6 +16,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "time" ) @@ -31,7 +32,9 @@ func NewFileDepot(path string) (*fileDepot, error) { } type fileDepot struct { - dirPath string + dirPath string + serialMu sync.Mutex + dbMu sync.Mutex } func (d *fileDepot) CA(pass []byte) ([]*x509.Certificate, *rsa.PrivateKey, error) { @@ -75,10 +78,7 @@ func (d *fileDepot) Put(cn string, crt *x509.Certificate) error { return err } - serial, err := d.Serial() - if err != nil { - return err - } + serial := crt.SerialNumber if crt.Subject.CommonName == "" { // this means our cn was replaced by the certificate Signature @@ -103,14 +103,12 @@ func (d *fileDepot) Put(cn string, crt *x509.Certificate) error { return err } - if err := d.incrementSerial(serial); err != nil { - return err - } - return nil } func (d *fileDepot) Serial() (*big.Int, error) { + d.serialMu.Lock() + defer d.serialMu.Unlock() name := d.path("serial") s := big.NewInt(2) if err := d.check("serial"); err != nil { @@ -136,6 +134,9 @@ func (d *fileDepot) Serial() (*big.Int, error) { if !ok { return nil, errors.New("could not convert " + string(data) + " to serial number") } + if err := d.incrementSerial(serial); err != nil { + return serial, err + } return serial, nil } @@ -255,6 +256,8 @@ func (d *fileDepot) HasCN(_ string, allowTime int, cert *x509.Certificate, revok } func (d *fileDepot) writeDB(cn string, serial *big.Int, filename string, cert *x509.Certificate) error { + d.dbMu.Lock() + defer d.dbMu.Unlock() var dbEntry bytes.Buffer diff --git a/depot/signer.go b/depot/signer.go index 6fe0109..3e3bdb5 100644 --- a/depot/signer.go +++ b/depot/signer.go @@ -3,7 +3,6 @@ package depot import ( "crypto/rand" "crypto/x509" - "sync" "time" "github.com/micromdm/scep/v2/cryptoutil" @@ -13,7 +12,6 @@ import ( // Signer signs x509 certificates and stores them in a Depot type Signer struct { depot Depot - mu sync.Mutex caPass string allowRenewalDays int validityDays int @@ -81,9 +79,6 @@ func (s *Signer) SignCSR(m *scep.CSRReqMessage) (*x509.Certificate, error) { return nil, err } - s.mu.Lock() - defer s.mu.Unlock() - serial, err := s.depot.Serial() if err != nil { return nil, err