Skip to content

Commit

Permalink
Merge pull request #401 from cBournhonesque/cb/add-hook-to-accept-con…
Browse files Browse the repository at this point in the history
…nection

Add user-provided hook to accept/reject connections
  • Loading branch information
cBournhonesque authored May 27, 2024
2 parents e610332 + 4faab86 commit 9d0dfb7
Show file tree
Hide file tree
Showing 9 changed files with 330 additions and 23 deletions.
1 change: 1 addition & 0 deletions examples/common/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ pub(crate) fn get_server_net_configs(settings: &Settings) -> Vec<server::NetConf
game_port: *game_port,
query_port: *query_port,
max_clients: 16,
accept_connection_request_fn: None,
version: "1.0".to_string(),
},
conditioner: settings
Expand Down
8 changes: 7 additions & 1 deletion examples/simple_box/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use bevy::utils::HashMap;
use lightyear::prelude::server::*;
use lightyear::prelude::*;
use lightyear::shared::replication::components::ReplicationTarget;
use std::sync::Arc;

use crate::protocol::*;
use crate::shared;
Expand All @@ -28,7 +29,12 @@ impl Plugin for ExampleServerPlugin {
}

/// Start the server
fn start_server(mut commands: Commands) {
fn start_server(mut config: ResMut<ServerConfig>, mut commands: Commands) {
for net_config in &mut config.net {
net_config.set_accept_connection_request_fn(Arc::new(|client_id| {
client_id != ClientId::Netcode(0)
}));
}
commands.start_server();
}

Expand Down
6 changes: 5 additions & 1 deletion lightyear/src/connection/netcode/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,13 @@ impl<Ctx> NetcodeClient<Ctx> {
}
match (packet, self.state) {
(
Packet::Denied(_),
Packet::Denied(pkt),
ClientState::SendingConnectionRequest | ClientState::SendingChallengeResponse,
) => {
error!(
"client connection denied by server. Reason: {:?}",
pkt.reason
);
self.should_disconnect = true;
self.should_disconnect_state = ClientState::ConnectionDenied;
}
Expand Down
137 changes: 129 additions & 8 deletions lightyear/src/connection/netcode/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use chacha20poly1305::XNonce;
use tracing::debug;

use crate::connection::netcode::ClientId;
use crate::connection::server::DeniedReason;

use super::{
bytes::Bytes,
Expand Down Expand Up @@ -158,22 +159,107 @@ impl Bytes for RequestPacket {
}
}

pub struct DeniedPacket {}
pub struct DeniedPacket {
pub reason: DeniedReason,
}

impl DeniedPacket {
pub fn create() -> Packet<'static> {
Packet::Denied(DeniedPacket {})
pub fn create(reason: DeniedReason) -> Packet<'static> {
Packet::Denied(DeniedPacket { reason })
}
}

impl Bytes for DeniedReason {
type Error = io::Error;

fn write_to(&self, writer: &mut impl WriteBytesExt) -> Result<(), Self::Error> {
match self {
DeniedReason::ServerFull => {
writer.write_u8(0)?;
}
DeniedReason::Banned => {
writer.write_u8(1)?;
}
DeniedReason::InternalError => {
writer.write_u8(2)?;
}
DeniedReason::AlreadyConnected => {
writer.write_u8(3)?;
}
DeniedReason::TokenAlreadyUsed => {
writer.write_u8(4)?;
}
DeniedReason::InvalidToken => {
writer.write_u8(5)?;
}
DeniedReason::Custom(reason) => {
writer.write_u8(6)?;
// the reason cannot exceed u8::MAX in size
if reason.len() > u8::MAX as usize {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"custom denied reason too long",
));
}
writer.write_u8(reason.len() as u8)?;
let num_write = writer.write(reason.as_bytes())?;
if num_write != reason.len() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid denied reason",
));
}
}
}
Ok(())
}

fn read_from(reader: &mut impl ReadBytesExt) -> Result<Self, Self::Error> {
let variant = reader.read_u8()?;
if variant == 0 {
Ok(DeniedReason::ServerFull)
} else if variant == 1 {
Ok(DeniedReason::Banned)
} else if variant == 2 {
Ok(DeniedReason::InternalError)
} else if variant == 3 {
Ok(DeniedReason::AlreadyConnected)
} else if variant == 4 {
Ok(DeniedReason::TokenAlreadyUsed)
} else if variant == 5 {
Ok(DeniedReason::InvalidToken)
} else if variant == 6 {
let len = reader.read_u8()? as usize;
let mut reason = [0; u8::MAX as usize];
let num_read = reader.read(&mut reason)?;
if num_read != len {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid denied reason",
));
}
let reason_str = String::from_utf8(reason[..len].to_vec())
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid denied reason"))?;
Ok(DeniedReason::Custom(reason_str))
} else {
Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid denied reason",
))
}
}
}

