Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wallet] Reopen CDBEnv after encryption instead of shutting down #2648

Merged
merged 9 commits into from
Jan 11, 2022
5 changes: 2 additions & 3 deletions src/qt/askpassphrasedialog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,9 @@ void AskPassphraseDialog::warningMessage()
openStandardDialog(
tr("Wallet encrypted"),
"<qt>" +
tr("%1 will close now to finish the encryption process. "
tr("Your wallet is now encrypted. "
"Remember that encrypting your wallet cannot fully protect "
"your PIVs from being stolen by malware infecting your computer.").arg(PACKAGE_NAME) +
"your PIVs from being stolen by malware infecting your computer.") +
"<br><br><b>" +
tr("IMPORTANT: Any previous backups you have made of your wallet file "
"should be replaced with the newly generated, encrypted wallet file. "
Expand All @@ -362,7 +362,6 @@ void AskPassphraseDialog::warningMessage()
"</b></qt>",
tr("OK")
);
QApplication::quit();
}

void AskPassphraseDialog::errorEncryptingWallet()
Expand Down
12 changes: 11 additions & 1 deletion src/test/librust/sapling_rpc_wallet_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "key_io.h"
#include "consensus/merkle.h"
#include "wallet/wallet.h"
#include "wallet/walletutil.h"

#include "sapling/key_io_sapling.h"
#include "sapling/address.h"
Expand Down Expand Up @@ -497,7 +498,14 @@ BOOST_AUTO_TEST_CASE(rpc_shieldsendmany_taddr_to_sapling)
vpwallets.erase(vpwallets.begin());
}

BOOST_AUTO_TEST_CASE(rpc_wallet_encrypted_wallet_sapzkeys)
struct RealWalletTestingSetup : public WalletTestingSetupBase
{
RealWalletTestingSetup() : WalletTestingSetupBase(CBaseChainParams::MAIN,
"test_wallet",
WalletDatabase::Create(fs::absolute("test_wallet", GetWalletDir()))) {};
};

