diff --git a/lightning/src/util/persist.rs b/lightning/src/util/persist.rs index f5cb420d5bf..f26b2449303 100644 --- a/lightning/src/util/persist.rs +++ b/lightning/src/util/persist.rs @@ -30,6 +30,8 @@ use crate::routing::gossip::NetworkGraph; use crate::util::logger::Logger; use crate::util::ser::{Readable, ReadableArgs, Writeable}; +use self::types::{MonitorName, UpdateName}; + /// The namespace under which the [`ChannelManager`] will be persisted. pub const CHANNEL_MANAGER_PERSISTENCE_NAMESPACE: &str = ""; /// The key under which the [`ChannelManager`] will be persisted. @@ -198,10 +200,11 @@ where } Ok(res) } - enum KVStoreUpdatingPersisterError<'a> { /// The monitor name was improperly formatted. BadMonitorName { reason: &'a str, context: &'a str }, + /// The update name was improperly formatted. + BadUpdateName { reason: &'a str, context: &'a str }, /// The monitor could not be decoded. MonitorDecodeFailed { reason: DecodeError, @@ -223,100 +226,36 @@ impl<'a> From> for io::Error { match value { KVStoreUpdatingPersisterError::BadMonitorName { reason, context } => io::Error::new( io::ErrorKind::InvalidInput, - format!("BadMonitorName, {}, context: {}'", reason, context), + format!("BadMonitorName, {}, context: {}", reason, context), + ), + KVStoreUpdatingPersisterError::BadUpdateName { reason, context } => io::Error::new( + io::ErrorKind::InvalidInput, + format!("BadUpdateName, {}, context: {}", reason, context), ), KVStoreUpdatingPersisterError::MonitorDecodeFailed { reason, context } => { io::Error::new( io::ErrorKind::InvalidData, - format!("MonitorDecodeFailed, {}, context: {}'", reason, context), + format!("MonitorDecodeFailed, {}, context: {}", reason, context), ) } KVStoreUpdatingPersisterError::UpdateDecodeFailed { reason, context } => { io::Error::new( io::ErrorKind::InvalidData, - format!("UpdateDecodeFailed, {}, context: {}'", reason, context), + format!("UpdateDecodeFailed, {}, context: {}", reason, context), ) } KVStoreUpdatingPersisterError::StorageReadFailed { reason, context } => io::Error::new( io::ErrorKind::Other, - format!("StorageReadFailed, {}, context: {}'", reason, context), + format!("StorageReadFailed, {}, context: {}", reason, context), ), KVStoreUpdatingPersisterError::UpdateFailed { reason, context } => io::Error::new( io::ErrorKind::InvalidData, - format!("UpdateFailed, {}, context: {}'", reason, context), + format!("UpdateFailed, {}, context: {}", reason, context), ), } } } -/// A struct representing a name for a monitor. -#[derive(Clone, Debug)] -pub struct MonitorName(String); - -impl MonitorName { - /// The key to store this monitor with. - fn storage_key(&self) -> &str { - &self.0 - } -} - -impl TryFrom for OutPoint { - type Error = std::io::Error; - - fn try_from(value: MonitorName) -> Result { - let mut parts = value.0.splitn(2, '_'); - let txid_hex = - parts - .next() - .ok_or_else(|| KVStoreUpdatingPersisterError::BadMonitorName { - reason: "no txid found, maybe there is no underscore", - context: &value.0, - })?; - let index = parts - .next() - .ok_or_else(|| KVStoreUpdatingPersisterError::BadMonitorName { - reason: "no index value found after underscore", - context: &value.0, - })?; - let index = index - .parse() - .map_err(|_| KVStoreUpdatingPersisterError::BadMonitorName { - reason: "could not parse index value in monitor name", - context: &value.0, - })?; - let txid = Txid::from_hex(txid_hex).map_err(|_| { - KVStoreUpdatingPersisterError::BadMonitorName { - reason: "bad txid in monitor name", - context: &value.0, - } - })?; - Ok(OutPoint { txid, index }) - } -} - -impl From for MonitorName { - fn from(value: OutPoint) -> Self { - MonitorName(format!("{}_{}", value.txid.to_hex(), value.index)) - } -} - -/// A struct representing a name for an update. -#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord)] -pub struct UpdateName(String); - -impl UpdateName { - /// The key to store this update with. - fn storage_key(&self) -> &str { - &self.0 - } -} - -impl From for UpdateName { - fn from(value: u64) -> Self { - Self(format!("{:0>20}", value)) - } -} - struct KVStoreUpdatingPersister where K::Target: KVStore, @@ -368,7 +307,7 @@ where .update_monitor(&update, broadcaster, fee_estimator.clone(), &self.logger) .map_err(|_| KVStoreUpdatingPersisterError::UpdateFailed { reason: "update_monitor returned Err(())", - context: &monitor_name.0, + context: monitor_name.as_str(), })?; } } @@ -383,7 +322,7 @@ where // Append the monitor name to the namespace with an underscore. [ CHANNEL_MONITOR_UPDATE_PERSISTENCE_NAMESPACE, - monitor_name.storage_key(), + monitor_name.as_str(), ] .join("_") } @@ -391,15 +330,21 @@ where /// List all the names of monitors. fn list_monitor_names(&self) -> io::Result> { let list = self.kv.list(CHANNEL_MONITOR_PERSISTENCE_NAMESPACE)?; - Ok(list.into_iter().map(MonitorName).collect()) + list.into_iter().map(MonitorName::new).collect() } /// List all the names of updates corresponding to a given monitor name. fn list_update_names(&self, monitor_name: &MonitorName) -> io::Result> { let update_ns = self.monitor_update_namespace(monitor_name); - let mut list = self.kv.list(&update_ns)?; - list.sort(); - Ok(list.into_iter().map(UpdateName).collect()) + let list = self.kv.list(&update_ns)?; + + let mut update_name_list: Vec = list + .into_iter() + .map(UpdateName::new) + .collect::>>()?; + + update_name_list.sort(); + Ok(update_name_list) } /// Deserialize a channel monitor. @@ -416,7 +361,7 @@ where ES::Target: EntropySource + Sized, SP::Target: SignerProvider + Sized, { - let key = monitor_name.storage_key(); + let key = monitor_name.as_str(); let outpoint: OutPoint = monitor_name.to_owned().try_into()?; match <( BlockHash, @@ -451,8 +396,6 @@ where } } - - /// Deserialize a channel monitor update. fn deserialize_monitor_update( &self, @@ -460,7 +403,7 @@ where update_name: &UpdateName, ) -> io::Result { let ns = self.monitor_update_namespace(monitor_name); - let key = update_name.storage_key(); + let key = update_name.as_str(); Ok( ChannelMonitorUpdate::read(&mut self.kv.read(&ns, &key).map_err(|e| { KVStoreUpdatingPersisterError::StorageReadFailed { @@ -489,7 +432,7 @@ where && update.update_id <= monitor.get_latest_update_id() { let ns = self.monitor_update_namespace(&monitor_name); - let key = update_name.storage_key(); + let key = update_name.as_str(); self.kv.remove(&ns, key)?; } } @@ -513,10 +456,11 @@ where monitor: &ChannelMonitor, _update_id: MonitorUpdateId, ) -> chain::ChannelMonitorUpdateStatus { - let key = match MonitorName::try_from(funding_txo) { - Ok(monitor_name) => monitor_name.0, + let monitor_name = match MonitorName::try_from(funding_txo) { + Ok(n) => n, Err(_) => return chain::ChannelMonitorUpdateStatus::PermanentFailure, }; + let key = monitor_name.as_str(); match self.kv.write( CHANNEL_MONITOR_PERSISTENCE_NAMESPACE, &key, @@ -553,10 +497,11 @@ where Ok(monitor_name) => self.monitor_update_namespace(&monitor_name), Err(_) => return chain::ChannelMonitorUpdateStatus::PermanentFailure, }; - let key = match UpdateName::try_from(update.update_id) { - Ok(update_name) => update_name.0, + let update_name = match UpdateName::try_from(update.update_id) { + Ok(update_name) => update_name, Err(_) => return chain::ChannelMonitorUpdateStatus::PermanentFailure, }; + let key = update_name.as_str(); match self.kv.write(&ns, &key, &update.encode()) { Ok(()) => chain::ChannelMonitorUpdateStatus::Completed, Err(_) => chain::ChannelMonitorUpdateStatus::PermanentFailure, @@ -569,6 +514,131 @@ where } } +pub mod types { + use crate::chain::transaction::OutPoint; + use crate::util::persist::KVStoreUpdatingPersisterError; + use bitcoin::hashes::hex::{FromHex, ToHex}; + use bitcoin::Txid; + use core::convert::TryFrom; + use std::io::Error; + + /// A struct representing a name for a monitor. + #[derive(Clone, Debug)] + pub struct MonitorName(String); + + impl MonitorName { + /// Verifies that an [`OutPoint`] can be formed from the given `name`. + pub fn new(name: String) -> Result { + let me = Self(name); + me.do_try_into_outpoint()?; + Ok(me) + } + /// Convert this monitor name to a str. + pub fn as_str(&self) -> &str { + &self.0 + } + fn do_try_into_outpoint(&self) -> Result { + let mut parts = self.0.splitn(2, '_'); + let txid_hex = + parts + .next() + .ok_or_else(|| KVStoreUpdatingPersisterError::BadMonitorName { + reason: "name is not a splittable string", + context: &self.0, + })?; + let index = + parts + .next() + .ok_or_else(|| KVStoreUpdatingPersisterError::BadMonitorName { + reason: "no index value found after underscore", + context: &self.0, + })?; + let index = + index + .parse() + .map_err(|_| KVStoreUpdatingPersisterError::BadMonitorName { + reason: "could not parse index value in monitor name", + context: &self.0, + })?; + let txid = Txid::from_hex(txid_hex).map_err(|_| { + KVStoreUpdatingPersisterError::BadMonitorName { + reason: "bad txid in monitor name", + context: &self.0, + } + })?; + Ok(OutPoint { txid, index }) + } + } + + impl TryFrom for OutPoint { + type Error = Error; + + fn try_from(value: MonitorName) -> Result { + value.do_try_into_outpoint() + } + } + + impl From for MonitorName { + fn from(value: OutPoint) -> Self { + MonitorName(format!("{}_{}", value.txid.to_hex(), value.index)) + } + } + + /// A struct representing a name for an update. + #[derive(Clone, Debug)] + pub struct UpdateName(u64, String); + + impl UpdateName { + /// Validates that an update sequence ID can be derived from the given `name`. + pub fn new(name: String) -> Result { + match name.parse::() { + Ok(u) => Ok(Self(u, UpdateName::left_padded_string(u))), + Err(_) => Err(KVStoreUpdatingPersisterError::BadUpdateName { + reason: "cannot parse u64 from update name", + context: &name, + } + .into()), + } + } + + /// Convert this monitor update name to a &str + pub fn as_str(&self) -> &str { + &self.1 + } + + /// Left-pad the sequential update id for string representation + fn left_padded_string(value: u64) -> String { + format!("{:0>20}", value) + } + } + + impl From for UpdateName { + fn from(value: u64) -> Self { + Self(value, UpdateName::left_padded_string(value)) + } + } + + impl PartialEq for UpdateName { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } + } + + impl PartialOrd for UpdateName { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + + impl Eq for UpdateName {} + + impl Ord for UpdateName { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.0.cmp(&other.0) + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -590,8 +660,7 @@ mod tests { } } - fn read_only_test_persister() -> KVStoreUpdatingPersister, Box> - { + fn read_only_test_persister() -> KVStoreUpdatingPersister, Box> { KVStoreUpdatingPersister { kv: Box::new(TestStore::new(false)), logger: Box::new(TestLogger::new()), @@ -603,7 +672,7 @@ mod tests { // ================================= #[test] - fn test_update_from_u64_and_storage_key_works() { + fn test_update_from_u64_and_as_str_works() { let cases = [ (0, "00000000000000000000"), (21, "00000000000000000021"), @@ -611,13 +680,31 @@ mod tests { ]; for (input, expected) in &cases { let update_name = UpdateName::from(*input); - assert_eq!(update_name.0, *expected); - assert_eq!(update_name.storage_key(), *expected); + assert_eq!(update_name.as_str(), *expected); } } #[test] - fn test_monitor_from_outpoint_and_storage_key_works() { + fn test_update_from_bad_string_fails() { + let cases = [ + ("deadb33f".to_string(), "cannot parse u64"), + ("-0000000000000000001".to_string(), "cannot parse u64"), + ]; + for (input, expected) in &cases { + let got = UpdateName::new(input.to_owned()) + .expect_err("bad update name accepted") + .to_string(); + assert!( + got.contains(*expected), + "wrong update name error\nexpected: {}\ngot: {}", + expected, + got + ); + } + } + + #[test] + fn test_monitor_from_outpoint_and_as_str_works() { let cases = [ ( OutPoint { @@ -642,8 +729,37 @@ mod tests { ]; for (input, expected) in &cases { let monitor_name = MonitorName::from(*input); - assert_eq!(monitor_name.0, *expected); - assert_eq!(monitor_name.storage_key(), *expected); + assert_eq!(monitor_name.as_str(), *expected); + } + } + + #[test] + fn test_monitor_from_bad_string_fails() { + let cases = [ + ( + "deadb33fdeadb33fdeadb33fdeadb33fdeadb33fdeadb33fdeadb33fdeadb33f".to_string(), + "no index", + ), + ( + "deadb33fdeadb33fdeadb33fdeadb33fdeadb33fdeadb33fdeadb33fdeadb33f_65536" + .to_string(), + "could not parse index value", + ), + ( + "deadb33fdeadb33fdeadb33fdeadb33fdeadb33fdeadb33fdeadb33f_21".to_string(), + "bad txid", + ), + ]; + for (input, expected) in &cases { + let got = MonitorName::new(input.to_owned()) + .expect_err("bad monitor name accepted") + .to_string(); + assert!( + got.contains(*expected), + "wrong monitor name error\nexpected: {}\ngot: {}", + expected, + got + ); } } @@ -653,17 +769,29 @@ mod tests { // Add some dummy entries for monitor files persister .kv - .write(CHANNEL_MONITOR_PERSISTENCE_NAMESPACE, "deadb33f_0", &[0]) + .write( + CHANNEL_MONITOR_PERSISTENCE_NAMESPACE, + "deadb33fdeadb33fdeadb33fdeadb33fdeadb33fdeadb33fdeadb33fdeadb33f_0", + &[0], + ) .unwrap(); persister .kv - .write(CHANNEL_MONITOR_PERSISTENCE_NAMESPACE, "feedbeef_1", &[0]) + .write( + CHANNEL_MONITOR_PERSISTENCE_NAMESPACE, + "f33dbeeff33dbeeff33dbeeff33dbeeff33dbeeff33dbeeff33dbeeff33dbeef_65535", + &[0], + ) .unwrap(); // Test that these monitors are found where they should be. let listed_monitor_names = persister.list_monitor_names().unwrap(); assert_eq!(listed_monitor_names.len(), 2); - assert!(listed_monitor_names.iter().any(|m| m.0 == "deadb33f_0")); - assert!(listed_monitor_names.iter().any(|m| m.0 == "feedbeef_1")); + assert!(listed_monitor_names + .iter() + .any(|m| m.as_str() + == "deadb33fdeadb33fdeadb33fdeadb33fdeadb33fdeadb33fdeadb33fdeadb33f_0")); + assert!(listed_monitor_names.iter().any(|m| m.as_str() + == "f33dbeeff33dbeeff33dbeeff33dbeeff33dbeeff33dbeeff33dbeeff33dbeef_65535")); } #[test] @@ -682,7 +810,7 @@ mod tests { let update_ns = persister.monitor_update_namespace(&monitor_name); let expected_ns = [ CHANNEL_MONITOR_UPDATE_PERSISTENCE_NAMESPACE, - &monitor_name.storage_key(), + &monitor_name.as_str(), ] .join("_"); assert_eq!(update_ns, expected_ns); @@ -705,7 +833,7 @@ mod tests { .kv .write( &persister.monitor_update_namespace(&monitor_name), - &UpdateName::from(i).storage_key(), + &UpdateName::from(i).as_str(), &[0], ) .unwrap();