diff --git a/conn.go b/conn.go index d09a7658..3ba4375b 100644 --- a/conn.go +++ b/conn.go @@ -521,10 +521,3 @@ func (c *Conn) stmtsIter(yield func(*Stmt) bool) { } } } - -// DriverConn is implemented by the SQLite [database/sql] driver connection. -// -// Deprecated: use [github.com/ncruces/go-sqlite3/driver.Conn] instead. -type DriverConn interface { - Raw() *Conn -} diff --git a/ext/bloom/testdata/bloom.db b/ext/bloom/testdata/bloom.db index f255762a..6e8b569a 100644 Binary files a/ext/bloom/testdata/bloom.db and b/ext/bloom/testdata/bloom.db differ diff --git a/tests/testdata/utf16be.db b/tests/testdata/utf16be.db index 08dc812c..336613f9 100644 Binary files a/tests/testdata/utf16be.db and b/tests/testdata/utf16be.db differ diff --git a/tests/testdata/wal.db b/tests/testdata/wal.db index e113317f..90b6151c 100644 Binary files a/tests/testdata/wal.db and b/tests/testdata/wal.db differ diff --git a/tests/wal_test.go b/tests/wal_test.go index 477edac9..733d5a8d 100644 --- a/tests/wal_test.go +++ b/tests/wal_test.go @@ -77,7 +77,7 @@ func TestWAL_readonly(t *testing.T) { // Select the data using the second (readonly) connection. var name string - err = db2.QueryRow("SELECT name FROM t").Scan(&name) + err = db2.QueryRow(`SELECT name FROM t`).Scan(&name) if err != nil { t.Fatal(err) } @@ -95,7 +95,7 @@ func TestWAL_readonly(t *testing.T) { } // Select the data using the second (readonly) connection. - err = db2.QueryRow("SELECT name FROM t").Scan(&name) + err = db2.QueryRow(`SELECT name FROM t`).Scan(&name) if err != nil { t.Fatal(err) } diff --git a/txn.go b/txn.go index 18d421ec..57ba979a 100644 --- a/txn.go +++ b/txn.go @@ -143,7 +143,7 @@ func (c *Conn) Savepoint() Savepoint { // Names can be reused, but this makes catching bugs more likely. name = QuoteIdentifier(name + "_" + strconv.Itoa(int(rand.Int31()))) - err := c.txnExecInterrupted("SAVEPOINT " + name) + err := c.txnExecInterrupted(`SAVEPOINT ` + name) if err != nil { panic(err) } @@ -187,7 +187,7 @@ func (s Savepoint) Release(errp *error) { if s.c.GetAutocommit() { // There is nothing to commit. return } - *errp = s.c.Exec("RELEASE " + s.name) + *errp = s.c.Exec(`RELEASE ` + s.name) if *errp == nil { return } @@ -199,8 +199,7 @@ func (s Savepoint) Release(errp *error) { return } // ROLLBACK and RELEASE even if interrupted. - err := s.c.txnExecInterrupted("ROLLBACK TO " + - s.name + "; RELEASE " + s.name) + err := s.c.txnExecInterrupted(`ROLLBACK TO ` + s.name + `; RELEASE ` + s.name) if err != nil { panic(err) } @@ -213,7 +212,7 @@ func (s Savepoint) Release(errp *error) { // https://sqlite.org/lang_transaction.html func (s Savepoint) Rollback() error { // ROLLBACK even if interrupted. - return s.c.txnExecInterrupted("ROLLBACK TO " + s.name) + return s.c.txnExecInterrupted(`ROLLBACK TO ` + s.name) } func (c *Conn) txnExecInterrupted(sql string) error { diff --git a/util/vfsutil/wrap.go b/util/vfsutil/wrap.go index 72ac7784..43b86406 100644 --- a/util/vfsutil/wrap.go +++ b/util/vfsutil/wrap.go @@ -22,6 +22,14 @@ func UnwrapFile[T vfs.File](f vfs.File) (_ T, _ bool) { } } +// WrapOpenFilename helps wrap [vfs.VFSFilename]. +func WrapOpenFilename(f vfs.VFS, name *vfs.Filename, flags vfs.OpenFlag) (file vfs.File, _ vfs.OpenFlag, err error) { + if f, ok := f.(vfs.VFSFilename); ok { + return f.OpenFilename(name, flags) + } + return f.Open(name.String(), flags) +} + // WrapLockState helps wrap [vfs.FileLockState]. func WrapLockState(f vfs.File) vfs.LockLevel { if f, ok := f.(vfs.FileLockState); ok { diff --git a/vfs/adiantum/adiantum_test.go b/vfs/adiantum/adiantum_test.go index dc327d1d..6e767733 100644 --- a/vfs/adiantum/adiantum_test.go +++ b/vfs/adiantum/adiantum_test.go @@ -21,7 +21,7 @@ var testDB string func Test_fileformat(t *testing.T) { readervfs.Create("test.db", ioutil.NewSizeReaderAt(strings.NewReader(testDB))) - adiantum.Register("radiantum", vfs.Find("reader"), nil) + vfs.Register("radiantum", adiantum.Wrap(vfs.Find("reader"), nil)) db, err := driver.Open("file:test.db?vfs=radiantum") if err != nil { diff --git a/vfs/adiantum/api.go b/vfs/adiantum/api.go index 09b180ac..97643c76 100644 --- a/vfs/adiantum/api.go +++ b/vfs/adiantum/api.go @@ -40,24 +40,25 @@ import ( ) func init() { - Register("adiantum", vfs.Find(""), nil) + vfs.Register("adiantum", Wrap(vfs.Find(""), nil)) } -// Register registers an encrypting VFS, wrapping a base VFS, -// and possibly using a custom HBSH cipher construction. +// Wrap wraps a base VFS to create an encrypting VFS, +// possibly using a custom HBSH cipher construction. +// // To use the default Adiantum construction, set cipher to nil. // // The default construction uses a 32 byte key/hexkey. // If a textkey is provided, the default KDF is Argon2id // with 64 MiB of memory, 3 iterations, and 4 threads. -func Register(name string, base vfs.VFS, cipher HBSHCreator) { +func Wrap(base vfs.VFS, cipher HBSHCreator) vfs.VFS { if cipher == nil { cipher = adiantumCreator{} } - vfs.Register(name, &hbshVFS{ + return &hbshVFS{ VFS: base, init: cipher, - }) + } } // HBSHCreator creates an [hbsh.HBSH] cipher diff --git a/vfs/adiantum/example_test.go b/vfs/adiantum/example_test.go index 590e9afe..aae7ed96 100644 --- a/vfs/adiantum/example_test.go +++ b/vfs/adiantum/example_test.go @@ -17,7 +17,7 @@ import ( ) func ExampleRegister_hpolyc() { - adiantum.Register("hpolyc", vfs.Find(""), hpolycCreator{}) + vfs.Register("hpolyc", adiantum.Wrap(vfs.Find(""), hpolycCreator{})) db, err := sqlite3.Open("file:demo.db?vfs=hpolyc" + "&textkey=correct+horse+battery+staple") diff --git a/vfs/adiantum/hbsh.go b/vfs/adiantum/hbsh.go index 5e301790..cdb1dcbb 100644 --- a/vfs/adiantum/hbsh.go +++ b/vfs/adiantum/hbsh.go @@ -24,11 +24,7 @@ func (h *hbshVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, } func (h *hbshVFS) OpenFilename(name *vfs.Filename, flags vfs.OpenFlag) (file vfs.File, _ vfs.OpenFlag, err error) { - if hf, ok := h.VFS.(vfs.VFSFilename); ok { - file, flags, err = hf.OpenFilename(name, flags) - } else { - file, flags, err = h.VFS.Open(name.String(), flags) - } + file, flags, err = vfsutil.WrapOpenFilename(h.VFS, name, flags) // Encrypt everything except super journals and memory files. if err != nil || flags&(vfs.OPEN_SUPER_JOURNAL|vfs.OPEN_MEMORY) != 0 { @@ -49,13 +45,14 @@ func (h *hbshVFS) OpenFilename(name *vfs.Filename, flags vfs.OpenFlag) (file vfs } else if t, ok := params["textkey"]; ok && len(t[0]) > 0 { key = h.init.KDF(t[0]) } else if flags&vfs.OPEN_MAIN_DB != 0 { - // Main datatabases may have their key specified as a PRAGMA. + // Main databases may have their key specified as a PRAGMA. return &hbshFile{File: file, init: h.init}, flags, nil } hbsh = h.init.HBSH(key) } if hbsh == nil { + file.Close() return nil, flags, sqlite3.CANTOPEN } return &hbshFile{File: file, hbsh: hbsh, init: h.init}, flags, nil diff --git a/vfs/cksmvfs/README.md b/vfs/cksmvfs/README.md new file mode 100644 index 00000000..dd4b26ea --- /dev/null +++ b/vfs/cksmvfs/README.md @@ -0,0 +1,20 @@ +# Go `cksmvfs` SQLite VFS + +This package wraps an SQLite VFS to help detect database corruption. + +The `"cksmvfs"` VFS wraps the default SQLite VFS adding an 8-byte checksum +to the end of every page in an SQLite database.\ +The checksum is added as each page is written +and verified as each page is read.\ +The checksum is intended to help detect database corruption +caused by random bit-flips in the mass storage device. + +This implementation is compatible with SQLite's +[Checksum VFS Shim](https://sqlite.org/cksumvfs.html). + +> [!IMPORTANT] +> [Checksums](https://en.wikipedia.org/wiki/Checksum) +> are meant to protect against _silent data corruption_ (bit rot). +> They do not offer _authenticity_ (i.e. protect against _forgery_), +> nor prevent _silent loss of durability_. +> Checkpoint WAL mode databases to improve durabiliy. \ No newline at end of file diff --git a/vfs/cksmvfs/api.go b/vfs/cksmvfs/api.go new file mode 100644 index 00000000..087022f4 --- /dev/null +++ b/vfs/cksmvfs/api.go @@ -0,0 +1,75 @@ +// Package cksmvfs wraps an SQLite VFS to help detect database corruption. +// +// The "cksmvfs" [vfs.VFS] wraps the default VFS adding an 8-byte checksum +// to the end of every page in an SQLite database. +// The checksum is added as each page is written +// and verified as each page is read. +// The checksum is intended to help detect database corruption +// caused by random bit-flips in the mass storage device. +// +// This implementation is compatible with SQLite's +// [Checksum VFS Shim]. +// +// [Checksum VFS Shim]: https://sqlite.org/cksumvfs.html +package cksmvfs + +import ( + "fmt" + + "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/vfs" +) + +func init() { + vfs.Register("cksmvfs", Wrap(vfs.Find(""))) +} + +// Wrap wraps a base VFS to create a checksumming VFS. +func Wrap(base vfs.VFS) vfs.VFS { + return &cksmVFS{VFS: base} +} + +// EnableChecksums enables checksums on a database. +func EnableChecksums(db *sqlite3.Conn, schema string) error { + if f, ok := db.Filename("").DatabaseFile().(*cksmFile); !ok { + return fmt.Errorf("cksmvfs: incorrect type: %T", f) + } + + r, err := db.FileControl(schema, sqlite3.FCNTL_RESERVE_BYTES) + if err != nil { + return err + } + if r == 8 { + // Correct value, enabled. + return nil + } + if r == 0 { + // Default value, enable. + _, err = db.FileControl(schema, sqlite3.FCNTL_RESERVE_BYTES, 8) + if err != nil { + return err + } + r, err = db.FileControl(schema, sqlite3.FCNTL_RESERVE_BYTES) + if err != nil { + return err + } + } + if r != 8 { + // Invalid value. + return fmt.Errorf("cksmvfs: reserve bytes must be 8, is: %d", r) + } + + // VACUUM the database. + if schema != "" { + err = db.Exec(`VACUUM ` + sqlite3.QuoteIdentifier(schema)) + } else { + err = db.Exec(`VACUUM`) + } + if err != nil { + return err + } + + // Checkpoint the WAL. + _, _, err = db.WALCheckpoint(schema, sqlite3.CHECKPOINT_RESTART) + return err +} diff --git a/vfs/cksmvfs/api_test.go b/vfs/cksmvfs/api_test.go new file mode 100644 index 00000000..04d4f393 --- /dev/null +++ b/vfs/cksmvfs/api_test.go @@ -0,0 +1,133 @@ +package cksmvfs_test + +import ( + _ "embed" + "log" + "path/filepath" + "strings" + "testing" + + "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/driver" + _ "github.com/ncruces/go-sqlite3/embed" + _ "github.com/ncruces/go-sqlite3/internal/testcfg" + "github.com/ncruces/go-sqlite3/util/ioutil" + "github.com/ncruces/go-sqlite3/vfs" + "github.com/ncruces/go-sqlite3/vfs/cksmvfs" + "github.com/ncruces/go-sqlite3/vfs/memdb" + "github.com/ncruces/go-sqlite3/vfs/readervfs" +) + +//go:embed testdata/cksm.db +var cksmDB string + +func Test_fileformat(t *testing.T) { + readervfs.Create("test.db", ioutil.NewSizeReaderAt(strings.NewReader(cksmDB))) + vfs.Register("rcksm", cksmvfs.Wrap(vfs.Find("reader"))) + + db, err := driver.Open("file:test.db?vfs=rcksm") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + var enabled bool + err = db.QueryRow(`PRAGMA checksum_verification`).Scan(&enabled) + if err != nil { + t.Fatal(err) + } + if !enabled { + t.Error("want true") + } + + db.SetMaxIdleConns(0) // Clears the page cache. + + _, err = db.Exec(`PRAGMA integrity_check`) + if err != nil { + t.Fatal(err) + } +} + +//go:embed testdata/test.db +var testDB []byte + +func Test_enable(t *testing.T) { + memdb.Create("nockpt.db", testDB) + vfs.Register("mcksm", cksmvfs.Wrap(vfs.Find("memdb"))) + + db, err := driver.Open("file:/nockpt.db?vfs=mcksm", + func(db *sqlite3.Conn) error { + return cksmvfs.EnableChecksums(db, "") + }) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + var enabled bool + err = db.QueryRow(`PRAGMA checksum_verification`).Scan(&enabled) + if err != nil { + t.Fatal(err) + } + if !enabled { + t.Error("want true") + } + + db.SetMaxIdleConns(0) // Clears the page cache. + + _, err = db.Exec(`PRAGMA integrity_check`) + if err != nil { + t.Fatal(err) + } +} + +func Test_new(t *testing.T) { + if !vfs.SupportsFileLocking { + t.Skip("skipping without locks") + } + + name := "file:" + + filepath.ToSlash(filepath.Join(t.TempDir(), "test.db")) + + "?vfs=cksmvfs&_pragma=journal_mode(wal)" + + db, err := driver.Open(name) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + var enabled bool + err = db.QueryRow(`PRAGMA checksum_verification`).Scan(&enabled) + if err != nil { + t.Fatal(err) + } + if !enabled { + t.Error("want true") + } + + var size int + err = db.QueryRow(`PRAGMA page_size=1024`).Scan(&size) + if err != nil { + t.Fatal(err) + } + if size != 4096 { + t.Errorf("got %d, want 4096", size) + } + + _, err = db.Exec(`CREATE TABLE users (id INT, name VARCHAR(10))`) + if err != nil { + log.Fatal(err) + } + + _, err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`) + if err != nil { + log.Fatal(err) + } + + db.SetMaxIdleConns(0) // Clears the page cache. + + _, err = db.Exec(`PRAGMA integrity_check`) + if err != nil { + t.Fatal(err) + } +} diff --git a/vfs/cksmvfs/cksmvfs.go b/vfs/cksmvfs/cksmvfs.go new file mode 100644 index 00000000..82db6965 --- /dev/null +++ b/vfs/cksmvfs/cksmvfs.go @@ -0,0 +1,234 @@ +package cksmvfs + +import ( + "bytes" + _ "embed" + "encoding/binary" + "io" + "runtime" + "strconv" + + "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/internal/util" + "github.com/ncruces/go-sqlite3/util/sql3util" + "github.com/ncruces/go-sqlite3/util/vfsutil" + "github.com/ncruces/go-sqlite3/vfs" +) + +type cksmVFS struct { + vfs.VFS +} + +func (c *cksmVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) { + // notest // OpenFilename is called instead + return nil, 0, sqlite3.CANTOPEN +} + +func (c *cksmVFS) OpenFilename(name *vfs.Filename, flags vfs.OpenFlag) (file vfs.File, _ vfs.OpenFlag, err error) { + // Prevent accidental wrapping. + if pc, _, _, ok := runtime.Caller(1); ok { + if fn := runtime.FuncForPC(pc); fn != nil { + if fn.Name() != "github.com/ncruces/go-sqlite3/vfs.vfsOpen" { + return nil, 0, sqlite3.CANTOPEN + } + } + } + + file, flags, err = vfsutil.WrapOpenFilename(c.VFS, name, flags) + + // Checksum only main databases and WALs. + if err != nil || flags&(vfs.OPEN_MAIN_DB|vfs.OPEN_WAL) == 0 { + return file, flags, err + } + + cksm := cksmFile{File: file} + + if flags&vfs.OPEN_WAL != 0 { + main, _ := name.DatabaseFile().(*cksmFile) + cksm.cksmFlags = main.cksmFlags + } else { + cksm.isDB = true + cksm.cksmFlags = new(cksmFlags) + } + const createDB = vfs.OPEN_CREATE | vfs.OPEN_READWRITE | vfs.OPEN_MAIN_DB + cksm.createDB = flags&createDB == createDB + + return &cksm, flags, err +} + +type cksmFile struct { + vfs.File + *cksmFlags + isDB bool + createDB bool +} + +type cksmFlags struct { + computeCksm bool + verifyCksm bool + inCkpt bool + pageSize int +} + +//go:embed empty.db +var empty string + +func (c *cksmFile) ReadAt(p []byte, off int64) (n int, err error) { + n, err = c.File.ReadAt(p, off) + + // SQLite is trying to read from the first page of an empty database file. + // Instead, read from an empty database that had checksums enabled, + // so checksums are enabled by default. + if c.createDB && n == 0 && err == io.EOF && off < 100 { + n = copy(p, empty[off:]) + if n < len(p) { + clear(p[n:]) + } + err = nil + } + + // SQLite is reading the header of a database file. + if c.isDB && off == 0 && len(p) >= 100 && + bytes.HasPrefix(p, []byte("SQLite format 3\000")) { + c.updateFlags(p) + } + + // Verify checksums. + if c.verifyCksm && !c.inCkpt && len(p) == c.pageSize { + cksm1 := cksmCompute(p[:len(p)-8]) + cksm2 := *(*[8]byte)(p[len(p)-8:]) + if cksm1 != cksm2 { + return 0, sqlite3.IOERR_DATA + } + } + return n, err +} + +func (c *cksmFile) WriteAt(p []byte, off int64) (n int, err error) { + // SQLite is writing the first page of a database file. + if c.isDB && off == 0 && len(p) >= 100 && + bytes.HasPrefix(p, []byte("SQLite format 3\000")) { + c.updateFlags(p) + } + + // Compute checksums. + if c.computeCksm && !c.inCkpt && len(p) == c.pageSize { + *(*[8]byte)(p[len(p)-8:]) = cksmCompute(p[:len(p)-8]) + } + + return c.File.WriteAt(p, off) +} + +func (c *cksmFile) updateFlags(header []byte) { + c.pageSize = 256 * int(binary.LittleEndian.Uint16(header[16:18])) + if r := header[20] == 8; r != c.computeCksm { + c.computeCksm = r + c.verifyCksm = r + } +} + +func (c *cksmFile) CheckpointStart() { + c.inCkpt = true +} + +func (c *cksmFile) CheckpointDone() { + c.inCkpt = false +} + +func (c *cksmFile) Pragma(name string, value string) (string, error) { + switch name { + case "checksum_verification": + b, ok := sql3util.ParseBool(value) + if ok { + c.verifyCksm = b && c.computeCksm + } + if !c.verifyCksm { + return "0", nil + } + return "1", nil + + case "page_size": + if c.computeCksm { + // Do not allow page size changes on a checksum database. + return strconv.Itoa(c.pageSize), nil + } + } + return vfsutil.WrapPragma(c.File, name, value) +} + +func cksmCompute(a []byte) (cksm [8]byte) { + var s1, s2 uint32 + for len(a) >= 8 { + s1 += binary.LittleEndian.Uint32(a[0:4]) + s2 + s2 += binary.LittleEndian.Uint32(a[4:8]) + s1 + a = a[8:] + } + if len(a) != 0 { + panic(util.AssertErr()) + } + binary.LittleEndian.PutUint32(cksm[0:4], s1) + binary.LittleEndian.PutUint32(cksm[4:8], s2) + return +} + +func (c *cksmFile) Unwrap() vfs.File { + return c.File +} + +func (c *cksmFile) SharedMemory() vfs.SharedMemory { + return vfsutil.WrapSharedMemory(c.File) +} + +// Wrap optional methods. + +func (c *cksmFile) LockState() vfs.LockLevel { + return vfsutil.WrapLockState(c.File) // notest +} + +func (c *cksmFile) PersistentWAL() bool { + return vfsutil.WrapPersistentWAL(c.File) // notest +} + +func (c *cksmFile) SetPersistentWAL(keepWAL bool) { + vfsutil.WrapSetPersistentWAL(c.File, keepWAL) // notest +} + +func (c *cksmFile) PowersafeOverwrite() bool { + return vfsutil.WrapPowersafeOverwrite(c.File) // notest +} + +func (c *cksmFile) SetPowersafeOverwrite(psow bool) { + vfsutil.WrapSetPowersafeOverwrite(c.File, psow) // notest +} + +func (c *cksmFile) ChunkSize(size int) { + vfsutil.WrapChunkSize(c.File, size) // notest +} + +func (c *cksmFile) SizeHint(size int64) error { + return vfsutil.WrapSizeHint(c.File, size) // notest +} + +func (c *cksmFile) HasMoved() (bool, error) { + return vfsutil.WrapHasMoved(c.File) // notest +} + +func (c *cksmFile) Overwrite() error { + return vfsutil.WrapOverwrite(c.File) // notest +} + +func (c *cksmFile) CommitPhaseTwo() error { + return vfsutil.WrapCommitPhaseTwo(c.File) // notest +} + +func (c *cksmFile) BeginAtomicWrite() error { + return vfsutil.WrapBeginAtomicWrite(c.File) // notest +} + +func (c *cksmFile) CommitAtomicWrite() error { + return vfsutil.WrapCommitAtomicWrite(c.File) // notest +} + +func (c *cksmFile) RollbackAtomicWrite() error { + return vfsutil.WrapRollbackAtomicWrite(c.File) // notest +} diff --git a/vfs/cksmvfs/empty.db b/vfs/cksmvfs/empty.db new file mode 100644 index 00000000..1093329c Binary files /dev/null and b/vfs/cksmvfs/empty.db differ diff --git a/vfs/cksmvfs/testdata/cksm.db b/vfs/cksmvfs/testdata/cksm.db new file mode 100644 index 00000000..0a46b4fe Binary files /dev/null and b/vfs/cksmvfs/testdata/cksm.db differ diff --git a/vfs/cksmvfs/testdata/test.db b/vfs/cksmvfs/testdata/test.db new file mode 100644 index 00000000..bd97e005 Binary files /dev/null and b/vfs/cksmvfs/testdata/test.db differ diff --git a/vfs/memdb/testdata/test.db b/vfs/memdb/testdata/test.db index 48ea4e75..bd97e005 100644 Binary files a/vfs/memdb/testdata/test.db and b/vfs/memdb/testdata/test.db differ diff --git a/vfs/memdb/testdata/wal.db b/vfs/memdb/testdata/wal.db index e113317f..90b6151c 100644 Binary files a/vfs/memdb/testdata/wal.db and b/vfs/memdb/testdata/wal.db differ diff --git a/vfs/readervfs/testdata/test.db b/vfs/readervfs/testdata/test.db index 48ea4e75..bd97e005 100644 Binary files a/vfs/readervfs/testdata/test.db and b/vfs/readervfs/testdata/test.db differ diff --git a/vfs/xts/aes_test.go b/vfs/xts/aes_test.go index d7b78cc3..f412a932 100644 --- a/vfs/xts/aes_test.go +++ b/vfs/xts/aes_test.go @@ -21,7 +21,7 @@ var testDB string func Test_fileformat(t *testing.T) { readervfs.Create("test.db", ioutil.NewSizeReaderAt(strings.NewReader(testDB))) - xts.Register("rxts", vfs.Find("reader"), nil) + vfs.Register("rxts", xts.Wrap(vfs.Find("reader"), nil)) db, err := driver.Open("file:test.db?vfs=rxts") if err != nil { diff --git a/vfs/xts/api.go b/vfs/xts/api.go index 4bf197df..c1be2e02 100644 --- a/vfs/xts/api.go +++ b/vfs/xts/api.go @@ -40,25 +40,26 @@ import ( ) func init() { - Register("xts", vfs.Find(""), nil) + vfs.Register("xts", Wrap(vfs.Find(""), nil)) } -// Register registers an encrypting VFS, wrapping a base VFS, -// and possibly using a custom XTS cipher construction. +// Wrap wraps a base VFS to create an encrypting VFS, +// possibly using a custom XTS cipher construction. +// // To use the default AES-XTS construction, set cipher to nil. // // The default construction uses AES-128, AES-192, or AES-256 // if the key/hexkey is 32, 48, or 64 bytes, respectively. // If a textkey is provided, the default KDF is PBKDF2-HMAC-SHA512 // with 10,000 iterations, always producing a 32 byte key. -func Register(name string, base vfs.VFS, cipher XTSCreator) { +func Wrap(base vfs.VFS, cipher XTSCreator) vfs.VFS { if cipher == nil { cipher = aesCreator{} } - vfs.Register(name, &xtsVFS{ + return &xtsVFS{ VFS: base, init: cipher, - }) + } } // XTSCreator creates an [xts.Cipher] diff --git a/vfs/xts/xts.go b/vfs/xts/xts.go index 9b75c2ea..4cc0d7f9 100644 --- a/vfs/xts/xts.go +++ b/vfs/xts/xts.go @@ -23,11 +23,7 @@ func (x *xtsVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, } func (x *xtsVFS) OpenFilename(name *vfs.Filename, flags vfs.OpenFlag) (file vfs.File, _ vfs.OpenFlag, err error) { - if hf, ok := x.VFS.(vfs.VFSFilename); ok { - file, flags, err = hf.OpenFilename(name, flags) - } else { - file, flags, err = x.VFS.Open(name.String(), flags) - } + file, flags, err = vfsutil.WrapOpenFilename(x.VFS, name, flags) // Encrypt everything except super journals and memory files. if err != nil || flags&(vfs.OPEN_SUPER_JOURNAL|vfs.OPEN_MEMORY) != 0 { @@ -48,13 +44,14 @@ func (x *xtsVFS) OpenFilename(name *vfs.Filename, flags vfs.OpenFlag) (file vfs. } else if t, ok := params["textkey"]; ok && len(t[0]) > 0 { key = x.init.KDF(t[0]) } else if flags&vfs.OPEN_MAIN_DB != 0 { - // Main datatabases may have their key specified as a PRAGMA. + // Main databases may have their key specified as a PRAGMA. return &xtsFile{File: file, init: x.init}, flags, nil } cipher = x.init.XTS(key) } if cipher == nil { + file.Close() return nil, flags, sqlite3.CANTOPEN } return &xtsFile{File: file, cipher: cipher, init: x.init}, flags, nil