diff --git a/Changelog.md b/Changelog.md index 4080b39e..8b056f67 100644 --- a/Changelog.md +++ b/Changelog.md @@ -12,6 +12,10 @@ ### New Features +- [#581]: Allow `Deserializer` to set `quick_xml::de::EntityResolver` for + resolving unknown entities that would otherwise cause the parser to return + an [`EscapeError::UnrecognizedSymbol`] error. + ### Bug Fixes ### Misc Changes diff --git a/src/de/map.rs b/src/de/map.rs index 8c13554f..37226787 100644 --- a/src/de/map.rs +++ b/src/de/map.rs @@ -2,6 +2,7 @@ use crate::{ de::key::QNameDeserializer, + de::resolver::EntityResolver, de::simple_type::SimpleTypeDeserializer, de::{str2bool, DeEvent, Deserializer, XmlRead, TEXT_KEY, VALUE_KEY}, encoding::Decoder, @@ -165,13 +166,14 @@ enum ValueSource { /// /// - `'a` lifetime represents a parent deserializer, which could own the data /// buffer. -pub(crate) struct MapAccess<'de, 'a, R> +pub(crate) struct MapAccess<'de, 'a, R, E> where R: XmlRead<'de>, + E: EntityResolver, { /// Tag -- owner of attributes start: BytesStart<'de>, - de: &'a mut Deserializer<'de, R>, + de: &'a mut Deserializer<'de, R, E>, /// State of the iterator over attributes. Contains the next position in the /// inner `start` slice, from which next attribute should be parsed. iter: IterState, @@ -190,13 +192,14 @@ where has_value_field: bool, } -impl<'de, 'a, R> MapAccess<'de, 'a, R> +impl<'de, 'a, R, E> MapAccess<'de, 'a, R, E> where R: XmlRead<'de>, + E: EntityResolver, { /// Create a new MapAccess pub fn new( - de: &'a mut Deserializer<'de, R>, + de: &'a mut Deserializer<'de, R, E>, start: BytesStart<'de>, fields: &'static [&'static str], ) -> Result { @@ -211,9 +214,10 @@ where } } -impl<'de, 'a, R> de::MapAccess<'de> for MapAccess<'de, 'a, R> +impl<'de, 'a, R, E> de::MapAccess<'de> for MapAccess<'de, 'a, R, E> where R: XmlRead<'de>, + E: EntityResolver, { type Error = DeError; @@ -369,13 +373,14 @@ macro_rules! forward { /// A deserializer for a value of map or struct. That deserializer slightly /// differently processes events for a primitive types and sequences than /// a [`Deserializer`]. -struct MapValueDeserializer<'de, 'a, 'm, R> +struct MapValueDeserializer<'de, 'a, 'm, R, E> where R: XmlRead<'de>, + E: EntityResolver, { /// Access to the map that created this deserializer. Gives access to the /// context, such as list of fields, that current map known about. - map: &'m mut MapAccess<'de, 'a, R>, + map: &'m mut MapAccess<'de, 'a, R, E>, /// Determines, should [`Deserializer::read_string_impl()`] expand the second /// level of tags or not. /// @@ -453,9 +458,10 @@ where allow_start: bool, } -impl<'de, 'a, 'm, R> MapValueDeserializer<'de, 'a, 'm, R> +impl<'de, 'a, 'm, R, E> MapValueDeserializer<'de, 'a, 'm, R, E> where R: XmlRead<'de>, + E: EntityResolver, { /// Returns a next string as concatenated content of consequent [`Text`] and /// [`CData`] events, used inside [`deserialize_primitives!()`]. @@ -468,9 +474,10 @@ where } } -impl<'de, 'a, 'm, R> de::Deserializer<'de> for MapValueDeserializer<'de, 'a, 'm, R> +impl<'de, 'a, 'm, R, E> de::Deserializer<'de> for MapValueDeserializer<'de, 'a, 'm, R, E> where R: XmlRead<'de>, + E: EntityResolver, { type Error = DeError; @@ -629,13 +636,14 @@ impl<'de> TagFilter<'de> { /// /// [`Text`]: crate::events::Event::Text /// [`CData`]: crate::events::Event::CData -struct MapValueSeqAccess<'de, 'a, 'm, R> +struct MapValueSeqAccess<'de, 'a, 'm, R, E> where R: XmlRead<'de>, + E: EntityResolver, { /// Accessor to a map that creates this accessor and to a deserializer for /// a sequence items. - map: &'m mut MapAccess<'de, 'a, R>, + map: &'m mut MapAccess<'de, 'a, R, E>, /// Filter that determines whether a tag is a part of this sequence. /// /// When feature `overlapped-lists` is not activated, iteration will stop @@ -653,18 +661,20 @@ where } #[cfg(feature = "overlapped-lists")] -impl<'de, 'a, 'm, R> Drop for MapValueSeqAccess<'de, 'a, 'm, R> +impl<'de, 'a, 'm, R, E> Drop for MapValueSeqAccess<'de, 'a, 'm, R, E> where R: XmlRead<'de>, + E: EntityResolver, { fn drop(&mut self) { self.map.de.start_replay(self.checkpoint); } } -impl<'de, 'a, 'm, R> SeqAccess<'de> for MapValueSeqAccess<'de, 'a, 'm, R> +impl<'de, 'a, 'm, R, E> SeqAccess<'de> for MapValueSeqAccess<'de, 'a, 'm, R, E> where R: XmlRead<'de>, + E: EntityResolver, { type Error = DeError; @@ -705,18 +715,20 @@ where //////////////////////////////////////////////////////////////////////////////////////////////////// /// A deserializer for a single item of a sequence. -struct SeqItemDeserializer<'de, 'a, 'm, R> +struct SeqItemDeserializer<'de, 'a, 'm, R, E> where R: XmlRead<'de>, + E: EntityResolver, { /// Access to the map that created this deserializer. Gives access to the /// context, such as list of fields, that current map known about. - map: &'m mut MapAccess<'de, 'a, R>, + map: &'m mut MapAccess<'de, 'a, R, E>, } -impl<'de, 'a, 'm, R> SeqItemDeserializer<'de, 'a, 'm, R> +impl<'de, 'a, 'm, R, E> SeqItemDeserializer<'de, 'a, 'm, R, E> where R: XmlRead<'de>, + E: EntityResolver, { /// Returns a next string as concatenated content of consequent [`Text`] and /// [`CData`] events, used inside [`deserialize_primitives!()`]. @@ -729,9 +741,10 @@ where } } -impl<'de, 'a, 'm, R> de::Deserializer<'de> for SeqItemDeserializer<'de, 'a, 'm, R> +impl<'de, 'a, 'm, R, E> de::Deserializer<'de> for SeqItemDeserializer<'de, 'a, 'm, R, E> where R: XmlRead<'de>, + E: EntityResolver, { type Error = DeError; diff --git a/src/de/mod.rs b/src/de/mod.rs index 013a50f4..9fd0ebf6 100644 --- a/src/de/mod.rs +++ b/src/de/mod.rs @@ -1833,10 +1833,13 @@ macro_rules! deserialize_option { mod key; mod map; +mod resolver; mod simple_type; mod var; pub use crate::errors::serialize::DeError; +pub use resolver::{EntityResolver, NoEntityResolver}; + use crate::{ encoding::Decoder, errors::Error, @@ -1935,6 +1938,8 @@ pub enum PayloadEvent<'a> { Text(BytesText<'a>), /// Unescaped character data stored in ``. CData(BytesCData<'a>), + /// Document type definition data (DTD) stored in ``. + DocType(BytesText<'a>), /// End of XML document. Eof, } @@ -1948,6 +1953,7 @@ impl<'a> PayloadEvent<'a> { PayloadEvent::End(e) => PayloadEvent::End(e.into_owned()), PayloadEvent::Text(e) => PayloadEvent::Text(e.into_owned()), PayloadEvent::CData(e) => PayloadEvent::CData(e.into_owned()), + PayloadEvent::DocType(e) => PayloadEvent::DocType(e.into_owned()), PayloadEvent::Eof => PayloadEvent::Eof, } } @@ -1956,7 +1962,7 @@ impl<'a> PayloadEvent<'a> { /// An intermediate reader that consumes [`PayloadEvent`]s and produces final [`DeEvent`]s. /// [`PayloadEvent::Text`] events, that followed by any event except /// [`PayloadEvent::Text`] or [`PayloadEvent::CData`], are trimmed from the end. -struct XmlReader<'i, R: XmlRead<'i>> { +struct XmlReader<'i, R: XmlRead<'i>, E: EntityResolver = NoEntityResolver> { /// A source of low-level XML events reader: R, /// Intermediate event, that could be returned by the next call to `next()`. @@ -1964,15 +1970,32 @@ struct XmlReader<'i, R: XmlRead<'i>> { /// trailing spaces is not. Before the event will be returned, trimming of /// the spaces could be necessary lookahead: Result, DeError>, + + /// Used to resolve unknown entities that would otherwise cause the parser + /// to return an [`EscapeError::UnrecognizedSymbol`] error. + /// + /// [`EscapeError::UnrecognizedSymbol`]: crate::escape::EscapeError::UnrecognizedSymbol + entity_resolver: E, } -impl<'i, R: XmlRead<'i>> XmlReader<'i, R> { - fn new(mut reader: R) -> Self { +impl<'i, R: XmlRead<'i>, E: EntityResolver> XmlReader<'i, R, E> { + fn new(reader: R) -> Self + where + E: Default, + { + Self::with_resolver(reader, E::default()) + } + + fn with_resolver(mut reader: R, entity_resolver: E) -> Self { // Lookahead by one event immediately, so we do not need to check in the // loop if we need lookahead or not let lookahead = reader.next(); - Self { reader, lookahead } + Self { + reader, + lookahead, + entity_resolver, + } } /// Read next event and put it in lookahead, return the current lookahead @@ -2028,7 +2051,7 @@ impl<'i, R: XmlRead<'i>> XmlReader<'i, R> { if self.need_trim_end() { e.inplace_trim_end(); } - Ok(e.unescape()?) + Ok(e.unescape_with(|entity| self.entity_resolver.resolve(entity))?) } PayloadEvent::CData(e) => Ok(e.decode()?), @@ -2047,9 +2070,15 @@ impl<'i, R: XmlRead<'i>> XmlReader<'i, R> { if self.need_trim_end() && e.inplace_trim_end() { continue; } - self.drain_text(e.unescape()?) + self.drain_text(e.unescape_with(|entity| self.entity_resolver.resolve(entity))?) } PayloadEvent::CData(e) => self.drain_text(e.decode()?), + PayloadEvent::DocType(e) => { + self.entity_resolver + .capture(e) + .map_err(|err| DeError::Custom(format!("cannot parse DTD: {}", err)))?; + continue; + } PayloadEvent::Eof => Ok(DeEvent::Eof), }; } @@ -2166,12 +2195,12 @@ where //////////////////////////////////////////////////////////////////////////////////////////////////// /// A structure that deserializes XML into Rust values. -pub struct Deserializer<'de, R> +pub struct Deserializer<'de, R, E: EntityResolver = NoEntityResolver> where R: XmlRead<'de>, { /// An XML reader that streams events into this deserializer - reader: XmlReader<'de, R>, + reader: XmlReader<'de, R, E>, /// When deserializing sequences sometimes we have to skip unwanted events. /// That events should be stored and then replayed. This is a replay buffer, @@ -2226,7 +2255,13 @@ where peek: None, } } +} +impl<'de, R, E> Deserializer<'de, R, E> +where + R: XmlRead<'de>, + E: EntityResolver, +{ /// Set the maximum number of events that could be skipped during deserialization /// of sequences. /// @@ -2556,20 +2591,49 @@ where /// instead, because it will borrow instead of copy. If you have `&[u8]` which /// is known to represent UTF-8, you can decode it first before using [`from_str`]. pub fn from_reader(reader: R) -> Self { + Self::with_resolver(reader, NoEntityResolver) + } +} + +impl<'de, R, E> Deserializer<'de, IoReader, E> +where + R: BufRead, + E: EntityResolver, +{ + /// Create new deserializer that will copy data from the specified reader + /// into internal buffer. If you already have a string use [`Self::from_str`] + /// instead, because it will borrow instead of copy. If you have `&[u8]` which + /// is known to represent UTF-8, you can decode it first before using [`from_str`]. + pub fn with_resolver(reader: R, entity_resolver: E) -> Self { let mut reader = Reader::from_reader(reader); reader.expand_empty_elements(true).check_end_names(true); - Self::new(IoReader { + let io_reader = IoReader { reader, start_trimmer: StartTrimmer::default(), buf: Vec::new(), - }) + }; + + Self { + reader: XmlReader::with_resolver(io_reader, entity_resolver), + + #[cfg(feature = "overlapped-lists")] + read: VecDeque::new(), + #[cfg(feature = "overlapped-lists")] + write: VecDeque::new(), + #[cfg(feature = "overlapped-lists")] + limit: None, + + #[cfg(not(feature = "overlapped-lists"))] + peek: None, + } } } -impl<'de, 'a, R> de::Deserializer<'de> for &'a mut Deserializer<'de, R> +impl<'de, 'a, R, E> de::Deserializer<'de> for &'a mut Deserializer<'de, R, E> where R: XmlRead<'de>, + E: EntityResolver, { type Error = DeError; @@ -2705,9 +2769,10 @@ where /// /// Technically, multiple top-level elements violates XML rule of only one top-level /// element, but we consider this as several concatenated XML documents. -impl<'de, 'a, R> SeqAccess<'de> for &'a mut Deserializer<'de, R> +impl<'de, 'a, R, E> SeqAccess<'de> for &'a mut Deserializer<'de, R, E> where R: XmlRead<'de>, + E: EntityResolver, { type Error = DeError; @@ -2743,6 +2808,7 @@ impl StartTrimmer { #[inline(always)] fn trim<'a>(&mut self, event: Event<'a>) -> Option> { let (event, trim_next_event) = match event { + Event::DocType(e) => (PayloadEvent::DocType(e), false), Event::Start(e) => (PayloadEvent::Start(e), true), Event::End(e) => (PayloadEvent::End(e), true), Event::Eof => (PayloadEvent::Eof, true), diff --git a/src/de/resolver.rs b/src/de/resolver.rs new file mode 100644 index 00000000..3e1f4552 --- /dev/null +++ b/src/de/resolver.rs @@ -0,0 +1,104 @@ +//! Entity resolver module + +use std::convert::Infallible; +use std::error::Error; + +use crate::events::BytesText; + +/// Used to resolve unknown entities while parsing +/// +/// # Example +/// +/// ``` +/// # use serde::Deserialize; +/// # use pretty_assertions::assert_eq; +/// use regex::bytes::Regex; +/// use std::collections::BTreeMap; +/// use std::string::FromUtf8Error; +/// use quick_xml::de::{Deserializer, EntityResolver}; +/// use quick_xml::events::BytesText; +/// +/// struct DocTypeEntityResolver { +/// re: Regex, +/// map: BTreeMap, +/// } +/// +/// impl Default for DocTypeEntityResolver { +/// fn default() -> Self { +/// Self { +/// // We do not focus on true parsing in this example +/// // You should use special libraries to parse DTD +/// re: Regex::new(r#""#).unwrap(), +/// map: BTreeMap::new(), +/// } +/// } +/// } +/// +/// impl EntityResolver for DocTypeEntityResolver { +/// type Error = FromUtf8Error; +/// +/// fn capture(&mut self, doctype: BytesText) -> Result<(), Self::Error> { +/// for cap in self.re.captures_iter(&doctype) { +/// self.map.insert( +/// String::from_utf8(cap[1].to_vec())?, +/// String::from_utf8(cap[2].to_vec())?, +/// ); +/// } +/// Ok(()) +/// } +/// +/// fn resolve(&self, entity: &str) -> Option<&str> { +/// self.map.get(entity).map(|s| s.as_str()) +/// } +/// } +/// +/// let xml_reader = br#" +/// ]> +/// +/// &e1; +/// +/// "#.as_ref(); +/// +/// let mut de = Deserializer::with_resolver( +/// xml_reader, +/// DocTypeEntityResolver::default(), +/// ); +/// let data: BTreeMap = BTreeMap::deserialize(&mut de).unwrap(); +/// +/// assert_eq!(data.get("entity_one"), Some(&"entity 1".to_string())); +/// ``` +pub trait EntityResolver { + /// The error type that represents DTD parse error + type Error: Error; + + /// Called on contents of [`Event::DocType`] to capture declared entities. + /// Can be called multiple times, for each parsed `` declaration. + /// + /// [`Event::DocType`]: crate::events::Event::DocType + fn capture(&mut self, doctype: BytesText) -> Result<(), Self::Error>; + + /// Called when an entity needs to be resolved. + /// + /// `None` is returned if a suitable value can not be found. + /// In that case an [`EscapeError::UnrecognizedSymbol`] will be returned by + /// a deserializer. + /// + /// [`EscapeError::UnrecognizedSymbol`]: crate::escape::EscapeError::UnrecognizedSymbol + fn resolve(&self, entity: &str) -> Option<&str>; +} + +/// An `EntityResolver` that does nothing and always returns `None`. +#[derive(Default, Copy, Clone)] +pub struct NoEntityResolver; + +impl EntityResolver for NoEntityResolver { + type Error = Infallible; + + fn capture(&mut self, _doctype: BytesText) -> Result<(), Self::Error> { + Ok(()) + } + + fn resolve(&self, _entity: &str) -> Option<&str> { + None + } +} diff --git a/src/de/var.rs b/src/de/var.rs index 32295258..9c97a9c1 100644 --- a/src/de/var.rs +++ b/src/de/var.rs @@ -1,5 +1,6 @@ use crate::{ de::key::QNameDeserializer, + de::resolver::EntityResolver, de::simple_type::SimpleTypeDeserializer, de::{DeEvent, Deserializer, XmlRead, TEXT_KEY}, errors::serialize::DeError, @@ -8,30 +9,33 @@ use serde::de::value::BorrowedStrDeserializer; use serde::de::{self, DeserializeSeed, Deserializer as _, Visitor}; /// An enum access -pub struct EnumAccess<'de, 'a, R> +pub struct EnumAccess<'de, 'a, R, E> where R: XmlRead<'de>, + E: EntityResolver, { - de: &'a mut Deserializer<'de, R>, + de: &'a mut Deserializer<'de, R, E>, } -impl<'de, 'a, R> EnumAccess<'de, 'a, R> +impl<'de, 'a, R, E> EnumAccess<'de, 'a, R, E> where R: XmlRead<'de>, + E: EntityResolver, { - pub fn new(de: &'a mut Deserializer<'de, R>) -> Self { + pub fn new(de: &'a mut Deserializer<'de, R, E>) -> Self { EnumAccess { de } } } -impl<'de, 'a, R> de::EnumAccess<'de> for EnumAccess<'de, 'a, R> +impl<'de, 'a, R, E> de::EnumAccess<'de> for EnumAccess<'de, 'a, R, E> where R: XmlRead<'de>, + E: EntityResolver, { type Error = DeError; - type Variant = VariantAccess<'de, 'a, R>; + type Variant = VariantAccess<'de, 'a, R, E>; - fn variant_seed(self, seed: V) -> Result<(V::Value, VariantAccess<'de, 'a, R>), DeError> + fn variant_seed(self, seed: V) -> Result<(V::Value, VariantAccess<'de, 'a, R, E>), DeError> where V: DeserializeSeed<'de>, { @@ -58,19 +62,21 @@ where } } -pub struct VariantAccess<'de, 'a, R> +pub struct VariantAccess<'de, 'a, R, E> where R: XmlRead<'de>, + E: EntityResolver, { - de: &'a mut Deserializer<'de, R>, + de: &'a mut Deserializer<'de, R, E>, /// `true` if variant should be deserialized from a textual content /// and `false` if from tag is_text: bool, } -impl<'de, 'a, R> de::VariantAccess<'de> for VariantAccess<'de, 'a, R> +impl<'de, 'a, R, E> de::VariantAccess<'de> for VariantAccess<'de, 'a, R, E> where R: XmlRead<'de>, + E: EntityResolver, { type Error = DeError; diff --git a/tests/serde-de.rs b/tests/serde-de.rs index 907a2cc6..7b12fd02 100644 --- a/tests/serde-de.rs +++ b/tests/serde-de.rs @@ -6427,3 +6427,72 @@ mod borrow { ); } } + +/// Test for entity resolver +mod resolve { + use super::*; + use pretty_assertions::assert_eq; + use quick_xml::de::EntityResolver; + use quick_xml::events::BytesText; + use std::collections::BTreeMap; + use std::convert::Infallible; + use std::iter::FromIterator; + + struct TestEntityResolver { + capture_called: bool, + } + + impl EntityResolver for TestEntityResolver { + type Error = Infallible; + + fn capture(&mut self, doctype: BytesText) -> Result<(), Self::Error> { + self.capture_called = true; + + assert_eq!(doctype.as_ref(), br#"dict[ ]"#); + + Ok(()) + } + + fn resolve(&self, entity: &str) -> Option<&str> { + assert!( + self.capture_called, + "`EntityResolver::capture` should be called before `EntityResolver::resolve`" + ); + match entity { + "t1" => Some("test_one"), + "t2" => Some("test_two"), + _ => None, + } + } + } + + #[test] + fn resolve_custom_entity() { + let resolver = TestEntityResolver { + capture_called: false, + }; + let mut de = Deserializer::with_resolver( + br#" + ]> + + + &t1; + &t2; + non-entity + + "# + .as_ref(), + resolver, + ); + + let data: BTreeMap = BTreeMap::deserialize(&mut de).unwrap(); + assert_eq!( + data, + BTreeMap::from_iter([ + (String::from("entity_one"), String::from("test_one")), + (String::from("entity_two"), String::from("test_two")), + (String::from("entity_three"), String::from("non-entity")), + ]) + ); + } +}