Skip to content

Commit

Permalink
Saner error handling (#17)
Browse files Browse the repository at this point in the history
Co-authored-by: Travis Athougies <travis@athougies.net>
  • Loading branch information
TroyNeubauer and tathougies authored Aug 10, 2024
1 parent aa0616e commit fcac9c5
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 96 deletions.
13 changes: 13 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,16 @@ jobs:
with:
token: ${{ secrets.GITHUB_TOKEN }}
args: --all-features

# TODO(Troy): re-enable after we bump to fix breaking change in winapi breaking our old compilation
# semver:
# name: Check semver
# runs-on: windows-latest
# steps:
# - uses: actions/checkout@v2
# - uses: actions-rs/toolchain@v1
# with:
# profile: minimal
# toolchain: stable
# override: true
# - uses: obi1kenobi/cargo-semver-checks-action@v2
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ targets = ["aarch64-pc-windows-msvc", "i686-pc-windows-msvc", "x86_64-pc-windows

[dependencies]
bitflags = "2"
getrandom = "0.2.15"
getrandom = "0.2"
ipnet = "2.3"
libloading = "0.8"
log = "0.4"
thiserror = "1.0"
widestring = "0.4"
windows-sys = { version = "0.52", features = [
windows-sys = { version = "0.59", features = [
"Win32_Foundation",
"Win32_Networking",
"Win32_Networking_WinSock",
Expand Down
21 changes: 10 additions & 11 deletions examples/demo_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,19 @@ fn main() {
get_demo_server_config(public.as_bytes()).expect("Failed to get demo server credentials");
println!("Connecting to {} - internal ip: {}", endpoint, internal_ip);

//Must be run as Administrator because we create network adapters
//Load the wireguard dll file so that we can call the underlying C functions
//Unsafe because we are loading an arbitrary dll file
// Must be run as Administrator because we create network adapters

// Load the wireguard dll file so that we can call the underlying C functions
// Unsafe because we are loading an arbitrary dll file
let wireguard =
unsafe { wireguard_nt::load_from_path("examples/wireguard_nt/bin/amd64/wireguard.dll") }
.expect("Failed to load wireguard dll");

//Try to open an adapter from the given pool with the name "Demo"
let adapter =
wireguard_nt::Adapter::open(wireguard, "Demo").unwrap_or_else(|(_, wireguard)| {
wireguard_nt::Adapter::create(wireguard, "WireGuard", "Demo", None)
.map_err(|e| e.0)
.expect("Failed to create wireguard adapter!")
});
// Try to open an adapter from the given pool with the name "Demo"
let adapter = wireguard_nt::Adapter::open(&wireguard, "Demo").unwrap_or_else(|_| {
wireguard_nt::Adapter::create(&wireguard, "WireGuard", "Demo", None)
.expect("Failed to create wireguard adapter!")
});
let mut interface_private = [0; 32];
let mut peer_pub = [0; 32];

Expand Down Expand Up @@ -64,7 +63,7 @@ fn main() {
Ok(()) => {}
Err(err) => panic!("Failed to set default route: {}", err),
}
assert!(adapter.up());
assert!(adapter.up().is_ok());

// Go to http://demo.wireguard.com/ and see the bandwidth numbers change!
println!("Printing peer bandwidth statistics");
Expand Down
108 changes: 57 additions & 51 deletions src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@ use windows_sys::Win32::{
};

use crate::log::AdapterLoggingLevel;
use crate::util;
use crate::util::{StructReader, UnsafeHandle};
use crate::util::{self, StructReader, UnsafeHandle};
use crate::wireguard_nt_raw::{
in6_addr, in_addr, wireguard, GUID, WIREGUARD_ADAPTER_HANDLE, WIREGUARD_ALLOWED_IP,
WIREGUARD_INTERFACE, WIREGUARD_INTERFACE_FLAG, WIREGUARD_PEER, WIREGUARD_PEER_FLAG,
_NET_LUID_LH,
};
use crate::WireGuardError;
use crate::{Error, Result, Wireguard};

/// Representation of a wireGuard adapter with safe idiomatic bindings to the functionality provided by
/// the WireGuard* C functions.
Expand Down Expand Up @@ -87,33 +86,18 @@ pub struct SetInterface {
pub peers: Vec<SetPeer>,
}

fn encode_name(
name: &str,
wireguard: Arc<wireguard>,
) -> Result<(U16CString, Arc<wireguard>), (WireGuardError, Arc<wireguard>)> {
let utf16 = match U16CString::from_str(name) {
Ok(u) => u,
Err(e) => return Err((e.into(), wireguard)),
};
fn encode_name(name: &str) -> Result<U16CString> {
let utf16 = U16CString::from_str(name)?;
let max = crate::MAX_NAME;
if utf16.len() >= max {
//max_characters is the maximum number of characters including the null terminator. And .len() measures the
//number of characters (excluding the null terminator). Therefore, we can hold a string with
//max_characters - 1 because the null terminator sits in the last element. A string
//of length max_characters needs max_characters + 1 to store the null terminator so the >=
//check holds
Err((
format!(
//TODO: Better error handling
"Length too large. Size: {}, Max: {}",
utf16.len(),
max,
)
.into(),
wireguard,
))
Err(Error::NameTooLarge)
} else {
Ok((utf16, wireguard))
Ok(utf16)
}
}

Expand All @@ -123,9 +107,9 @@ pub struct EnumeratedAdapter {
pub name: String,
}

fn win_error(context: &str, error_code: u32) -> Result<(), Box<dyn std::error::Error>> {
fn win_error(context: &str, error_code: u32) -> Result<()> {
let e = std::io::Error::from_raw_os_error(error_code as i32);
Err(format!("{} - {}", context, e).into())
Err(Error::Windows(context.to_string(), e))
}

const WIREGUARD_STATE_DOWN: i32 = 0;
Expand All @@ -138,13 +122,13 @@ impl Adapter {
///
/// Optionally a GUID can be specified that will become the GUID of this adapter once created.
pub fn create(
wireguard: Arc<wireguard>,
wireguard: &Wireguard,
pool: &str,
name: &str,
guid: Option<u128>,
) -> Result<Adapter, (WireGuardError, Arc<wireguard>)> {
let (pool_utf16, wireguard) = encode_name(pool, wireguard)?;
let (name_utf16, wireguard) = encode_name(name, wireguard)?;
) -> Result<Adapter> {
let pool_utf16 = encode_name(pool)?;
let name_utf16 = encode_name(name)?;

let guid = guid.unwrap_or_else(|| {
let mut guid_bytes = [0u8; 16];
Expand All @@ -159,7 +143,7 @@ impl Adapter {
//the byte order of the segments of the GUID struct that are larger than a byte. Verify
//that this works as expected

crate::log::set_default_logger_if_unset(&wireguard);
crate::log::set_default_logger_if_unset(wireguard);

//SAFETY: the function is loaded from the wireguard dll properly, we are providing valid
//pointers, and all the strings are correct null terminated UTF-16. This safety rationale
Expand All @@ -173,38 +157,35 @@ impl Adapter {
};

if result.is_null() {
Err(("Failed to create adapter".into(), wireguard))
Err(Error::Driver(std::io::Error::last_os_error()))
} else {
Ok(Self {
adapter: UnsafeHandle(result),
wireguard,
wireguard: Arc::clone(&wireguard.0),
})
}
}

/// Attempts to open an existing wireguard with name `name`.
pub fn open(
wireguard: Arc<wireguard>,
name: &str,
) -> Result<Adapter, (WireGuardError, Arc<wireguard>)> {
let (name_utf16, wireguard) = encode_name(name, wireguard)?;
pub fn open(wireguard: &Wireguard, name: &str) -> Result<Adapter> {
let name_utf16 = encode_name(name)?;

crate::log::set_default_logger_if_unset(&wireguard);
crate::log::set_default_logger_if_unset(wireguard);

let result = unsafe { wireguard.WireGuardOpenAdapter(name_utf16.as_ptr()) };

if result.is_null() {
Err(("WireGuardOpenAdapter failed".into(), wireguard))
Err(Error::Driver(std::io::Error::last_os_error()))
} else {
Ok(Adapter {
adapter: UnsafeHandle(result),
wireguard,
wireguard: Arc::clone(&wireguard.0),
})
}
}

/// Sets the wireguard configuration of this adapter
pub fn set_config(&self, config: &SetInterface) -> Result<(), WireGuardError> {
pub fn set_config(&self, config: &SetInterface) -> Result<()> {
bitflags::bitflags! {
struct InterfaceFlags: i32 {
const HAS_PUBLIC_KEY = 1 << 0;
Expand Down Expand Up @@ -345,7 +326,7 @@ impl Adapter {
};

match result {
0 => Err("WireGuardSetConfiguration failed".into()),
0 => Err(Error::Driver(std::io::Error::last_os_error())),
_ => Ok(()),
}
}
Expand All @@ -356,7 +337,7 @@ impl Adapter {
&self,
interface_addrs: &[IpNet],
config: &SetInterface,
) -> Result<(), Box<dyn std::error::Error>> {
) -> Result<()> {
// Set the route with metric = 0 (highest priority / default)
self.set_route_with_metric(interface_addrs, config, 0)
}
Expand All @@ -369,7 +350,7 @@ impl Adapter {
interface_addrs: &[IpNet],
config: &SetInterface,
metric: u32,
) -> Result<(), Box<dyn std::error::Error>> {
) -> Result<()> {
let luid = self.get_luid();
unsafe {
for allowed_ip in config.peers.iter().flat_map(|p| p.allowed_ips.iter()) {
Expand Down Expand Up @@ -403,7 +384,7 @@ impl Adapter {

let err = CreateIpForwardEntry2(&default_route);
if err != ERROR_SUCCESS && err != ERROR_OBJECT_ALREADY_EXISTS {
return win_error("Failed to set default route", err);
return win_error("CreateIpForwardEntry2", err);
}
}

Expand Down Expand Up @@ -440,42 +421,67 @@ impl Adapter {

let err = CreateUnicastIpAddressEntry(&address_row);
if err != ERROR_SUCCESS && err != ERROR_OBJECT_ALREADY_EXISTS {
return win_error("Failed to set IP interface", err);
return win_error("CreateUnicastIpAddressEntry", err);
}
}

let err = GetIpInterfaceEntry(&mut ip_interface);
if err != ERROR_SUCCESS {
return win_error("Failed to get IP interface", err);
return win_error("GetIpInterfaceEntry", err);
}
ip_interface.UseAutomaticMetric = 0;
ip_interface.Metric = metric;
ip_interface.NlMtu = 1420;
ip_interface.SitePrefixLength = 0;
let err = SetIpInterfaceEntry(&mut ip_interface);
if err != ERROR_SUCCESS {
return win_error("Failed to set metric and MTU", err);
return win_error("SetIpInterfaceEntry", err);
}

Ok(())
}
}

/// Get the state of this adapter
pub fn is_up(&self) -> Result<bool> {
let mut state = 0;
let success = unsafe {
self.wireguard
.WireGuardGetAdapterState(self.adapter.0, &mut state)
!= 0
};
if success {
Ok(state == WIREGUARD_STATE_UP)
} else {
Err(Error::Driver(std::io::Error::last_os_error()))
}
}

/// Puts this adapter into the up state
pub fn up(&self) -> bool {
unsafe {
pub fn up(&self) -> Result<()> {
let success = unsafe {
self.wireguard
.WireGuardSetAdapterState(self.adapter.0, WIREGUARD_STATE_UP)
!= 0
};
if success {
Ok(())
} else {
Err(Error::Driver(std::io::Error::last_os_error()))
}
}

/// Puts this adapter into the down state
pub fn down(&self) -> bool {
unsafe {
pub fn down(&self) -> Result<()> {
let success = unsafe {
self.wireguard
.WireGuardSetAdapterState(self.adapter.0, WIREGUARD_STATE_DOWN)
!= 0
};
if success {
Ok(())
} else {
Err(Error::Driver(std::io::Error::last_os_error()))
}
}

Expand Down
Loading

0 comments on commit fcac9c5

Please sign in to comment.