Skip to content

Commit

Permalink
feat(stun-types): do not fail on encode
Browse files Browse the repository at this point in the history
  • Loading branch information
kbalt committed Jan 17, 2025
1 parent d9568a0 commit 5c14bbc
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 176 deletions.
9 changes: 3 additions & 6 deletions crates/stun-types/src/attributes/addr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@ impl Attribute<'_> for MappedAddress {
decode_addr(attr.get_value(msg.buffer()), 0, 0, 0).map(Self)
}

fn encode(&self, _: Self::Context, builder: &mut MessageBuilder) -> Result<(), Error> {
fn encode(&self, _: Self::Context, builder: &mut MessageBuilder) {
encode_addr(self.0, builder.buffer(), 0, 0, 0);
Ok(())
}

fn encode_len(&self) -> Result<u16, Error> {
Expand All @@ -95,10 +94,9 @@ impl Attribute<'_> for XorMappedAddress {
decode_addr(attr.get_value(msg.buffer()), XOR16, COOKIE, xor128).map(Self)
}

fn encode(&self, _: Self::Context, builder: &mut MessageBuilder) -> Result<(), Error> {
fn encode(&self, _: Self::Context, builder: &mut MessageBuilder) {
let xor128 = builder.id();
encode_addr(self.0, builder.buffer(), XOR16, COOKIE, xor128);
Ok(())
}

fn encode_len(&self) -> Result<u16, Error> {
Expand All @@ -120,9 +118,8 @@ impl Attribute<'_> for AlternateServer {
decode_addr(attr.get_value(msg.buffer()), 0, 0, 0).map(Self)
}

fn encode(&self, _: Self::Context, builder: &mut MessageBuilder) -> Result<(), Error> {
fn encode(&self, _: Self::Context, builder: &mut MessageBuilder) {
encode_addr(self.0, builder.buffer(), 0, 0, 0);
Ok(())
}

fn encode_len(&self) -> Result<u16, Error> {
Expand Down
14 changes: 5 additions & 9 deletions crates/stun-types/src/attributes/error_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl<'s> Attribute<'s> for ErrorCode<'s> {
})
}

fn encode(&self, _: Self::Context, builder: &mut MessageBuilder) -> Result<(), Error> {
fn encode(&self, _: Self::Context, builder: &mut MessageBuilder) {
let class = self.number / 100;
let number = self.number % 100;

Expand All @@ -57,8 +57,6 @@ impl<'s> Attribute<'s> for ErrorCode<'s> {

builder.buffer().put_u32(head.0);
builder.buffer().extend_from_slice(self.reason.as_ref());

Ok(())
}

fn encode_len(&self) -> Result<u16, Error> {
Expand All @@ -81,12 +79,10 @@ mod test {
fn error_code() {
let mut builder =
MessageBuilder::new(Class::Error, Method::Binding, TransactionId::new([0; 12]));
builder
.add_attr(&ErrorCode {
number: 400,
reason: "Bad Request",
})
.unwrap();
builder.add_attr(ErrorCode {
number: 400,
reason: "Bad Request",
});

let bytes = builder.finish();

Expand Down
10 changes: 6 additions & 4 deletions crates/stun-types/src/attributes/fingerprint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,24 @@ impl Attribute<'_> for Fingerprint {
)
}

fn encode(&self, _: Self::Context, builder: &mut MessageBuilder) -> Result<(), Error> {
fn encode(&self, _: Self::Context, builder: &mut MessageBuilder) {
// First set the length of the message to the end of the fingerprint attribute
// 4 bytes containing type and length is already written into the buffer
let message_length_with_fingerprint_attribute =
(builder.buffer().len() + 4) - STUN_HEADER_LENGTH;

builder.set_len(message_length_with_fingerprint_attribute.try_into()?);
builder.set_len(
message_length_with_fingerprint_attribute
.try_into()
.expect("stun messages must fit withing 65535 bytes"),
);

// Calculate the checksum
let data = builder.buffer();
let data = &data[..data.len() - ATTRIBUTE_HEADER_LEN];
let crc = Self::crc32(data) ^ 0x5354554e;

builder.buffer().put_u32(crc);

Ok(())
}

fn encode_len(&self) -> Result<u16, Error> {
Expand Down
19 changes: 4 additions & 15 deletions crates/stun-types/src/attributes/ice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,9 @@ impl Attribute<'_> for Priority {
Ok(Self(value.read_u32::<NE>()?))
}

fn encode(&self, _: Self::Context, builder: &mut MessageBuilder) -> Result<(), Error> {
fn encode(&self, _: Self::Context, builder: &mut MessageBuilder) {
let data = builder.buffer();

data.put_u32(self.0);

Ok(())
}

fn encode_len(&self) -> Result<u16, Error> {
Expand All @@ -46,9 +43,7 @@ impl Attribute<'_> for UseCandidate {
Ok(Self)
}

fn encode(&self, _: Self::Context, _builder: &mut MessageBuilder) -> Result<(), Error> {
Ok(())
}
fn encode(&self, _: Self::Context, _builder: &mut MessageBuilder) {}

fn encode_len(&self) -> Result<u16, Error> {
Ok(0)
Expand All @@ -71,12 +66,9 @@ impl Attribute<'_> for IceControlled {
Ok(Self(value.read_u64::<NE>()?))
}

fn encode(&self, _: Self::Context, builder: &mut MessageBuilder) -> Result<(), Error> {
fn encode(&self, _: Self::Context, builder: &mut MessageBuilder) {
let data = builder.buffer();

data.put_u64(self.0);

Ok(())
}

fn encode_len(&self) -> Result<u16, Error> {
Expand All @@ -100,12 +92,9 @@ impl Attribute<'_> for IceControlling {
Ok(Self(value.read_u64::<NE>()?))
}

fn encode(&self, _: Self::Context, builder: &mut MessageBuilder) -> Result<(), Error> {
fn encode(&self, _: Self::Context, builder: &mut MessageBuilder) {
let data = builder.buffer();

data.put_u64(self.0);

Ok(())
}

fn encode_len(&self) -> Result<u16, Error> {
Expand Down
133 changes: 54 additions & 79 deletions crates/stun-types/src/attributes/integrity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,88 +8,69 @@ use hmac::digest::{Digest, Update};
use hmac::{Mac, SimpleHmac};
use sha1::Sha1;
use sha2::Sha256;
use std::borrow::Cow;
use std::convert::TryFrom;
use std::marker::PhantomData;

pub struct MessageIntegrityKey<'s>(Cow<'s, [u8]>);

impl<'s> MessageIntegrityKey<'s> {
pub fn new_long_term_md5(username: &str, realm: &str, password: &str) -> Self {
let key = md5::compute(format!("{}:{}:{}", username, realm, password))
.0
.to_vec();

Self(Cow::Owned(key))
}
pub fn long_term_password_md5(username: &str, realm: &str, password: &str) -> Vec<u8> {
md5::compute(format!("{}:{}:{}", username, realm, password).as_bytes()).to_vec()
}

pub fn new_long_term_sha256(username: &str, realm: &str, password: &str) -> Self {
let key =
Sha256::digest(format!("{}:{}:{}", username, realm, password).as_bytes()).to_vec();
pub fn long_term_password_sha256(username: &str, realm: &str, password: &str) -> Vec<u8> {
Sha256::digest(format!("{}:{}:{}", username, realm, password).as_bytes()).to_vec()
}

Self(Cow::Owned(key))
}
pub struct MessageIntegrityKey(SimpleHmac<Sha1>);

pub fn new_short_term(password: &'s str) -> Self {
Self(Cow::Borrowed(password.as_bytes()))
}

pub fn new_raw(raw: Cow<'s, [u8]>) -> Self {
Self(raw)
impl MessageIntegrityKey {
pub fn new(key: impl AsRef<[u8]>) -> Self {
Self(SimpleHmac::new_from_slice(key.as_ref()).expect("any key length is valid"))
}
}

/// [RFC8489](https://datatracker.ietf.org/doc/html/rfc8489#section-14.5)
#[derive(Default)]
pub struct MessageIntegrity<'k>(PhantomData<&'k ()>);
pub struct MessageIntegrity;

impl<'k> Attribute<'_> for MessageIntegrity<'k> {
type Context = &'k MessageIntegrityKey<'k>;
impl Attribute<'_> for MessageIntegrity {
type Context = MessageIntegrityKey;
const TYPE: u16 = 0x0008;

fn decode(ctx: Self::Context, msg: &mut Message, attr: AttrSpan) -> Result<Self, Error> {
let hmac: SimpleHmac<Sha1> = SimpleHmac::new_from_slice(&ctx.0)
.map_err(|_| Error::InvalidData("invalid key length"))?;

message_integrity_decode(hmac, msg, attr)?;
message_integrity_decode(ctx.0, msg, attr)?;

Ok(Self(PhantomData))
Ok(Self)
}

fn encode(&self, ctx: Self::Context, builder: &mut MessageBuilder) -> Result<(), Error> {
let hmac: SimpleHmac<Sha1> = SimpleHmac::new_from_slice(&ctx.0)
.map_err(|_| Error::InvalidData("invalid key length"))?;

message_integrity_encode(hmac, builder)
fn encode(&self, ctx: Self::Context, builder: &mut MessageBuilder) {
message_integrity_encode(ctx.0, builder)
}

fn encode_len(&self) -> Result<u16, Error> {
Ok(u16::try_from(Sha1::output_size())?)
}
}

pub struct MessageIntegritySha256Key(SimpleHmac<Sha256>);

impl MessageIntegritySha256Key {
pub fn new(key: impl AsRef<[u8]>) -> Self {
Self(SimpleHmac::new_from_slice(key.as_ref()).expect("any key length is valid"))
}
}

/// [RFC8489](https://datatracker.ietf.org/doc/html/rfc8489#section-14.6)
#[derive(Default)]
pub struct MessageIntegritySha256<'k>(PhantomData<&'k ()>);
pub struct MessageIntegritySha256;

impl<'k> Attribute<'_> for MessageIntegritySha256<'k> {
type Context = &'k MessageIntegrityKey<'k>;
impl Attribute<'_> for MessageIntegritySha256 {
type Context = MessageIntegritySha256Key;
const TYPE: u16 = 0x001C;

fn decode(ctx: Self::Context, msg: &mut Message, attr: AttrSpan) -> Result<Self, Error> {
let hmac: SimpleHmac<Sha256> = SimpleHmac::new_from_slice(&ctx.0)
.map_err(|_| Error::InvalidData("invalid key length"))?;

message_integrity_decode(hmac, msg, attr)?;
message_integrity_decode(ctx.0, msg, attr)?;

Ok(Self(PhantomData))
Ok(Self)
}

fn encode(&self, ctx: Self::Context, builder: &mut MessageBuilder) -> Result<(), Error> {
let hmac: SimpleHmac<Sha256> = SimpleHmac::new_from_slice(&ctx.0)
.map_err(|_| Error::InvalidData("invalid key length"))?;

message_integrity_encode(hmac, builder)
fn encode(&self, ctx: Self::Context, builder: &mut MessageBuilder) {
message_integrity_encode(ctx.0, builder)
}

fn encode_len(&self) -> Result<u16, Error> {
Expand Down Expand Up @@ -134,18 +115,19 @@ where
)
}

fn message_integrity_encode<D>(
mut hmac: SimpleHmac<D>,
builder: &mut MessageBuilder,
) -> Result<(), Error>
fn message_integrity_encode<D>(mut hmac: SimpleHmac<D>, builder: &mut MessageBuilder)
where
D: Digest + BlockSizeUser,
{
// 4 bytes containing type and length is already written into the buffer
let message_length_with_integrity_attribute =
(builder.buffer().len() + <D as Digest>::output_size()) - STUN_HEADER_LENGTH;

builder.set_len(message_length_with_integrity_attribute.try_into()?);
builder.set_len(
message_length_with_integrity_attribute
.try_into()
.expect("stun messages must fit withing 65535 bytes"),
);

// Calculate the digest of the message up until the previous attribute
let data = builder.buffer();
Expand All @@ -154,13 +136,13 @@ where
let digest = hmac.finalize().into_bytes();

builder.buffer().extend_from_slice(&digest);

Ok(())
}

#[cfg(test)]
mod test {
use super::{MessageIntegrity, MessageIntegrityKey, MessageIntegritySha256};
use super::{
MessageIntegrity, MessageIntegrityKey, MessageIntegritySha256, MessageIntegritySha256Key,
};
use crate::attributes::Software;
use crate::builder::MessageBuilder;
use crate::header::{Class, Method};
Expand All @@ -174,19 +156,15 @@ mod test {
let mut message =
MessageBuilder::new(Class::Request, Method::Binding, TransactionId::new([0; 12]));

message.add_attr(&Software::new("ezk-stun")).unwrap();
message
.add_attr_with(
&MessageIntegrity::default(),
&MessageIntegrityKey::new_short_term(password),
)
.unwrap();
message.add_attr(Software::new("ezk-stun"));
message.add_attr_with(MessageIntegrity, MessageIntegrityKey::new(password));

let bytes = message.finish();
let bytes = Vec::from(&bytes[..]);

let mut msg = Message::parse(bytes).unwrap();

msg.attribute_with::<MessageIntegrity>(&MessageIntegrityKey::new_short_term(password))
msg.attribute_with::<MessageIntegrity>(MessageIntegrityKey::new(password))
.unwrap()
.unwrap();
}
Expand All @@ -198,22 +176,19 @@ mod test {
let mut message =
MessageBuilder::new(Class::Request, Method::Binding, TransactionId::new([0; 12]));

message.add_attr(&Software::new("ezk-stun")).unwrap();
message
.add_attr_with(
&MessageIntegritySha256::default(),
&MessageIntegrityKey::new_short_term(password),
)
.unwrap();
message.add_attr(Software::new("ezk-stun"));
message.add_attr_with(
MessageIntegritySha256,
MessageIntegritySha256Key::new(password),
);

let bytes = message.finish();
let bytes = Vec::from(&bytes[..]);

let mut msg = Message::parse(bytes).unwrap();

msg.attribute_with::<MessageIntegritySha256>(&MessageIntegrityKey::new_short_term(
password,
))
.unwrap()
.unwrap();
msg.attribute_with::<MessageIntegritySha256>(MessageIntegritySha256Key::new(password))
.unwrap()
.unwrap();
}
}
Loading

0 comments on commit 5c14bbc

Please sign in to comment.