diff --git a/mobile_config/src/gateway_service.rs b/mobile_config/src/gateway_service.rs index b7892b620..b64546c3e 100644 --- a/mobile_config/src/gateway_service.rs +++ b/mobile_config/src/gateway_service.rs @@ -20,15 +20,14 @@ use helium_proto::{ Message, }; use sqlx::{Pool, Postgres}; -use std::sync::Arc; -use tokio::sync::RwLock; +use std::{collections::HashMap, sync::Arc}; use tonic::{Request, Response, Status}; pub struct GatewayService { key_cache: KeyCache, metadata_pool: Pool, signing_key: Arc, - tracked_radios_cache: Arc>, + tracked_radios_cache: TrackedRadiosMap, } impl GatewayService { @@ -36,7 +35,7 @@ impl GatewayService { key_cache: KeyCache, metadata_pool: Pool, signing_key: Keypair, - tracked_radios_cache: Arc>, + tracked_radios_cache: TrackedRadiosMap, ) -> Self { Self { key_cache, @@ -342,7 +341,7 @@ impl mobile_config::Gateway for GatewayService { async fn handle_updated_at( mut gateway_info: GatewayInfo, - updated_radios: &TrackedRadiosMap, + updated_radios: &HashMap>, min_updated_at: chrono::DateTime, ) -> Option { // Check mobile_radio_tracker HashMap diff --git a/mobile_config/src/main.rs b/mobile_config/src/main.rs index a3bdd118a..f06e0e6df 100644 --- a/mobile_config/src/main.rs +++ b/mobile_config/src/main.rs @@ -78,8 +78,7 @@ impl Daemon { let admin_svc = AdminService::new(settings, key_cache.clone(), key_cache_updater, pool.clone())?; - let tracked_radios_cache: Arc> = - Arc::new(RwLock::new(TrackedRadiosMap::new())); + let tracked_radios_cache = TrackedRadiosMap::new(); let gateway_svc = GatewayService::new( key_cache.clone(), diff --git a/mobile_config/src/mobile_radio_tracker.rs b/mobile_config/src/mobile_radio_tracker.rs index f19be4195..2c76bf5ac 100644 --- a/mobile_config/src/mobile_radio_tracker.rs +++ b/mobile_config/src/mobile_radio_tracker.rs @@ -1,3 +1,4 @@ +use std::time::UNIX_EPOCH; use std::{collections::HashMap, sync::Arc, time::Duration}; use chrono::{DateTime, Utc}; @@ -7,9 +8,52 @@ use sqlx::Row; use sqlx::{Pool, Postgres, QueryBuilder}; use std::str::FromStr; use task_manager::ManagedTask; -use tokio::sync::RwLock; +use tokio::sync::{RwLock, RwLockReadGuard}; + +#[derive(Clone)] +pub struct TrackedRadiosMap(Arc>>>); + +const GET_UPDATED_RADIOS: &str = + "SELECT entity_key, last_changed_at FROM mobile_radio_tracker WHERE last_changed_at >= $1"; + +pub async fn get_updated_radios( + pool: &Pool, + min_updated_at: DateTime, +) -> anyhow::Result>> { + sqlx::query(GET_UPDATED_RADIOS) + .bind(min_updated_at) + .fetch(pool) + .map_err(anyhow::Error::from) + .try_fold( + HashMap::new(), + |mut map: HashMap>, row| async move { + let entity_key_b = row.get::<&[u8], &str>("entity_key"); + let entity_key = bs58::encode(entity_key_b).into_string(); + let updated_at = row.get::, &str>("last_changed_at"); + map.insert(PublicKeyBinary::from_str(&entity_key)?, updated_at); + Ok(map) + }, + ) + .await +} + +impl TrackedRadiosMap { + pub fn new() -> Self { + TrackedRadiosMap(Arc::new(RwLock::new(HashMap::new()))) + } + + pub async fn update(&self, pool: &Pool) -> anyhow::Result<()> { + let new_data = get_updated_radios(pool, UNIX_EPOCH.into()).await?; + let mut write_guard = self.0.write().await; + *write_guard = new_data; + Ok(()) + } + + pub async fn read(&self) -> RwLockReadGuard<'_, HashMap>> { + self.0.read().await + } +} -pub type TrackedRadiosMap = HashMap>; type EntityKey = Vec; #[derive(Debug, Clone, sqlx::FromRow)] @@ -110,7 +154,8 @@ pub struct MobileRadioTracker { pool: Pool, metadata: Pool, interval: Duration, - tracked_radios_cache: Arc>, + // tracked_radios_cache: Arc>, + tracked_radios_cache: TrackedRadiosMap, } impl ManagedTask for MobileRadioTracker { @@ -132,7 +177,7 @@ impl MobileRadioTracker { pool: Pool, metadata: Pool, interval: Duration, - tracked_radios_cache: Arc>, + tracked_radios_cache: TrackedRadiosMap, ) -> Self { Self { pool, @@ -174,13 +219,7 @@ impl MobileRadioTracker { update_tracked_radios(&self.pool, updates).await?; tracing::info!("updating tracked radios cache"); - let tracked_radios_map: TrackedRadiosMap = - get_updated_radios(&self.pool, DateTime::UNIX_EPOCH).await?; - { - let mut map = self.tracked_radios_cache.write().await; - *map = tracked_radios_map; - } - + self.tracked_radios_cache.update(&self.pool).await?; tracing::info!("done"); Ok(()) } @@ -203,30 +242,6 @@ async fn identify_changes( .await } -const GET_UPDATED_RADIOS: &str = - "SELECT entity_key, last_changed_at FROM mobile_radio_tracker WHERE last_changed_at >= $1"; - -pub async fn get_updated_radios( - pool: &Pool, - min_updated_at: DateTime, -) -> anyhow::Result { - sqlx::query(GET_UPDATED_RADIOS) - .bind(min_updated_at) - .fetch(pool) - .map_err(anyhow::Error::from) - .try_fold( - HashMap::new(), - |mut map: HashMap>, row| async move { - let entity_key_b = row.get::<&[u8], &str>("entity_key"); - let entity_key = bs58::encode(entity_key_b).into_string(); - let updated_at = row.get::, &str>("last_changed_at"); - map.insert(PublicKeyBinary::from_str(&entity_key)?, updated_at); - Ok(map) - }, - ) - .await -} - pub async fn get_tracked_radios( pool: &Pool, ) -> anyhow::Result> {