Skip to content

Commit

Permalink
cache dns-server reverse map
Browse files Browse the repository at this point in the history
  • Loading branch information
bdbai committed Mar 19, 2024
1 parent 68350d3 commit 813f82d
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 24 deletions.
31 changes: 27 additions & 4 deletions ytflow/src/config/plugin/dns_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use serde::Deserialize;

use crate::config::factory::*;
use crate::config::*;
use crate::data::PluginId;

#[cfg_attr(not(feature = "plugins"), allow(dead_code))]
#[derive(Deserialize)]
Expand All @@ -16,12 +17,17 @@ pub struct DnsServerFactory<'a> {
tcp_map_back: HashSet<&'a str>,
#[serde(borrow)]
udp_map_back: HashSet<&'a str>,
#[serde(skip)]
plugin_id: Option<PluginId>,
}

impl<'de> DnsServerFactory<'de> {
pub(in super::super) fn parse(plugin: &'de Plugin) -> ConfigResult<ParsedPlugin<'de, Self>> {
let Plugin { name, param, .. } = plugin;
let config: Self = parse_param(name, param)?;
let Plugin {
name, param, id, ..
} = plugin;
let mut config: Self = parse_param(name, param)?;
config.plugin_id = *id;
let resolver = config.resolver;
Ok(ParsedPlugin {
requires: [Descriptor {
Expand Down Expand Up @@ -61,10 +67,24 @@ impl<'de> DnsServerFactory<'de> {
impl<'de> Factory for DnsServerFactory<'de> {
#[cfg(feature = "plugins")]
fn load(&mut self, plugin_name: String, set: &mut PartialPluginSet) -> LoadResult<()> {
use crate::data::PluginCache;
use crate::plugin::dns_server;
use crate::plugin::null::Null;
use crate::plugin::reject::RejectHandler;

let db = set
.db
.ok_or_else(|| LoadError::DatabaseRequired {
plugin: plugin_name.clone(),
})?
.clone();
let cache = PluginCache::new(
self.plugin_id.ok_or_else(|| LoadError::DatabaseRequired {
plugin: plugin_name.clone(),
})?,
Some(db.clone()),
);

let mut err = None;
let factory = Arc::new_cyclic(|weak| {
set.datagram_handlers
Expand All @@ -75,7 +95,7 @@ impl<'de> Factory for DnsServerFactory<'de> {
err = Some(e);
Arc::downgrade(&(Arc::new(Null) as _))
});
dns_server::DnsDatagramHandler::new(self.concurrency_limit as usize, resolver, self.ttl)
dns_server::DnsServer::new(self.concurrency_limit as usize, resolver, self.ttl, cache)
});
if let Some(e) = err {
set.errors.push(e);
Expand Down Expand Up @@ -127,7 +147,10 @@ impl<'de> Factory for DnsServerFactory<'de> {

set.fully_constructed
.datagram_handlers
.insert(plugin_name + ".udp", factory);
.insert(plugin_name + ".udp", factory.clone());
set.fully_constructed
.long_running_tasks
.push(tokio::spawn(dns_server::cache_writer(factory)));
Ok(())
}
}
108 changes: 95 additions & 13 deletions ytflow/src/plugin/dns_server/datagram.rs
Original file line number Diff line number Diff line change
@@ -1,40 +1,100 @@
use std::collections::BTreeMap;
use std::hash::Hash;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::num::NonZeroUsize;
use std::sync::{Arc, Mutex, Weak};

use futures::future::poll_fn;
use lru::LruCache;
use tokio::sync::Semaphore;
use serde::{Deserialize, Serialize};
use tokio::sync::{Notify, Semaphore};
use trust_dns_resolver::proto::op::{Message as DnsMessage, MessageType, ResponseCode};
use trust_dns_resolver::proto::rr::{RData, Record, RecordType};
use trust_dns_resolver::proto::serialize::binary::BinDecodable;

use crate::data::PluginCache;
use crate::flow::*;

const CACHE_CAPAICTY: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(1024) };
const CACHE_CAPACITY: NonZeroUsize = NonZeroUsize::new(1024).unwrap();
const REVERSE_MAPPING_V4_CACHE_KEY: &str = "rev_v4";
const REVERSE_MAPPING_V6_CACHE_KEY: &str = "rev_v6";

pub struct DnsDatagramHandler {
pub struct DnsServer {
concurrency_limit: Arc<Semaphore>,
resolver: Weak<dyn Resolver>,
ttl: u32,
pub(super) reverse_mapping_v4: Arc<Mutex<LruCache<Ipv4Addr, String>>>,
pub(super) reverse_mapping_v6: Arc<Mutex<LruCache<Ipv6Addr, String>>>,
plugin_cache: PluginCache,
pub(super) new_notify: Arc<Notify>,
}

impl DnsDatagramHandler {
pub fn new(concurrency_limit: usize, resolver: Weak<dyn Resolver>, ttl: u32) -> Self {
#[derive(Debug, Clone, Default, PartialOrd, Ord, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)]
struct ReverseMappingCache<T: Ord>(BTreeMap<T, String>);

impl DnsServer {
pub fn new(
concurrency_limit: usize,
resolver: Weak<dyn Resolver>,
ttl: u32,
plugin_cache: PluginCache,
) -> Self {
let concurrency_limit = Arc::new(Semaphore::new(concurrency_limit));
DnsDatagramHandler {
let mut reverse_mapping_v4 = LruCache::new(CACHE_CAPACITY);
let mut reverse_mapping_v6 = LruCache::new(CACHE_CAPACITY);
if let Some(reverse_mapping_v4_cache) = plugin_cache
.get::<ReverseMappingCache<_>>(REVERSE_MAPPING_V4_CACHE_KEY)
.ok()
.flatten()
{
for (k, v) in reverse_mapping_v4_cache.0 {
reverse_mapping_v4.put(k, v);
}
}
if let Some(reverse_mapping_v6_cache) = plugin_cache
.get::<ReverseMappingCache<_>>(REVERSE_MAPPING_V6_CACHE_KEY)
.ok()
.flatten()
{
for (k, v) in reverse_mapping_v6_cache.0 {
reverse_mapping_v6.put(k, v);
}
}
DnsServer {
concurrency_limit,
resolver,
ttl,
reverse_mapping_v4: Arc::new(Mutex::new(LruCache::new(CACHE_CAPAICTY))),
reverse_mapping_v6: Arc::new(Mutex::new(LruCache::new(CACHE_CAPAICTY))),
reverse_mapping_v4: Arc::new(Mutex::new(reverse_mapping_v4)),
reverse_mapping_v6: Arc::new(Mutex::new(reverse_mapping_v6)),
plugin_cache,
new_notify: Arc::new(Notify::new()),
}
}

fn save_reverse_mapping_cache<T: Serialize + Hash + Eq + Ord + Clone>(
&self,
cache: &Mutex<LruCache<T, String>>,
key: &str,
) {
let cache = {
let inner = cache.lock().unwrap();
ReverseMappingCache(
(&*inner)
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
)
};
self.plugin_cache.set(key, &cache).ok();
}
pub(crate) fn save_cache(&self) {
self.save_reverse_mapping_cache(&self.reverse_mapping_v4, REVERSE_MAPPING_V4_CACHE_KEY);
self.save_reverse_mapping_cache(&self.reverse_mapping_v6, REVERSE_MAPPING_V6_CACHE_KEY);
}
}

impl DatagramSessionHandler for DnsDatagramHandler {
impl DatagramSessionHandler for DnsServer {
fn on_session(&self, mut session: Box<dyn DatagramSession>, _context: Box<FlowContext>) {
let resolver = match self.resolver.upgrade() {
Some(resolver) => resolver,
Expand All @@ -44,6 +104,7 @@ impl DatagramSessionHandler for DnsDatagramHandler {
let ttl = self.ttl;
let reverse_mapping_v4 = self.reverse_mapping_v4.clone();
let reverse_mapping_v6 = self.reverse_mapping_v6.clone();
let new_notify = self.new_notify.clone();
tokio::spawn(async move {
let mut send_ready = true;
while let Some((dest, buf)) = poll_fn(|cx| {
Expand All @@ -65,18 +126,25 @@ impl DatagramSessionHandler for DnsDatagramHandler {
};
let mut res_code = ResponseCode::NoError;
let mut ans_records = Vec::with_capacity(msg.queries().len());
let mut notify_cache_update = false;
for query in msg.queries() {
let name = query.name();
let name_str = name.to_lowercase().to_ascii();
#[allow(unreachable_code)]
match query.query_type() {
RecordType::A => {
let ips = match resolver.resolve_ipv4(name_str.clone()).await {
Ok(addrs) => addrs,
Err(_) => (res_code = ResponseCode::NXDomain, continue).1,
Err(_) => {
res_code = ResponseCode::NXDomain;
continue;
}
};
let mut reverse_mapping = reverse_mapping_v4.lock().unwrap();
for ip in &ips {
notify_cache_update |= reverse_mapping
.peek_mut(ip)
.filter(|n| *n == &name_str)
.is_none();
reverse_mapping.get_or_insert(*ip, || name_str.clone());
}
ans_records.extend(
Expand All @@ -88,20 +156,34 @@ impl DatagramSessionHandler for DnsDatagramHandler {
RecordType::AAAA => {
let ips = match resolver.resolve_ipv6(name_str.clone()).await {
Ok(addrs) => addrs,
Err(_) => (res_code = ResponseCode::NXDomain, continue).1,
Err(_) => {
res_code = ResponseCode::NXDomain;
continue;
}
};
let mut reverse_mapping = reverse_mapping_v6.lock().unwrap();
for ip in &ips {
notify_cache_update |= reverse_mapping
.peek_mut(ip)
.filter(|n| *n == &name_str)
.is_none();
reverse_mapping.get_or_insert(*ip, || name_str.clone());
}
ans_records.extend(ips.into_iter().map(|addr| {
Record::from_rdata(name.clone(), ttl, RData::AAAA(addr))
}))
}
// TODO: SRV
_ => (res_code = ResponseCode::NotImp, continue).1,
_ => {
res_code = ResponseCode::NotImp;
continue;
}
}
}
if notify_cache_update {
new_notify.notify_one();
}

*msg.set_message_type(MessageType::Response)
.set_response_code(res_code)
.answers_mut() = ans_records;
Expand Down
7 changes: 3 additions & 4 deletions ytflow/src/plugin/dns_server/map_back.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::task::{ready, Context, Poll};

use lru::LruCache;

use super::DnsDatagramHandler;
use super::DnsServer;
use crate::flow::*;

#[derive(Clone)]
Expand Down Expand Up @@ -46,8 +46,7 @@ pub struct MapBackStreamHandler {
}

impl MapBackStreamHandler {
pub fn new(handler: &DnsDatagramHandler, next: Weak<dyn StreamHandler>) -> Self {
// TODO: persist mapping into cache
pub fn new(handler: &DnsServer, next: Weak<dyn StreamHandler>) -> Self {
Self {
back_mapper: BackMapper {
reverse_mapping_v4: handler.reverse_mapping_v4.clone(),
Expand Down Expand Up @@ -104,7 +103,7 @@ struct MapBackDatagramSession {
}

impl MapBackDatagramSessionHandler {
pub fn new(handler: &DnsDatagramHandler, next: Weak<dyn DatagramSessionHandler>) -> Self {
pub fn new(handler: &DnsServer, next: Weak<dyn DatagramSessionHandler>) -> Self {
Self {
back_mapper: BackMapper {
reverse_mapping_v4: handler.reverse_mapping_v4.clone(),
Expand Down
38 changes: 37 additions & 1 deletion ytflow/src/plugin/dns_server/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,41 @@
mod datagram;
mod map_back;

pub use datagram::DnsDatagramHandler;
use std::sync::Arc;

pub use datagram::DnsServer;
pub use map_back::{MapBackDatagramSessionHandler, MapBackStreamHandler};

pub async fn cache_writer(plugin: Arc<DnsServer>) {
let (plugin, notify) = {
let notify = plugin.new_notify.clone();
let weak = Arc::downgrade(&plugin);
drop(plugin);
(weak, notify)
};
if plugin.strong_count() == 0 {
panic!("dns-server has no strong reference left for cache_writer");
}

use tokio::select;
use tokio::time::{sleep, Duration};
loop {
let mut notified_fut = notify.notified();
let mut sleep_fut = sleep(Duration::from_secs(3600));
'debounce: loop {
select! {
_ = notified_fut => {
notified_fut = notify.notified();
sleep_fut = sleep(Duration::from_secs(3));
}
_ = sleep_fut => {
break 'debounce;
}
}
}
match plugin.upgrade() {
Some(plugin) => plugin.save_cache(),
None => break,
}
}
}
4 changes: 2 additions & 2 deletions ytflow/src/plugin/fakeip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use tokio::sync::Notify;
use crate::data::PluginCache;
use crate::flow::*;

const CACHE_SIZE: NonZeroUsize = NonZeroUsize::new(1000).unwrap();
const CACHE_CAPACITY: NonZeroUsize = NonZeroUsize::new(1000).unwrap();
const PLUGIN_CACHE_KEY: &str = "map";

struct Inner {
Expand All @@ -35,7 +35,7 @@ pub struct FakeIp {

impl FakeIp {
pub fn new(prefix_v4: [u8; 2], prefix_v6: [u8; 14], plugin_cache: PluginCache) -> Self {
let mut lru = LruCache::new(CACHE_SIZE);
let mut lru = LruCache::new(CACHE_CAPACITY);
let inner = match plugin_cache
.get::<InnerCache>(PLUGIN_CACHE_KEY)
.ok()
Expand Down

0 comments on commit 813f82d

Please sign in to comment.