Skip to content

Commit

Permalink
Add getters for protocol version and name; default to "MQTT" for name
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeff Hiner authored and zonyitoo committed Nov 12, 2020
1 parent 8e95393 commit 87ed144
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 20 deletions.
2 changes: 1 addition & 1 deletion examples/pub-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ fn main() {
info!("Connected!");

info!("Client identifier {:?}", client_id);
let mut conn = ConnectPacket::new("MQTT", client_id);
let mut conn = ConnectPacket::new(client_id);
conn.set_clean_session(true);
let mut buf = Vec::new();
conn.encode(&mut buf).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion examples/sub-client-async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async fn main() {
info!("Connected!");

info!("Client identifier {:?}", client_id);
let mut conn = ConnectPacket::new("MQTT", client_id);
let mut conn = ConnectPacket::new(client_id);
conn.set_clean_session(true);
conn.set_keep_alive(keep_alive);
let mut buf = Vec::new();
Expand Down
2 changes: 1 addition & 1 deletion examples/sub-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ fn main() {
info!("Connected!");

info!("Client identifier {:?}", client_id);
let mut conn = ConnectPacket::new("MQTT", client_id);
let mut conn = ConnectPacket::new(client_id);
conn.set_clean_session(true);
conn.set_keep_alive(keep_alive);
let mut buf = Vec::new();
Expand Down
3 changes: 3 additions & 0 deletions src/control/variable_header/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub enum VariableHeaderError {
InvalidReservedFlag,
FromUtf8Error(FromUtf8Error),
TopicNameError(TopicNameError),
InvalidProtocolVersion,
}

impl From<io::Error> for VariableHeaderError {
Expand Down Expand Up @@ -69,6 +70,7 @@ impl fmt::Display for VariableHeaderError {
VariableHeaderError::InvalidReservedFlag => write!(f, "Invalid reserved flags"),
VariableHeaderError::FromUtf8Error(ref err) => write!(f, "{}", err),
VariableHeaderError::TopicNameError(ref err) => write!(f, "{}", err),
VariableHeaderError::InvalidProtocolVersion => write!(f, "Invalid protocol version"),
}
}
}
Expand All @@ -81,6 +83,7 @@ impl Error for VariableHeaderError {
VariableHeaderError::InvalidReservedFlag => None,
VariableHeaderError::FromUtf8Error(ref err) => Some(err),
VariableHeaderError::TopicNameError(ref err) => Some(err),
VariableHeaderError::InvalidProtocolVersion => None,
}
}
}
30 changes: 26 additions & 4 deletions src/control/variable_header/protocol_level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,27 @@ use std::io::{Read, Write};

use byteorder::{ReadBytesExt, WriteBytesExt};

use crate::{Decodable, Encodable};
use crate::control::variable_header::VariableHeaderError;
use crate::{Decodable, Encodable};

pub const SPEC_3_1_0: u8 = 0x03;
pub const SPEC_3_1_1: u8 = 0x04;
pub const SPEC_5_0: u8 = 0x05;

/// Protocol level in MQTT (`0x04` in v3.1.1)
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
pub struct ProtocolLevel(pub u8);
#[repr(u8)]
pub enum ProtocolLevel {
Version310 = SPEC_3_1_0,
Version311 = SPEC_3_1_1,
Version50 = SPEC_5_0,
}

