From 4f808569bd3469bd40575acb970fa62ea763aec5 Mon Sep 17 00:00:00 2001 From: Jesse Peterson Date: Wed, 6 Apr 2022 14:11:05 -0700 Subject: [PATCH] BoltDB Depot serial number changes for data races --- depot/bolt/depot.go | 47 +++++++++++++++++----------------------- depot/bolt/depot_test.go | 2 +- 2 files changed, 21 insertions(+), 28 deletions(-) diff --git a/depot/bolt/depot.go b/depot/bolt/depot.go index a0c1161..cdeee81 100644 --- a/depot/bolt/depot.go +++ b/depot/bolt/depot.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "math/big" + "sync" "github.com/micromdm/scep/v2/depot" @@ -18,6 +19,7 @@ import ( // https://github.com/boltdb/bolt type Depot struct { *bolt.DB + serialMu sync.RWMutex } const ( @@ -36,7 +38,7 @@ func NewBoltDepot(db *bolt.DB) (*Depot, error) { if err != nil { return nil, err } - return &Depot{db}, nil + return &Depot{DB: db}, nil } // For some read operations Bolt returns a direct memory reference to @@ -93,26 +95,28 @@ func (db *Depot) Put(cn string, crt *x509.Certificate) error { if crt == nil || crt.Raw == nil { return fmt.Errorf("%q does not specify a valid certificate for storage", cn) } - serial, err := db.Serial() - if err != nil { - return err - } - - err = db.Update(func(tx *bolt.Tx) error { + err := db.Update(func(tx *bolt.Tx) error { bucket := tx.Bucket([]byte(certBucket)) if bucket == nil { return fmt.Errorf("bucket %q not found!", certBucket) } - name := cn + "." + serial.String() + name := cn + "." + crt.SerialNumber.String() return bucket.Put([]byte(name), crt.Raw) }) + return err +} + +func (db *Depot) Serial() (*big.Int, error) { + db.serialMu.Lock() + defer db.serialMu.Unlock() + s, err := db.readSerial() if err != nil { - return err + return nil, err } - return db.incrementSerial(serial) + return s, db.incrementSerial(s) } -func (db *Depot) Serial() (*big.Int, error) { +func (db *Depot) readSerial() (*big.Int, error) { s := big.NewInt(2) if !db.hasKey([]byte("serial")) { if err := db.writeSerial(s); err != nil { @@ -132,10 +136,7 @@ func (db *Depot) Serial() (*big.Int, error) { s = s.SetBytes(k) return nil }) - if err != nil { - return nil, err - } - return s, nil + return s, err } func (db *Depot) writeSerial(s *big.Int) error { @@ -156,7 +157,7 @@ func (db *Depot) hasKey(name []byte) bool { if bucket == nil { return fmt.Errorf("bucket %q not found!", certBucket) } - k := bucket.Get([]byte("serial")) + k := bucket.Get(name) if k != nil { present = true } @@ -166,15 +167,8 @@ func (db *Depot) hasKey(name []byte) bool { } func (db *Depot) incrementSerial(s *big.Int) error { - serial := s.Add(s, big.NewInt(1)) - err := db.Update(func(tx *bolt.Tx) error { - bucket := tx.Bucket([]byte(certBucket)) - if bucket == nil { - return fmt.Errorf("bucket %q not found!", certBucket) - } - return bucket.Put([]byte("serial"), []byte(serial.Bytes())) - }) - return err + serial := new(big.Int).Add(s, big.NewInt(1)) + return db.writeSerial(serial) } func (db *Depot) HasCN(cn string, allowTime int, cert *x509.Certificate, revokeOldCertificate bool) (bool, error) { @@ -185,8 +179,7 @@ func (db *Depot) HasCN(cn string, allowTime int, cert *x509.Certificate, revokeO } var hasCN bool err := db.View(func(tx *bolt.Tx) error { - // TODO: "scep_certificates" is internal const in micromdm/scep - curs := tx.Bucket([]byte("scep_certificates")).Cursor() + curs := tx.Bucket([]byte(certBucket)).Cursor() prefix := []byte(cert.Subject.CommonName) for k, v := curs.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = curs.Next() { if bytes.Compare(v, cert.Raw) == 0 { diff --git a/depot/bolt/depot_test.go b/depot/bolt/depot_test.go index e64dbed..34a9e94 100644 --- a/depot/bolt/depot_test.go +++ b/depot/bolt/depot_test.go @@ -100,7 +100,7 @@ func TestDepot_incrementSerial(t *testing.T) { if err := db.incrementSerial(tt.args); (err != nil) != tt.wantErr { t.Errorf("%q. Depot.incrementSerial() error = %v, wantErr %v", tt.name, err, tt.wantErr) } - got, _ := db.Serial() + got, _ := db.readSerial() if !reflect.DeepEqual(got, tt.want) { t.Errorf("%q. Depot.Serial() = %v, want %v", tt.name, got, tt.want) }