BOOST_FIXTURE_TEST_CASE(rpc_wallet_encrypted_wallet_sapzkeys, RealWalletTestingSetup)
{
UniValue retValue;
int n = 100;
Expand Down Expand Up @@ -535,6 +543,7 @@ BOOST_AUTO_TEST_CASE(rpc_wallet_encrypted_wallet_sapzkeys)

PushCurrentDirectory push_dir(gArgs.GetArg("-datadir","/tmp/thisshouldnothappen"));
BOOST_CHECK(m_wallet.EncryptWallet(strWalletPass));
BOOST_CHECK(m_wallet.IsCrypted());

// Verify we can still list the keys imported
BOOST_CHECK_NO_THROW(retValue = CallRPC("listshieldaddresses"));
Expand All @@ -547,6 +556,7 @@ BOOST_AUTO_TEST_CASE(rpc_wallet_encrypted_wallet_sapzkeys)
// We can't call RPC walletpassphrase as that invokes RPCRunLater which breaks tests.
// So we manually unlock.
BOOST_CHECK(m_wallet.Unlock(strWalletPass));
BOOST_CHECK(m_wallet.IsCrypted());

// Now add a key
BOOST_CHECK_NO_THROW(CallRPC("getnewshieldaddress"));
Expand Down
84 changes: 63 additions & 21 deletions src/wallet/db.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@


namespace {

//! Make sure database has a unique fileid within the environment. If it
//! doesn't, throw an error. BDB caches do not work properly when more than one
//! open database has the same fileid (values written to one database may show
Expand All @@ -33,25 +34,19 @@ namespace {
//! (https://docs.oracle.com/cd/E17275_01/html/programmer_reference/program_copy.html),
//! so bitcoin should never create different databases with the same fileid, but
//! this error can be triggered if users manually copy database files.
void CheckUniqueFileid(const BerkeleyEnvironment& env, const std::string& filename, Db& db)
void CheckUniqueFileid(const BerkeleyEnvironment& env, const std::string& filename, Db& db, WalletDatabaseFileId& fileid)
{
if (env.IsMock()) return;

u_int8_t fileid[DB_FILE_ID_LEN];
int ret = db.get_mpf()->get_fileid(fileid);
int ret = db.get_mpf()->get_fileid(fileid.value);
if (ret != 0) {
throw std::runtime_error(strprintf("BerkeleyBatch: Can't open database %s (get_fileid failed with %d)", filename, ret));
}

for (const auto& item : env.mapDb) {
u_int8_t item_fileid[DB_FILE_ID_LEN];
if (item.second && item.second->get_mpf()->get_fileid(item_fileid) == 0 &&
memcmp(fileid, item_fileid, sizeof(fileid)) == 0) {
const char* item_filename = nullptr;
item.second->get_dbname(&item_filename, nullptr);
for (const auto& item : env.m_fileids) {
if (fileid == item.second && &fileid != &item.second) {
throw std::runtime_error(strprintf("BerkeleyBatch: Can't open database %s (duplicates fileid %s from %s)", filename,
HexStr(item_fileid),
item_filename ? item_filename : "(unknown database)"));
HexStr(item.second.value), item.first));
}
}
}
Expand All @@ -60,6 +55,11 @@ RecursiveMutex cs_db;
std::map<std::string, BerkeleyEnvironment> g_dbenvs; //!< Map from directory name to open db environment.
} // namespace

bool WalletDatabaseFileId::operator==(const WalletDatabaseFileId& rhs) const
{
return memcmp(value, &rhs.value, sizeof(value)) == 0;
}

BerkeleyEnvironment* GetWalletEnv(const fs::path& wallet_path, std::string& database_filename)
{
fs::path env_directory;
Expand Down Expand Up @@ -106,7 +106,7 @@ void BerkeleyEnvironment::Close()

int ret = dbenv->close(0);
if (ret != 0)
LogPrintf("Error %d shutting down database environment: %s\n", ret, DbEnv::strerror(ret));
LogPrintf("%s: Error %d closing database environment: %s\n", __func__, ret, DbEnv::strerror(ret));
if (!fMockDb)
DbEnv((u_int32_t)0).remove(strPath.c_str(), 0);
}
Expand Down Expand Up @@ -172,8 +172,12 @@ bool BerkeleyEnvironment::Open(bool retry)
nEnvFlags,
S_IRUSR | S_IWUSR);
if (ret != 0) {
dbenv->close(0);
LogPrintf("BerkeleyEnvironment::Open: Error %d opening database environment: %s\n", ret, DbEnv::strerror(ret));
LogPrintf("%s: Error %d opening database environment: %s\n", __func__, ret, DbEnv::strerror(ret));
int ret2 = dbenv->close(0);
if (ret2 != 0) {
LogPrintf("%s: Error %d closing failed database environment: %s\n", __func__, ret2, DbEnv::strerror(ret2));
}
Reset();
if (retry) {
// try moving the database env out of the way
fs::path pathDatabaseBak = pathIn / strprintf("database.%d.bak", GetTime());
Expand Down Expand Up @@ -499,8 +503,8 @@ BerkeleyBatch::BerkeleyBatch(BerkeleyDatabase& database, const char* pszMode, bo
// be implemented, so no equality checks are needed at all. (Newer
// versions of BDB have an set_lk_exclusive method for this
// purpose, but the older version we use does not.)
for (auto& env : g_dbenvs) {
CheckUniqueFileid(env.second, strFilename, *pdb_temp);
for (const auto& env : g_dbenvs) {
CheckUniqueFileid(env.second, strFilename, *pdb_temp, this->env->m_fileids[strFilename]);
}

pdb = pdb_temp.release();
Expand Down Expand Up @@ -552,6 +556,7 @@ void BerkeleyBatch::Close()
LOCK(cs_db);
--env->mapFileUseCount[strFile];
}
env->m_db_in_use.notify_all();
}

void BerkeleyEnvironment::CloseDb(const std::string& strFile)
Expand All @@ -568,6 +573,32 @@ void BerkeleyEnvironment::CloseDb(const std::string& strFile)
}
}

void BerkeleyEnvironment::ReloadDbEnv()
{
// Make sure that no Db's are in use
AssertLockNotHeld(cs_db);
std::unique_lock<RecursiveMutex> lock(cs_db);
m_db_in_use.wait(lock, [this](){
for (auto& count : mapFileUseCount) {
if (count.second > 0) return false;
}
return true;
});

std::vector<std::string> filenames;
for (auto it : mapDb) {
filenames.push_back(it.first);
}
// Close the individual Db's
for (const std::string& filename : filenames) {
CloseDb(filename);
}
// Reset the environment
Flush(true); // This will flush and close the environment
Reset();
Open(true);
}

bool BerkeleyBatch::Rewrite(BerkeleyDatabase& database, const char* pszSkip)
{
if (database.IsDummy()) {
Expand Down Expand Up @@ -693,7 +724,6 @@ void BerkeleyEnvironment::Flush(bool fShutdown)
if (!fMockDb) {
fs::remove_all(fs::path(strPath) / "database");
}
g_dbenvs.erase(strPath);
}
}
}
Expand Down Expand Up @@ -793,12 +823,24 @@ void BerkeleyDatabase::Flush(bool shutdown)
{
if (!IsDummy()) {
env->Flush(shutdown);
if (shutdown) env = nullptr;
if (shutdown) {
LOCK(cs_db);
g_dbenvs.erase(env->Directory().string());
env = nullptr;
} else {
// TODO: To avoid g_dbenvs.erase erasing the environment prematurely after the
// first database shutdown when multiple databases are open in the same
// environment, should replace raw database `env` pointers with shared or weak
// pointers, or else separate the database and environment shutdowns so
// environments can be shut down after databases.
env->m_fileids.erase(strFile);
}
}
}