impl Encodable for ProtocolLevel {
type Err = VariableHeaderError;

fn encode<W: Write>(&self, writer: &mut W) -> Result<(), VariableHeaderError> {
writer.write_u8(self.0).map_err(From::from)
writer.write_u8(*self as u8).map_err(From::from)
}

fn encoded_length(&self) -> u32 {
Expand All @@ -31,6 +38,21 @@ impl Decodable for ProtocolLevel {
type Cond = ();

fn decode_with<R: Read>(reader: &mut R, _rest: Option<()>) -> Result<ProtocolLevel, VariableHeaderError> {
reader.read_u8().map(ProtocolLevel).map_err(From::from)
reader
.read_u8()
.map_err(From::from)
.map(ProtocolLevel::from_u8)
.and_then(|x| x.ok_or(VariableHeaderError::InvalidProtocolVersion))
}
}

impl ProtocolLevel {
pub fn from_u8(n: u8) -> Option<ProtocolLevel> {
match n {
SPEC_3_1_0 => Some(ProtocolLevel::Version310),
SPEC_3_1_1 => Some(ProtocolLevel::Version311),
SPEC_5_0 => Some(ProtocolLevel::Version50),
_ => None,
}
}
}
28 changes: 18 additions & 10 deletions src/packet/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::fmt;
use std::io::{self, Read, Write};

use crate::control::variable_header::protocol_level::SPEC_3_1_1;
use crate::control::variable_header::{ConnectFlags, KeepAlive, ProtocolLevel, ProtocolName};
use crate::control::variable_header::{ConnectFlags, KeepAlive, ProtocolLevel, ProtocolName, VariableHeaderError};
use crate::control::{ControlType, FixedHeader, PacketType};
use crate::encodable::{StringEncodeError, VarBytes};
use crate::packet::{Packet, PacketError};
Expand All @@ -26,31 +26,31 @@ pub struct ConnectPacket {
}

impl ConnectPacket {
pub fn new<P, C>(protoname: P, client_identifier: C) -> ConnectPacket
pub fn new<C>(client_identifier: C) -> ConnectPacket
where
P: Into<String>,
C: Into<String>,
{
ConnectPacket::with_level(protoname, client_identifier, SPEC_3_1_1)
ConnectPacket::with_level("MQTT", client_identifier, SPEC_3_1_1).expect("SPEC_3_1_1 should always be valid")
}

pub fn with_level<P, C>(protoname: P, client_identifier: C, level: u8) -> ConnectPacket
pub fn with_level<P, C>(protoname: P, client_identifier: C, level: u8) -> Result<ConnectPacket, VariableHeaderError>
where
P: Into<String>,
C: Into<String>,
{
let protocol_level = ProtocolLevel::from_u8(level).ok_or(VariableHeaderError::InvalidProtocolVersion)?;
let mut pk = ConnectPacket {
fixed_header: FixedHeader::new(PacketType::with_default(ControlType::Connect), 0),
protocol_name: ProtocolName(protoname.into()),
protocol_level: ProtocolLevel(level),
protocol_level,
flags: ConnectFlags::empty(),
keep_alive: KeepAlive(0),
payload: ConnectPacketPayload::new(client_identifier.into()),
};

pk.fixed_header.remaining_length = pk.calculate_remaining_length();

pk
Ok(pk)
}

#[inline]
Expand Down Expand Up @@ -142,6 +142,14 @@ impl ConnectPacket {
&self.payload.client_identifier[..]
}

pub fn protocol_name(&self) -> &str {
&self.protocol_name.0
}

pub fn protocol_level(&self) -> ProtocolLevel {
self.protocol_level
}

pub fn clean_session(&self) -> bool {
self.flags.clean_session
}
Expand Down Expand Up @@ -379,7 +387,7 @@ mod test {

#[test]
fn test_connect_packet_encode_basic() {
let packet = ConnectPacket::new("MQTT".to_owned(), "12345".to_owned());
let packet = ConnectPacket::new("12345".to_owned());
let expected = b"\x10\x11\x00\x04MQTT\x04\x00\x00\x00\x00\x0512345";

let mut buf = Vec::new();
Expand All @@ -395,13 +403,13 @@ mod test {
let mut buf = Cursor::new(&encoded_data[..]);
let packet = ConnectPacket::decode(&mut buf).unwrap();

let expected = ConnectPacket::new("MQTT".to_owned(), "12345".to_owned());
let expected = ConnectPacket::new("12345".to_owned());
assert_eq!(expected, packet);
}

#[test]
fn test_connect_packet_user_name() {
let mut packet = ConnectPacket::new("MQTT".to_owned(), "12345".to_owned());
let mut packet = ConnectPacket::new("12345".to_owned());
packet.set_user_name(Some("mqtt_player".to_owned()));

let mut buf = Vec::new();
Expand Down
6 changes: 3 additions & 3 deletions src/packet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ mod test {

#[test]
fn test_variable_packet_basic() {
let packet = ConnectPacket::new("MQTT".to_owned(), "1234".to_owned());
let packet = ConnectPacket::new("1234".to_owned());

// Wrap it
let var_packet = VariablePacket::new(packet);
Expand All @@ -449,7 +449,7 @@ mod test {
#[cfg(feature = "async")]
#[tokio::test]
async fn test_variable_packet_async_parse() {
let packet = ConnectPacket::new("MQTT".to_owned(), "1234".to_owned());
let packet = ConnectPacket::new("1234".to_owned());

// Wrap it
let var_packet = VariablePacket::new(packet);
Expand All @@ -468,7 +468,7 @@ mod test {
#[cfg(feature = "async")]
#[tokio::test]
async fn test_variable_packet_async_peek() {
let packet = ConnectPacket::new("MQTT".to_owned(), "1234".to_owned());
let packet = ConnectPacket::new("1234".to_owned());

// Wrap it
let var_packet = VariablePacket::new(packet);
Expand Down

0 comments on commit 87ed144

Please sign in to comment.