diff --git a/src/types/entities/channel.rs b/src/types/entities/channel.rs index 044c1e2d..a42dadf4 100644 --- a/src/types/entities/channel.rs +++ b/src/types/entities/channel.rs @@ -3,14 +3,16 @@ // file, You can obtain one at http://mozilla.org/MPL/2.0/. use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; use serde_repr::{Deserialize_repr, Serialize_repr}; -use std::fmt::Debug; +use std::fmt::{Debug, Formatter}; +use std::str::FromStr; use crate::types::{ PermissionFlags, Shared, entities::{GuildMember, User}, utils::Snowflake, + serde::string_or_u64 }; #[cfg(feature = "client")] @@ -24,6 +26,8 @@ use crate::gateway::Updateable; #[cfg(feature = "client")] use chorus_macros::{observe_option_vec, Composite, Updateable}; +use serde::de::{Error, Visitor}; + #[derive(Default, Debug, Serialize, Deserialize, Clone)] #[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))] @@ -155,7 +159,7 @@ pub struct PermissionOverwrite { } -#[derive(Debug, Serialize_repr, Deserialize_repr, Clone, PartialEq, Eq, PartialOrd)] +#[derive(Debug, Serialize_repr, Clone, PartialEq, Eq, PartialOrd)] #[repr(u8)] /// # Reference pub enum PermissionOverwriteType { @@ -163,6 +167,63 @@ pub enum PermissionOverwriteType { Member = 1, } +impl From for PermissionOverwriteType { + fn from(v: u8) -> Self { + match v { + 0 => PermissionOverwriteType::Role, + 1 => PermissionOverwriteType::Member, + _ => unreachable!(), + } + } +} + +impl FromStr for PermissionOverwriteType { + type Err = serde::de::value::Error; + + fn from_str(s: &str) -> Result { + match s { + "role" => Ok(PermissionOverwriteType::Role), + "member" => Ok(PermissionOverwriteType::Member), + _ => Err(Self::Err::custom("invalid permission overwrite type")), + } + } +} + +struct PermissionOverwriteTypeVisitor; + +impl<'de> Visitor<'de> for PermissionOverwriteTypeVisitor { + type Value = PermissionOverwriteType; + + fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { + formatter.write_str("a valid permission overwrite type") + } + + fn visit_u8(self, v: u8) -> Result where E: Error { + Ok(PermissionOverwriteType::from(v)) + } + + fn visit_u64(self, v: u64) -> Result where E: Error { + self.visit_u8(v as u8) + } + + fn visit_str(self, v: &str) -> Result where E: Error { + PermissionOverwriteType::from_str(v) + .map_err(E::custom) + } + + fn visit_string(self, v: String) -> Result where E: Error { + self.visit_str(v.as_str()) + } +} + +impl<'de> Deserialize<'de> for PermissionOverwriteType { + fn deserialize(deserializer: D) -> Result where D: Deserializer<'de> { + let val = deserializer.deserialize_any(PermissionOverwriteTypeVisitor)?; + + Ok(val) + } +} + #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] /// # Reference /// See