Skip to content

Commit

Permalink
Merge pull request #46 from quartiq/feature/dns-support
Browse files Browse the repository at this point in the history
Adding smoltcp-based DNS support
  • Loading branch information
ryan-summers authored Aug 22, 2023
2 parents 360ad26 + 2627081 commit 8ba6776
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 1 deletion.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

This document describes the changes to smoltcp-nal between releases.

# [Unreleased]

## Added
* Added support for `embedded_nal::Dns` traits

# [0.4.0] - 2023-07-21

## Added
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ features = ["wyrand"]

[dependencies.smoltcp]
version = "0.10"
features = ["medium-ethernet", "proto-ipv6", "socket-tcp", "socket-dhcpv4", "socket-udp"]
features = ["medium-ethernet", "proto-ipv6", "socket-tcp", "socket-dns", "socket-dhcpv4", "socket-udp"]
default-features = false

[dependencies.shared-bus]
Expand Down
87 changes: 87 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ pub enum SmoltcpError {
#[derive(Debug, Copy, Clone)]
pub enum NetworkError {
NoSocket,
DnsStart(smoltcp::socket::dns::StartQueryError),
DnsFailure,
UdpConnectionFailure(smoltcp::socket::udp::BindError),
TcpConnectionFailure(smoltcp::socket::tcp::ConnectError),
TcpReadFailure(smoltcp::socket::tcp::RecvError),
Expand Down Expand Up @@ -113,6 +115,8 @@ where
device: Device,
sockets: smoltcp::iface::SocketSet<'a>,
dhcp_handle: Option<SocketHandle>,
dns_handle: Option<SocketHandle>,
dns_lookups: heapless::LinearMap<heapless::String<255>, smoltcp::socket::dns::QueryHandle, 2>,
unused_tcp_handles: Vec<SocketHandle, 16>,
unused_udp_handles: Vec<SocketHandle, 16>,
clock: Clock,
Expand Down Expand Up @@ -152,6 +156,7 @@ where
let mut unused_tcp_handles: Vec<SocketHandle, 16> = Vec::new();
let mut unused_udp_handles: Vec<SocketHandle, 16> = Vec::new();
let mut dhcp_handle: Option<SocketHandle> = None;
let mut dns_handle: Option<SocketHandle> = None;

for (handle, socket) in sockets.iter() {
match socket {
Expand All @@ -164,6 +169,9 @@ where
smoltcp::socket::Socket::Dhcpv4(_) => {
dhcp_handle.replace(handle);
}
smoltcp::socket::Socket::Dns(_) => {
dns_handle.replace(handle);
}

// This branch may be enabled through cargo feature unification (e.g. if an
// application enables raw-sockets). To accomodate this, we provide a default match
Expand All @@ -178,9 +186,11 @@ where
sockets,
device,
dhcp_handle,
dns_handle,
unused_tcp_handles,
unused_udp_handles,
last_poll: None,
dns_lookups: heapless::LinearMap::new(),
clock,
stack_time: smoltcp::time::Instant::from_secs(0),
rand: WyRand::new_seed(0),
Expand Down Expand Up @@ -234,6 +244,7 @@ where
// Service the DHCP client.
if let Some(handle) = self.dhcp_handle {
let mut close_sockets = false;
let mut dns_server = None;

if let Some(event) = self.sockets.get_mut::<dhcpv4::Socket>(handle).poll() {
match event {
Expand All @@ -246,6 +257,15 @@ where
Self::set_ipv4_addr(&mut self.network_interface, config.address);
}

if let Some(server) = config
.dns_servers
.iter()
.next()
.map(|ipv4| smoltcp::wire::IpAddress::Ipv4(*ipv4))
{
dns_server.replace(server);
}

if let Some(route) = config.router {
// Note: If the user did not provide enough route storage, we may not be
// able to store the gateway.
Expand Down Expand Up @@ -274,6 +294,18 @@ where
if close_sockets {
self.close_sockets();
}

if let Some((server, handle)) = dns_server.zip(self.dns_handle) {
let dns = self.sockets.get_mut::<smoltcp::socket::dns::Socket>(handle);

// Clear out all pending DNS queries now that we have a new server.
for (_query, handle) in self.dns_lookups.iter() {
dns.cancel_query(*handle);
}
self.dns_lookups.clear();

dns.update_servers(&[server]);
}
}

Ok(updated)
Expand Down Expand Up @@ -649,3 +681,58 @@ where
.map_err(|e| embedded_nal::nb::Error::Other(NetworkError::UdpWriteFailure(e)))
}
}

impl<'a, Device, Clock> embedded_nal::Dns for NetworkStack<'a, Device, Clock>
where
Device: smoltcp::phy::Device,
Clock: embedded_time::Clock,
u32: From<Clock::T>,
{
type Error = NetworkError;
fn get_host_by_name(
&mut self,
hostname: &str,
_addr_type: embedded_nal::AddrType,
) -> embedded_nal::nb::Result<embedded_nal::IpAddr, Self::Error> {
let handle = self.dns_handle.ok_or(NetworkError::Unsupported)?;
let dns_socket: &mut smoltcp::socket::dns::Socket = self.sockets.get_mut(handle);
let context = self.network_interface.context();
let key = heapless::String::try_from(hostname).map_err(|_| NetworkError::Unsupported)?;

if let Some(handle) = self.dns_lookups.get(&key) {
match dns_socket.get_query_result(*handle) {
Ok(addrs) => {
self.dns_lookups.remove(&key);
let addr = addrs.iter().next().ok_or(NetworkError::DnsFailure)?;
let smoltcp::wire::IpAddress::Ipv4(addr) = addr else {
panic!("Unexpected address return type");
};
return Ok(embedded_nal::IpAddr::V4(addr.0.into()));
}
Err(smoltcp::socket::dns::GetQueryResultError::Pending) => {}
Err(smoltcp::socket::dns::GetQueryResultError::Failed) => {
self.dns_lookups.remove(&key);
return Err(embedded_nal::nb::Error::Other(NetworkError::DnsFailure));
}
}
} else {
// Note: We only support A types because we are an Ipv4-only stack
let dns_query = dns_socket
.start_query(context, hostname, smoltcp::wire::DnsQueryType::A)
.map_err(NetworkError::DnsStart)?;
if self.dns_lookups.insert(key, dns_query).is_err() {
dns_socket.cancel_query(dns_query);
return Err(embedded_nal::nb::Error::Other(NetworkError::Unsupported));
}
}

Err(embedded_nal::nb::Error::WouldBlock)
}

fn get_host_by_address(
&mut self,
_addr: embedded_nal::IpAddr,
) -> embedded_nal::nb::Result<embedded_nal::heapless::String<256>, Self::Error> {
unimplemented!()
}
}
9 changes: 9 additions & 0 deletions src/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ where
forward! {send_to(socket: &mut S::UdpSocket, remote: embedded_nal::SocketAddr, buffer: &[u8]) -> embedded_nal::nb::Result<(), S::Error>}
forward! {bind(socket: &mut S::UdpSocket, local_port: u16) -> Result<(), S::Error>}
}
impl<'a, S> embedded_nal::Dns for NetworkStackProxy<'a, S>
where
S: embedded_nal::Dns,
{
type Error = S::Error;

forward! {get_host_by_name(hostname: &str, addr_type: embedded_nal::AddrType) -> embedded_nal::nb::Result<embedded_nal::IpAddr, Self::Error>}
forward! {get_host_by_address(addr: embedded_nal::IpAddr) -> embedded_nal::nb::Result<embedded_nal::heapless::String<256>, Self::Error>}
}

impl<'a, Device, Clock> NetworkManager<'a, Device, Clock>
where
Expand Down

0 comments on commit 8ba6776

Please sign in to comment.