Skip to content

Commit

Permalink
fix(connector): send join requests in a single batch
Browse files Browse the repository at this point in the history
Issue: #112
  • Loading branch information
CBenoit committed Jan 9, 2024
1 parent 7519d8c commit f53b7cb
Showing 1 changed file with 40 additions and 44 deletions.
84 changes: 40 additions & 44 deletions crates/ironrdp-connector/src/channel_connection.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashSet;
use std::mem;

use ironrdp_pdu::write_buf::WriteBuf;
Expand All @@ -17,11 +18,9 @@ pub enum ChannelConnectionState {
WaitAttachUserConfirm,
SendChannelJoinRequest {
user_channel_id: u16,
index: usize,
},
WaitChannelJoinConfirm {
user_channel_id: u16,
index: usize,
},
AllJoined {
user_channel_id: u16,
Expand Down Expand Up @@ -54,13 +53,15 @@ impl State for ChannelConnectionState {
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub struct ChannelConnectionSequence {
pub state: ChannelConnectionState,
pub channel_ids: Vec<u16>,
pub channel_ids: HashSet<u16>,
}

impl ChannelConnectionSequence {
pub fn new(io_channel_id: u16, mut channel_ids: Vec<u16>) -> Self {
// I/O channel ID must be joined as well
channel_ids.push(io_channel_id);
pub fn new(io_channel_id: u16, channel_ids: Vec<u16>) -> Self {
let mut channel_ids: HashSet<u16> = channel_ids.into_iter().collect();

// I/O channel ID must be joined as well.
channel_ids.insert(io_channel_id);

Self {
state: ChannelConnectionState::SendErectDomainRequest,
Expand Down Expand Up @@ -125,49 +126,46 @@ impl Sequence for ChannelConnectionSequence {

let user_channel_id = attach_user_confirm.initiator_id;

// user channel ID must also be joined
self.channel_ids.push(user_channel_id);
// User channel ID must also be joined.
self.channel_ids.insert(user_channel_id);

debug!(message = ?attach_user_confirm, user_channel_id, "Received");

debug_assert!(!self.channel_ids.is_empty());

(
Written::Nothing,
ChannelConnectionState::SendChannelJoinRequest {
user_channel_id,
index: 0,
},
ChannelConnectionState::SendChannelJoinRequest { user_channel_id },
)
}

// TODO(#112): send all the join requests in a single batch
// (RDP 4.0, 5.0, 5.1, 5.2, 6.0, 6.1, 7.0, 7.1, 8.0, 10.2, 10.3,
// 10.4, and 10.5 clients send a Channel Join Request to the server only after the
// Channel Join Confirm for a previously sent request has been received. RDP 8.1,
// 10.0, and 10.1 clients send all of the Channel Join Requests to the server in a
// single batch to minimize the overall connection sequence time.)
ChannelConnectionState::SendChannelJoinRequest { user_channel_id, index } => {
let channel_id = self.channel_ids[index];

let channel_join_request = mcs::ChannelJoinRequest {
initiator_id: user_channel_id,
channel_id,
};
// Send all the join requests in a single batch.
// > RDP 4.0, 5.0, 5.1, 5.2, 6.0, 6.1, 7.0, 7.1, 8.0, 10.2, 10.3,
// > 10.4, and 10.5 clients send a Channel Join Request to the server only after the
// > Channel Join Confirm for a previously sent request has been received. RDP 8.1,
// > 10.0, and 10.1 clients send all of the Channel Join Requests to the server in a
// > single batch to minimize the overall connection sequence time.
ChannelConnectionState::SendChannelJoinRequest { user_channel_id } => {
let mut written = 0;

for channel_id in self.channel_ids.iter().copied() {
let channel_join_request = mcs::ChannelJoinRequest {
initiator_id: user_channel_id,
channel_id,
};

debug!(message = ?channel_join_request, "Send");
debug!(message = ?channel_join_request, "Send");

let written = ironrdp_pdu::encode_buf(&channel_join_request, output).map_err(ConnectorError::pdu)?;
written += ironrdp_pdu::encode_buf(&channel_join_request, output).map_err(ConnectorError::pdu)?;
}

(
Written::from_size(written)?,
ChannelConnectionState::WaitChannelJoinConfirm { user_channel_id, index },
ChannelConnectionState::WaitChannelJoinConfirm { user_channel_id },
)
}

ChannelConnectionState::WaitChannelJoinConfirm { user_channel_id, index } => {
let channel_id = self.channel_ids[index];

ChannelConnectionState::WaitChannelJoinConfirm { user_channel_id } => {
let channel_join_confirm =
ironrdp_pdu::decode::<mcs::ChannelJoinConfirm>(input).map_err(ConnectorError::pdu)?;

Expand All @@ -180,33 +178,31 @@ impl Sequence for ChannelConnectionSequence {
)
}

if channel_id != channel_join_confirm.requested_channel_id {
let is_expected = self.channel_ids.remove(&channel_join_confirm.requested_channel_id);

if !is_expected {
return Err(reason_err!(
"ChannelJoinConfirm",
"unexpected requested_channel_id in MCS Channel Join Confirm: received {}, got {}",
channel_id,
"unexpected requested_channel_id in MCS Channel Join Confirm: got {}, expected one of: {:?}",
channel_join_confirm.requested_channel_id,
self.channel_ids,
));
}

if channel_id != channel_join_confirm.channel_id {
if channel_join_confirm.requested_channel_id != channel_join_confirm.channel_id {
// We could handle that gracefully by updating the StaticChannelSet, but it doesn’t seem to ever happen.
return Err(reason_err!(
"ChannelJoinConfirm",
"unexpected channel_id in MCS Channel Join Confirm: received {}, got {}",
channel_id,
"a channel was joined with a different channel ID than requested: requested {}, got {}",
channel_join_confirm.requested_channel_id,
channel_join_confirm.channel_id,
));
}

let next_index = index.checked_add(1).unwrap();

let next_state = if next_index == self.channel_ids.len() {
let next_state = if self.channel_ids.is_empty() {
ChannelConnectionState::AllJoined { user_channel_id }
} else {
ChannelConnectionState::SendChannelJoinRequest {
user_channel_id,
index: next_index,
}
ChannelConnectionState::WaitChannelJoinConfirm { user_channel_id }
};

(Written::Nothing, next_state)
Expand Down

0 comments on commit f53b7cb

Please sign in to comment.