impl Bytes for DeniedPacket {
type Error = io::Error;
fn write_to(&self, _writer: &mut impl WriteBytesExt) -> Result<(), Self::Error> {
fn write_to(&self, writer: &mut impl WriteBytesExt) -> Result<(), Self::Error> {
self.reason.write_to(writer)?;
Ok(())
}

fn read_from(_reader: &mut impl byteorder::ReadBytesExt) -> Result<Self, io::Error> {
Ok(Self {})
fn read_from(reader: &mut impl byteorder::ReadBytesExt) -> Result<Self, io::Error> {
let reason = DeniedReason::read_from(reader)?;
Ok(Self { reason })
}
}

Expand Down Expand Up @@ -578,14 +664,48 @@ mod tests {
assert_eq!(connect_token_private.user_data, user_data);
}

#[test]
fn denied_packet_custom_reason() {
let packet_key = generate_key();
let protocol_id = 0x1234_5678_9abc_def0;
let sequence = 0u64;
let mut replay_protection = ReplayProtection::new();

let packet = Packet::Denied(DeniedPacket {
reason: DeniedReason::Custom(String::from("a")),
});

let mut buf = [0u8; MAX_PKT_BUF_SIZE];
let size = packet
.write(&mut buf, sequence, &packet_key, protocol_id)
.unwrap();

let packet = Packet::read(
&mut buf[..size],
protocol_id,
0,
packet_key,
Some(&mut replay_protection),
0xff,
)
.unwrap();

let Packet::Denied(denied_pkt) = packet else {
panic!("wrong packet type");
};
assert_eq!(denied_pkt.reason, DeniedReason::Custom(String::from("a")));
}

#[test]
fn denied_packet() {
let packet_key = generate_key();
let protocol_id = 0x1234_5678_9abc_def0;
let sequence = 0u64;
let mut replay_protection = ReplayProtection::new();

let packet = Packet::Denied(DeniedPacket {});
let packet = Packet::Denied(DeniedPacket {
reason: DeniedReason::ServerFull,
});

let mut buf = [0u8; MAX_PKT_BUF_SIZE];
let size = packet
Expand All @@ -602,9 +722,10 @@ mod tests {
)
.unwrap();

let Packet::Denied(_denied_pkt) = packet else {
let Packet::Denied(denied_pkt) = packet else {
panic!("wrong packet type");
};
assert_eq!(denied_pkt.reason, DeniedReason::ServerFull);
}

#[test]
Expand Down
27 changes: 24 additions & 3 deletions lightyear/src/connection/netcode/server.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::{HashMap, VecDeque};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};

use anyhow::{anyhow, Context};
Expand All @@ -8,7 +9,9 @@ use tracing::{debug, error, trace};

use crate::connection::id;
use crate::connection::netcode::token::TOKEN_EXPIRE_SEC;
use crate::connection::server::{IoConfig, NetServer};
use crate::connection::server::{
ConnectionRequestHandler, DefaultConnectionRequestHandler, DeniedReason, IoConfig, NetServer,
};
use crate::serialize::bitcode::reader::BufferPool;
use crate::serialize::reader::ReadBuffer;
use crate::server::config::NetcodeConfig;
Expand Down Expand Up @@ -246,6 +249,7 @@ pub struct ServerConfig<Ctx> {
keep_alive_send_rate: f64,
token_expire_secs: i32,
client_timeout_secs: i32,
connection_request_handler: Arc<dyn ConnectionRequestHandler>,
server_addr: SocketAddr,
context: Ctx,
on_connect: Option<Callback<Ctx>>,
Expand All @@ -259,6 +263,7 @@ impl Default for ServerConfig<()> {
keep_alive_send_rate: PACKET_SEND_RATE_SEC,
token_expire_secs: TOKEN_EXPIRE_SEC,
client_timeout_secs: CLIENT_TIMEOUT_SECS,
connection_request_handler: Arc::new(DefaultConnectionRequestHandler),
server_addr: SocketAddr::from(([0, 0, 0, 0], 0)),
context: (),
on_connect: None,
Expand All @@ -279,6 +284,7 @@ impl<Ctx> ServerConfig<Ctx> {
keep_alive_send_rate: PACKET_SEND_RATE_SEC,
token_expire_secs: TOKEN_EXPIRE_SEC,
client_timeout_secs: CLIENT_TIMEOUT_SECS,
connection_request_handler: Arc::new(DefaultConnectionRequestHandler),
server_addr: SocketAddr::from(([0, 0, 0, 0], 0)),
context: ctx,
on_connect: None,
Expand Down Expand Up @@ -605,13 +611,27 @@ impl<Ctx> NetcodeServer<Ctx> {
if self.num_connected_clients() >= MAX_CLIENTS {
debug!("server denied connection request. server is full");
self.send_to_addr(
DeniedPacket::create(),
DeniedPacket::create(DeniedReason::ServerFull),
from_addr,
token.server_to_client_key,
sender,
)?;
return Ok(());
};
if let Some(denied_reason) = self
.cfg
.connection_request_handler
.handle_request(crate::prelude::ClientId::Netcode(token.client_id))
{
debug!("server denied connection request. handle_connection_request_fn returned false");
self.send_to_addr(
DeniedPacket::create(denied_reason),
from_addr,
token.server_to_client_key,
sender,
)?;
return Ok(());
}
self.conn_cache.add(
token.client_id,
from_addr,
Expand Down Expand Up @@ -662,7 +682,7 @@ impl<Ctx> NetcodeServer<Ctx> {
if self.num_connected_clients() >= MAX_CLIENTS {
debug!("server denied connection response. server is full");
self.send_to_addr(
DeniedPacket::create(),
DeniedPacket::create(DeniedReason::ServerFull),
from_addr,
self.conn_cache
.clients
Expand Down Expand Up @@ -1136,6 +1156,7 @@ impl Server {
cfg = cfg.keep_alive_send_rate(config.keep_alive_send_rate);
cfg = cfg.num_disconnect_packets(config.num_disconnect_packets);
cfg = cfg.client_timeout_secs(config.client_timeout_secs);
cfg.connection_request_handler = config.connection_request_handler;
let server = NetcodeServer::with_config(config.protocol_id, config.private_key, cfg)
.expect("Could not create server netcode");

Expand Down
54 changes: 53 additions & 1 deletion lightyear/src/connection/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ use anyhow::{anyhow, Result};
use bevy::prelude::Resource;
use bevy::utils::HashMap;
use enum_dispatch::enum_dispatch;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::fmt::Debug;
use std::net::SocketAddr;
use std::sync::Arc;

use crate::connection::id::ClientId;
#[cfg(all(feature = "steam", not(target_family = "wasm")))]
Expand All @@ -15,6 +19,36 @@ use crate::server::config::NetcodeConfig;
use crate::server::io::Io;
use crate::transport::config::SharedIoConfig;

/// Reasons for denying a connection request
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum DeniedReason {
ServerFull,
Banned,
InternalError,
AlreadyConnected,
TokenAlreadyUsed,
InvalidToken,
Custom(String),
}

/// Trait for handling connection requests from clients.
pub trait ConnectionRequestHandler: Debug + Send + Sync {
/// Handle a connection request from a client.
/// Returns None if the connection is accepted,
/// Returns Some(reason) if the connection is denied.
fn handle_request(&self, client_id: ClientId) -> Option<DeniedReason>;
}

/// By default, all connection requests are accepted by the server.
#[derive(Debug, Clone)]
pub struct DefaultConnectionRequestHandler;

impl ConnectionRequestHandler for DefaultConnectionRequestHandler {
fn handle_request(&self, client_id: ClientId) -> Option<DeniedReason> {
None
}
}

#[enum_dispatch]
pub trait NetServer: Send + Sync {
/// Start the server
Expand Down Expand Up @@ -75,6 +109,24 @@ pub enum NetConfig {
},
}

impl NetConfig {
/// Update the `accept_connection_request_fn` field in the config
pub fn set_connection_request_handler(
&mut self,
connection_request_handler: Arc<dyn ConnectionRequestHandler>,
) {
match self {
NetConfig::Netcode { config, .. } => {
config.connection_request_handler = connection_request_handler;
}
#[cfg(all(feature = "steam", not(target_family = "wasm")))]
NetConfig::Steam { config, .. } => {
config.connection_request_handler = connection_request_handler;
}
}
}
}

impl Default for NetConfig {
fn default() -> Self {
NetConfig::Netcode {
Expand Down Expand Up @@ -115,7 +167,7 @@ type ServerConnectionIdx = usize;
#[derive(Resource)]
pub struct ServerConnections {
/// list of the various `ServerConnection`s available. Will be static after first insertion.
pub(crate) servers: Vec<ServerConnection>,
pub servers: Vec<ServerConnection>,
/// Mapping from the connection's [`ClientId`] into the index of the [`ServerConnection`] in the `servers` list
pub(crate) client_server_map: HashMap<ClientId, ServerConnectionIdx>,
/// Track whether the server is ready to listen to incoming connections
Expand Down
Loading

0 comments on commit 9d0dfb7

Please sign in to comment.