forked from torvalds/linux
-
Notifications
You must be signed in to change notification settings - Fork 430
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
net/tcp: add Rust implementation of BIC
Reimplement the Binary Increase Congestion (BIC) control algorithm in Rust. BIC is one of the smallest CCAs in the kernel and this mainly serves as a minimal example for a real-world algorithm.
- Loading branch information
Valentin Obst
committed
Feb 18, 2024
1 parent
17fbe62
commit 1ecb4d6
Showing
3 changed files
with
314 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,300 @@ | ||
//! SPDX-License-Identifier: GPL-2.0 | ||
//! | ||
//! Binary Increase Congestion control (BIC). Based on: | ||
//! Binary Increase Congestion Control (BIC) for Fast Long-Distance | ||
//! Networks - Lisong Xu, Khaled Harfoush, and Injong Rhee | ||
//! IEEE INFOCOM 2004, Hong Kong, China, 2004, pp. 2514-2524 vol.4 | ||
//! doi: 10.1109/INFCOM.2004.1354672 | ||
//! Link: https://doi.org/10.1109/INFCOM.2004.1354672 | ||
//! Link: https://web.archive.org/web/20160417213452/http://netsrv.csc.ncsu.edu/export/bitcp.pdf | ||
|
||
use core::cmp::{max, min}; | ||
use core::num::NonZeroU32; | ||
use kernel::net::tcp::cong; | ||
use kernel::prelude::*; | ||
use kernel::time; | ||
use kernel::{c_str, module_cca}; | ||
|
||
const ACK_RATIO_SHIFT: u32 = 4; | ||
|
||
// TODO: Convert to module parameters once they are available. | ||
/// Value of ssthresh for new connections. | ||
const INITIAL_SSTHRESH: Option<u32> = None; | ||
/// If cwnd is larger than this threshold, BIC engages; otherwise normal TCP | ||
/// increase/decrease will be performed. | ||
// NOTE: cwnd is expressed in units of full-sized segments. | ||
const LOW_WINDOW: u32 = 14; | ||
/// In binary search, go to point: `cwnd + (W_max - cwnd) / BICTCP_B`. | ||
// SAFETY: This will panic at compile time when passing zero. | ||
// TODO: Convert to `new::(x).unwrap()` once 'const_option' is stabilised. | ||
const BICTCP_B: NonZeroU32 = unsafe { NonZeroU32::new_unchecked(4) }; | ||
/// The maximum increment, i.e., `S_max`. This is used during additive increase. | ||
/// After crossing `W_max`, slow start is performed until passing | ||
/// `MAX_INCREMENT * (BICTCP_B - 1)`. | ||
// SAFETY: This will panic at compile time when passing zero. | ||
const MAX_INCREMENT: NonZeroU32 = unsafe { NonZeroU32::new_unchecked(16) }; | ||
/// The number of RTT it takes to get from `W_max - BICTCP_B` to `W_max` (and | ||
/// from `W_max` to `W_max + BICTCP_B`). This is not part of the original paper | ||
/// and results in a slow additive increase across `W_max`. | ||
const SMOOTH_PART: u32 = 20; | ||
/// Enable or disable fast convergence. | ||
const FAST_CONVERGENCE: bool = true; | ||
/// Factor for multiplicative decrease. In fast retransmit we have: | ||
/// `cwnd = cwnd * BETA/BETA_SCALE` | ||
/// and if fast convergence is active: | ||
/// `W_max = cwnd * (1 + BETA/BETA_SCALE)/2` | ||
/// instead of `W_max = cwnd`. | ||
const BETA: u32 = 819; | ||
/// Used to calculate beta in [0, 1] with integer arithmetics. | ||
// SAFETY: This will panic at compile time when passing zero. | ||
const BETA_SCALE: NonZeroU32 = unsafe { NonZeroU32::new_unchecked(1024) }; | ||
/// The minimum amount of time that has to pass between two updates of the cwnd. | ||
const MIN_UPDATE_INTERVAL: time::Nsecs = 31250000; | ||
|
||
module_cca! { | ||
type: Bic, | ||
name: "tcp_bic_rust", | ||
author: "Rust for Linux Contributors", | ||
description: "Binary Increase Congestion (BIC) control algorithm, Rust implementation", | ||
license: "GPL v2", | ||
} | ||
|
||
struct Bic {} | ||
|
||
#[vtable] | ||
impl cong::Algorithm for Bic { | ||
type Data = BicState; | ||
|
||
const NAME: &'static CStr = c_str!("bic_rust"); | ||
|
||
fn pkts_acked(sk: &mut cong::Sock<'_, Self>, sample: &cong::AckSample) { | ||
if let Ok(cong::State::Open) = sk.inet_csk().ca_state() { | ||
let ca = sk.inet_csk_ca_mut(); | ||
|
||
// This is supposed to wrap. | ||
ca.delayed_ack = ca.delayed_ack.wrapping_add( | ||
sample | ||
.pkts_acked() | ||
.wrapping_sub(ca.delayed_ack >> ACK_RATIO_SHIFT), | ||
); | ||
} | ||
} | ||
|
||
fn ssthresh(sk: &mut cong::Sock<'_, Self>) -> u32 { | ||
let cwnd = sk.tcp_sk().snd_cwnd(); | ||
let ca = sk.inet_csk_ca_mut(); | ||
|
||
pr_info!( | ||
// TODO: remove | ||
"Enter fast retransmit: time {}, start {}", | ||
time::ktime_get_boot_fast_ns(), | ||
ca.start_time | ||
); | ||
|
||
// Epoch has ended. | ||
ca.epoch_start = 0; | ||
ca.last_max_cwnd = if cwnd < ca.last_max_cwnd && FAST_CONVERGENCE { | ||
(cwnd * (BETA_SCALE.get() + BETA)) / (2 * BETA_SCALE.get()) | ||
} else { | ||
cwnd | ||
}; | ||
|
||
if cwnd <= LOW_WINDOW { | ||
max(cwnd >> 1, 2) | ||
} else { | ||
max((cwnd * BETA) / BETA_SCALE, 2) | ||
} | ||
} | ||
|
||
fn cong_avoid(sk: &mut cong::Sock<'_, Self>, _ack: u32, mut acked: u32) { | ||
if !sk.tcp_is_cwnd_limited() { | ||
return; | ||
} | ||
|
||
let tp = sk.tcp_sk_mut(); | ||
|
||
if tp.in_slow_start() { | ||
acked = tp.slow_start(acked); | ||
if acked == 0 { | ||
pr_info!( | ||
// TODO: remove | ||
"New cwnd {}, time {}, ssthresh {}, start {}, ss 1", | ||
sk.tcp_sk().snd_cwnd(), | ||
time::ktime_get_boot_fast_ns(), | ||
sk.tcp_sk().snd_ssthresh(), | ||
sk.inet_csk_ca().start_time | ||
); | ||
return; | ||
} | ||
} | ||
|
||
let cwnd = tp.snd_cwnd(); | ||
let cnt = sk.inet_csk_ca_mut().update(cwnd); | ||
sk.tcp_sk_mut().cong_avoid_ai(cnt, acked); | ||
|
||
pr_info!( | ||
// TODO: remove | ||
"New cwnd {}, time {}, ssthresh {}, start {}, ss 0", | ||
sk.tcp_sk().snd_cwnd(), | ||
time::ktime_get_boot_fast_ns(), | ||
sk.tcp_sk().snd_ssthresh(), | ||
sk.inet_csk_ca().start_time | ||
); | ||
} | ||
|
||
fn set_state(sk: &mut cong::Sock<'_, Self>, new_state: cong::State) { | ||
if matches!(new_state, cong::State::Loss) { | ||
pr_info!( | ||
// TODO: remove | ||
"Retransmission timeout fired: time {}, start {}", | ||
time::ktime_get_boot_fast_ns(), | ||
sk.inet_csk_ca().start_time | ||
); | ||
sk.inet_csk_ca_mut().reset() | ||
} | ||
} | ||
|
||
fn undo_cwnd(sk: &mut cong::Sock<'_, Self>) -> u32 { | ||
pr_info!( | ||
// TODO: remove | ||
"Undo cwnd reduction: time {}, start {}", | ||
time::ktime_get_boot_fast_ns(), | ||
sk.inet_csk_ca().start_time | ||
); | ||
|
||
cong::reno::undo_cwnd(sk) | ||
} | ||
|
||
fn init(sk: &mut cong::Sock<'_, Self>) { | ||
if let Some(ssthresh) = INITIAL_SSTHRESH { | ||
sk.tcp_sk_mut().set_snd_ssthresh(ssthresh); | ||
} | ||
|
||
// TODO: remove | ||
pr_info!("Socket created: start {}", sk.inet_csk_ca().start_time); | ||
} | ||
|
||
// TODO: remove | ||
fn release(sk: &mut cong::Sock<'_, Self>) { | ||
pr_info!( | ||
"Socket destroyed: start {}, end {}", | ||
sk.inet_csk_ca().start_time, | ||
time::ktime_get_boot_fast_ns() | ||
); | ||
} | ||
} | ||
|
||
/// Internal state of each instance of the algorithm. | ||
struct BicState { | ||
/// During congestion avoidance, cwnd is increased at most every `cnt` | ||
/// acknowledged packets, i.e., the average increase per acknowledged packet | ||
/// is proportional to `1 / cnt`. | ||
// NOTE: The C impl initialises this to zero. It then ensures that zero is | ||
// never passed to `cong_avoid_ai`, which could divide by it. Make it | ||
// explicit in the types that zero is not a valid value. | ||
cnt: NonZeroU32, | ||
/// Last maximum `snd_cwnd`, i.e, `W_max`. | ||
last_max_cwnd: u32, | ||
/// The last `snd_cwnd`. | ||
last_cwnd: u32, | ||
/// Time when `last_cwnd` was updated. | ||
last_time: time::Nsecs, | ||
/// Records the beginning of an epoch. | ||
epoch_start: time::Nsecs, | ||
/// Estimates the ratio of `packets/ACK << 4`. This allows us to adjust cwnd | ||
/// per packet when a receiver is sending a single ACK for multiple received | ||
/// packets. | ||
delayed_ack: u32, | ||
/// Time when algorithm was initialised. | ||
// TODO: remove | ||
start_time: time::Nsecs, | ||
} | ||
|
||
impl Default for BicState { | ||
fn default() -> Self { | ||
Self { | ||
// NOTE: Initializing this to 1 deviates from the C code. It does | ||
// not change the behavior. | ||
cnt: NonZeroU32::MIN, | ||
last_max_cwnd: 0, | ||
last_cwnd: 0, | ||
last_time: 0, | ||
epoch_start: 0, | ||
delayed_ack: 2 << ACK_RATIO_SHIFT, | ||
// TODO: remove | ||
start_time: time::ktime_get_boot_fast_ns(), | ||
} | ||
} | ||
} | ||
|
||
impl BicState { | ||
/// Compute congestion window to use. Returns the new `cnt`. | ||
/// | ||
/// This governs the behavior of the algorithm during congestion avoidance. | ||
fn update(&mut self, cwnd: u32) -> NonZeroU32 { | ||
let timestamp = time::ktime_get_boot_fast_ns(); | ||
|
||
// Do nothing if we are invoked too frequently. | ||
if self.last_cwnd == cwnd && (timestamp - self.last_time) <= MIN_UPDATE_INTERVAL { | ||
return self.cnt; | ||
} | ||
|
||
self.last_cwnd = cwnd; | ||
self.last_time = timestamp; | ||
|
||
// Record the beginning of an epoch. | ||
if self.epoch_start == 0 { | ||
self.epoch_start = timestamp; | ||
} | ||
|
||
// Start off like normal TCP. | ||
if cwnd <= LOW_WINDOW { | ||
self.cnt = NonZeroU32::new(cwnd).unwrap_or(NonZeroU32::MIN); | ||
return self.cnt; | ||
} | ||
|
||
let mut new_cnt = if cwnd < self.last_max_cwnd { | ||
// binary increase | ||
let dist: u32 = (self.last_max_cwnd - cwnd) / BICTCP_B; | ||
|
||
if dist > MAX_INCREMENT.get() { | ||
cwnd / MAX_INCREMENT // additive increase | ||
} else if dist <= 1 { | ||
(cwnd * SMOOTH_PART) / BICTCP_B // careful additive increase | ||
} else { | ||
cwnd / dist // binary search | ||
} | ||
} else { | ||
if cwnd < self.last_max_cwnd + BICTCP_B.get() { | ||
(cwnd * SMOOTH_PART) / BICTCP_B // careful additive increase | ||
} else if cwnd < self.last_max_cwnd + MAX_INCREMENT.get() * (BICTCP_B.get() - 1) { | ||
(cwnd * (BICTCP_B.get() - 1)) / (cwnd - self.last_max_cwnd) // slow start | ||
} else { | ||
cwnd / MAX_INCREMENT // linear increase | ||
} | ||
}; | ||
|
||
// If in initial slow start or link utilization is very low. | ||
if self.last_max_cwnd == 0 { | ||
new_cnt = min(new_cnt, 20); | ||
} | ||
|
||
// Account for estimated packets/ACK to ensure that we increase per | ||
// packet. | ||
new_cnt = (new_cnt << ACK_RATIO_SHIFT) / self.delayed_ack; | ||
|
||
self.cnt = NonZeroU32::new(new_cnt).unwrap_or(NonZeroU32::MIN); | ||
|
||
self.cnt | ||
} | ||
|
||
fn reset(&mut self) { | ||
// TODO: remove | ||
let tmp = self.start_time; | ||
|
||
*self = Self::default(); | ||
|
||
// TODO: remove | ||
self.start_time = tmp; | ||
} | ||
} |