void BerkeleyDatabase::CloseAndReset()
void BerkeleyDatabase::ReloadDbEnv()
{
env->Close();
env->Reset();
if (!IsDummy()) {
env->ReloadDbEnv();
}
}
16 changes: 12 additions & 4 deletions src/wallet/db.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@
#include <atomic>
#include <map>
#include <string>
#include <unordered_map>
#include <vector>

#include <db_cxx.h>

static const unsigned int DEFAULT_WALLET_DBLOGSIZE = 100;
static const bool DEFAULT_WALLET_PRIVDB = true;

struct WalletDatabaseFileId {
u_int8_t value[DB_FILE_ID_LEN];
bool operator==(const WalletDatabaseFileId& rhs) const;
};

class BerkeleyEnvironment
{
private:
Expand All @@ -37,6 +43,8 @@ class BerkeleyEnvironment
std::unique_ptr<DbEnv> dbenv;
std::map<std::string, int> mapFileUseCount;
std::map<std::string, Db*> mapDb;
std::unordered_map<std::string, WalletDatabaseFileId> m_fileids;
std::condition_variable_any m_db_in_use;

BerkeleyEnvironment(const fs::path& env_directory);
~BerkeleyEnvironment();
Expand Down Expand Up @@ -75,6 +83,7 @@ class BerkeleyEnvironment
void CheckpointLSN(const std::string& strFile);

void CloseDb(const std::string& strFile);
void ReloadDbEnv();

DbTxn* TxnBegin(int flags = DB_TXN_WRITE_NOSYNC)
{
Expand Down Expand Up @@ -143,11 +152,10 @@ class BerkeleyDatabase
*/
void Flush(bool shutdown);

/** Close and reset.
*/
void CloseAndReset();

void IncrementUpdateCounter();

void ReloadDbEnv();

std::atomic<unsigned int> nUpdateCounter;
unsigned int nLastSeen;
unsigned int nLastFlushed;
Expand Down
7 changes: 1 addition & 6 deletions src/wallet/rpcwallet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3630,7 +3630,6 @@ UniValue encryptwallet(const JSONRPCRequest& request)
"will require the passphrase to be set prior the making these calls.\n"
"Use the walletpassphrase call for this, and then walletlock call.\n"
"If the wallet is already encrypted, use the walletpassphrasechange call.\n"
"Note that this will shutdown the server.\n"

"\nArguments:\n"
"1. \"passphrase\" (string) The pass phrase to encrypt the wallet with. It must be at least 1 character, but should be long.\n"
Expand Down Expand Up @@ -3669,11 +3668,7 @@ UniValue encryptwallet(const JSONRPCRequest& request)
if (!pwallet->EncryptWallet(strWalletPass))
throw JSONRPCError(RPC_WALLET_ENCRYPTION_FAILED, "Error: Failed to encrypt the wallet.");

// BDB seems to have a bad habit of writing old data into
// slack space in .dat files; that is bad if the old data is
// unencrypted private keys. So:
StartShutdown();
return "wallet encrypted; pivx server stopping, restart to run with encrypted wallet. The keypool has been flushed, you need to make a new backup.";
return "wallet encrypted; The keypool has been flushed, you need to make a new backup.";
}

UniValue listunspent(const JSONRPCRequest& request)
Expand Down
11 changes: 8 additions & 3 deletions src/wallet/test/wallet_test_fixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
#include "wallet/rpcwallet.h"
#include "wallet/wallet.h"

WalletTestingSetup::WalletTestingSetup(const std::string& chainName):
SaplingTestingSetup(chainName), m_wallet("mock", WalletDatabase::CreateMock())
WalletTestingSetupBase::WalletTestingSetupBase(const std::string& chainName,
const std::string& wallet_name,
std::unique_ptr<WalletDatabase> db) :
SaplingTestingSetup(chainName), m_wallet(wallet_name, std::move(db))
{
bool fFirstRun;
m_wallet.LoadWallet(fFirstRun);
Expand All @@ -20,7 +22,10 @@ WalletTestingSetup::WalletTestingSetup(const std::string& chainName):
RegisterWalletRPCCommands(tableRPC);
}

WalletTestingSetup::~WalletTestingSetup()
WalletTestingSetupBase::~WalletTestingSetupBase()
{
UnregisterValidationInterface(&m_wallet);
}

WalletTestingSetup::WalletTestingSetup(const std::string& chainName) :
WalletTestingSetupBase(chainName, "mock", WalletDatabase::CreateMock()) {}
14 changes: 10 additions & 4 deletions src/wallet/test/wallet_test_fixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,20 @@

/** Testing setup and teardown for wallet.
*/
struct WalletTestingSetup : public SaplingTestingSetup
struct WalletTestingSetupBase : public SaplingTestingSetup
{
WalletTestingSetup(const std::string& chainName = CBaseChainParams::MAIN);
~WalletTestingSetup();

WalletTestingSetupBase(const std::string& chainName,
const std::string& wallet_name,
std::unique_ptr<WalletDatabase> db);
~WalletTestingSetupBase();
CWallet m_wallet;
};

struct WalletTestingSetup : public WalletTestingSetupBase
{
WalletTestingSetup(const std::string& chainName = CBaseChainParams::MAIN);
};

struct WalletRegTestingSetup : public WalletTestingSetup
{
WalletRegTestingSetup() : WalletTestingSetup(CBaseChainParams::REGTEST) {}
Expand Down
6 changes: 6 additions & 0 deletions src/wallet/wallet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,12 @@ bool CWallet::EncryptWallet(const SecureString& strWalletPassphrase)
// Need to completely rewrite the wallet file; if we don't, bdb might keep
// bits of the unencrypted private key in slack space in the database file.
database->Rewrite();

// BDB seems to have a bad habit of writing old data into
// slack space in .dat files; that is bad if the old data is
// unencrypted private keys. So:
database->ReloadDbEnv();

}
NotifyStatusChanged(this);

Expand Down
7 changes: 1 addition & 6 deletions test/functional/rpc_fundrawtransaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
assert_raises_rpc_error,
assert_greater_than,
assert_greater_than_or_equal,
connect_nodes,
count_bytes,
find_vout_for_address,
Decimal,
Expand Down Expand Up @@ -371,11 +370,7 @@ def test_spend_2of2(self):
def test_locked_wallet(self):
self.log.info("test locked wallet")

self.nodes[1].node_encrypt_wallet("test")
self.start_node(1)
connect_nodes(self.nodes[0], 1)
connect_nodes(self.nodes[1], 2)
self.sync_all()
self.nodes[1].encryptwallet("test")

# Drain the keypool.
self.nodes[1].getnewaddress()
Expand Down
Loading