Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement tracked radios cache #928

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion boost_manager/src/updater.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ where
Ok(())
}

async fn confirm_txn<'a>(&self, txn_row: &TxnRow) -> Result<()> {
async fn confirm_txn(&self, txn_row: &TxnRow) -> Result<()> {
if self.solana.confirm_transaction(&txn_row.txn_id).await? {
tracing::info!("txn_id {} confirmed on chain, updated db", txn_row.txn_id);
db::update_verified_txns_onchain(&self.pool, &txn_row.txn_id).await?
Expand Down
2 changes: 1 addition & 1 deletion file_store/src/file_sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ mod tests {
.file_name()
.to_str()
.and_then(|file_name| FileInfo::from_str(file_name).ok())
.map_or(false, |file_info| {
.is_some_and(|file_info| {
FileType::from_str(&file_info.prefix).expect("entropy report prefix")
== FileType::EntropyReport
})
Expand Down
12 changes: 6 additions & 6 deletions iot_config/src/route_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ impl RouteService {
self.update_channel.clone()
}

async fn verify_request_signature<'a, R>(
async fn verify_request_signature<R>(
&self,
signer: &PublicKey,
request: &R,
id: OrgId<'a>,
id: OrgId<'_>,
) -> Result<(), Status>
where
R: MsgVerify,
Expand Down Expand Up @@ -117,11 +117,11 @@ impl RouteService {
}
}

async fn verify_request_signature_or_stream<'a, R>(
async fn verify_request_signature_or_stream<R>(
&self,
signer: &PublicKey,
request: &R,
id: OrgId<'a>,
id: OrgId<'_>,
) -> Result<(), Status>
where
R: MsgVerify,
Expand Down Expand Up @@ -151,9 +151,9 @@ impl RouteService {
DevAddrEuiValidator::new(route_id, admin_keys, &self.pool, check_constraints).await
}

async fn validate_skf_devaddrs<'a>(
async fn validate_skf_devaddrs(
&self,
route_id: &'a str,
route_id: &str,
updates: &[route_skf_update_req_v1::RouteSkfUpdateV1],
) -> Result<(), Status> {
let ranges: Vec<DevAddrRange> = route::list_devaddr_ranges_for_route(route_id, &self.pool)
Expand Down
7 changes: 3 additions & 4 deletions iot_verifier/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ where
.witness_updater
.get_last_witness(&beacon_report.report.pub_key)
.await?;
Ok(last_witness.map_or(false, |lw| {
Ok(last_witness.is_some_and(|lw| {
beacon_report.received_timestamp - lw.timestamp < *RECIPROCITY_WINDOW
}))
}
Expand Down Expand Up @@ -544,9 +544,8 @@ where
) -> anyhow::Result<bool> {
let last_beacon_recip =
LastBeaconReciprocity::get(&self.pool, &report.report.pub_key).await?;
Ok(last_beacon_recip.map_or(false, |lw| {
report.received_timestamp - lw.timestamp < *RECIPROCITY_WINDOW
}))
Ok(last_beacon_recip
.is_some_and(|lw| report.received_timestamp - lw.timestamp < *RECIPROCITY_WINDOW))
}
}

Expand Down
46 changes: 2 additions & 44 deletions mobile_config/src/gateway_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,13 +363,10 @@ pub(crate) mod db {
use super::{DeviceType, GatewayInfo, GatewayMetadata};
use crate::gateway_info::DeploymentInfo;
use chrono::{DateTime, Utc};
use futures::{
stream::{Stream, StreamExt},
TryStreamExt,
};
use futures::stream::{Stream, StreamExt};
use helium_crypto::PublicKeyBinary;
use sqlx::{types::Json, PgExecutor, Row};
use std::{collections::HashMap, str::FromStr};
use std::str::FromStr;

const GET_METADATA_SQL: &str = r#"
select kta.entity_key, infos.location::bigint, infos.device_type,
Expand All @@ -380,50 +377,11 @@ pub(crate) mod db {
const BATCH_SQL_WHERE_SNIPPET: &str = " where kta.entity_key = any($1::bytea[]) ";
const DEVICE_TYPES_WHERE_SNIPPET: &str = " where device_type::text = any($1) ";

const GET_UPDATED_RADIOS: &str =
"SELECT entity_key, last_changed_at FROM mobile_radio_tracker WHERE last_changed_at >= $1";

const GET_UPDATED_AT: &str =
"SELECT last_changed_at FROM mobile_radio_tracker WHERE entity_key = $1";

lazy_static::lazy_static! {
static ref BATCH_METADATA_SQL: String = format!("{GET_METADATA_SQL} {BATCH_SQL_WHERE_SNIPPET}");
static ref DEVICE_TYPES_METADATA_SQL: String = format!("{GET_METADATA_SQL} {DEVICE_TYPES_WHERE_SNIPPET}");
}

pub async fn get_updated_radios(
db: impl PgExecutor<'_>,
min_updated_at: DateTime<Utc>,
) -> anyhow::Result<HashMap<PublicKeyBinary, DateTime<Utc>>> {
sqlx::query(GET_UPDATED_RADIOS)
.bind(min_updated_at)
.fetch(db)
.map_err(anyhow::Error::from)
.try_fold(
HashMap::new(),
|mut map: HashMap<PublicKeyBinary, DateTime<Utc>>, row| async move {
let entity_key_b = row.get::<&[u8], &str>("entity_key");
let entity_key = bs58::encode(entity_key_b).into_string();
let updated_at = row.get::<DateTime<Utc>, &str>("last_changed_at");
map.insert(PublicKeyBinary::from_str(&entity_key)?, updated_at);
Ok(map)
},
)
.await
}

pub async fn get_updated_at(
db: impl PgExecutor<'_>,
address: &PublicKeyBinary,
) -> anyhow::Result<Option<DateTime<Utc>>> {
let entity_key = bs58::decode(address.to_string()).into_vec()?;
sqlx::query_scalar(GET_UPDATED_AT)
.bind(entity_key)
.fetch_optional(db)
.await
.map_err(anyhow::Error::from)
}

pub async fn get_info(
db: impl PgExecutor<'_>,
address: &PublicKeyBinary,
Expand Down
52 changes: 24 additions & 28 deletions mobile_config/src/gateway_service.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use crate::{
gateway_info::{self, db::get_updated_radios, DeviceType, GatewayInfo},
gateway_info::{self, DeviceType, GatewayInfo},
key_cache::KeyCache,
mobile_radio_tracker::TrackedRadiosMap,
telemetry, verify_public_key, GrpcResult, GrpcStreamResult,
};
use chrono::{DateTime, TimeZone, Utc};
use file_store::traits::{MsgVerify, TimestampEncode};
use futures::{
future,
stream::{Stream, StreamExt, TryStreamExt},
TryFutureExt,
};
Expand All @@ -20,28 +20,29 @@ use helium_proto::{
Message,
};
use sqlx::{Pool, Postgres};
use std::{collections::HashMap, sync::Arc};
use std::sync::Arc;
use tokio::sync::RwLock;
use tonic::{Request, Response, Status};

pub struct GatewayService {
key_cache: KeyCache,
mobile_config_db_pool: Pool<Postgres>,
metadata_pool: Pool<Postgres>,
signing_key: Arc<Keypair>,
tracked_radios_cache: Arc<RwLock<TrackedRadiosMap>>,
}

impl GatewayService {
pub fn new(
key_cache: KeyCache,
metadata_pool: Pool<Postgres>,
signing_key: Keypair,
mobile_config_db_pool: Pool<Postgres>,
tracked_radios_cache: Arc<RwLock<TrackedRadiosMap>>,
) -> Self {
Self {
key_cache,
metadata_pool,
signing_key: Arc::new(signing_key),
mobile_config_db_pool,
tracked_radios_cache,
}
}

Expand Down Expand Up @@ -129,11 +130,10 @@ impl mobile_config::Gateway for GatewayService {
let pubkey: PublicKeyBinary = request.address.into();
tracing::debug!(pubkey = pubkey.to_string(), "fetching gateway info (v2)");

let updated_at = gateway_info::db::get_updated_at(&self.mobile_config_db_pool, &pubkey)
.await
.map_err(|_| {
Status::internal("error fetching updated_at field for gateway info (v2)")
})?;
let updated_at = {
let tracked_radios = self.tracked_radios_cache.read().await;
tracked_radios.get(&pubkey).cloned()
};

gateway_info::db::get_info(&self.metadata_pool, &pubkey)
.await
Expand Down Expand Up @@ -230,7 +230,6 @@ impl mobile_config::Gateway for GatewayService {
);

let metadata_db_pool = self.metadata_pool.clone();
let mobile_config_db_pool = self.mobile_config_db_pool.clone();
let signing_key = self.signing_key.clone();
let batch_size = request.batch_size;
let addresses = request
Expand All @@ -241,18 +240,17 @@ impl mobile_config::Gateway for GatewayService {

let (tx, rx) = tokio::sync::mpsc::channel(100);

let radios_cache = Arc::clone(&self.tracked_radios_cache);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.tracked_radios_cache is already an Arc, you should be able to just self.tracked_radios_cache.clone() no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arc::clone(...) and self.tracked_radios_cache.clone() are actually the same under the hood.
When I use Arc::clone I just want to emphasize (clarity) in the code that it clones the reference (not the value).

Since it is not crucial for me, I've replaced as you ask

tokio::spawn(async move {
let min_updated_at = DateTime::UNIX_EPOCH;
let updated_radios = get_updated_radios(&mobile_config_db_pool, min_updated_at).await?;

let binding = Arc::clone(&radios_cache);
let radios_cache = binding.read().await;

let stream = gateway_info::db::batch_info_stream(&metadata_db_pool, &addresses)?;
let stream = stream
.filter_map(|gateway_info| {
future::ready(handle_updated_at(
gateway_info,
&updated_radios,
min_updated_at,
))
handle_updated_at(gateway_info, &radios_cache, min_updated_at)
})
.boxed();
stream_multi_gateways_info(stream, tx.clone(), signing_key.clone(), batch_size).await
Expand Down Expand Up @@ -307,7 +305,6 @@ impl mobile_config::Gateway for GatewayService {
self.verify_request_signature(&signer, &request)?;

let metadata_db_pool = self.metadata_pool.clone();
let mobile_config_db_pool = self.mobile_config_db_pool.clone();
let signing_key = self.signing_key.clone();
let batch_size = request.batch_size;

Expand All @@ -320,6 +317,7 @@ impl mobile_config::Gateway for GatewayService {
device_types
);

let radios_cache = Arc::clone(&self.tracked_radios_cache);
tokio::spawn(async move {
let min_updated_at = Utc
.timestamp_opt(request.min_updated_at as i64, 0)
Expand All @@ -328,15 +326,13 @@ impl mobile_config::Gateway for GatewayService {
"Invalid min_refreshed_at argument",
))?;

let updated_radios = get_updated_radios(&mobile_config_db_pool, min_updated_at).await?;
let binding = Arc::clone(&radios_cache);
let radios_cache = binding.read().await;

let stream = gateway_info::db::all_info_stream(&metadata_db_pool, &device_types);
let stream = stream
.filter_map(|gateway_info| {
future::ready(handle_updated_at(
gateway_info,
&updated_radios,
min_updated_at,
))
handle_updated_at(gateway_info, &radios_cache, min_updated_at)
})
.boxed();
stream_multi_gateways_info(stream, tx.clone(), signing_key.clone(), batch_size).await
Expand All @@ -346,20 +342,20 @@ impl mobile_config::Gateway for GatewayService {
}
}

fn handle_updated_at(
async fn handle_updated_at(
mut gateway_info: GatewayInfo,
updated_radios: &HashMap<PublicKeyBinary, chrono::DateTime<Utc>>,
updated_radios: &TrackedRadiosMap,
min_updated_at: chrono::DateTime<Utc>,
) -> Option<GatewayInfo> {
// Check mobile_radio_tracker HashMap
if let Some(updated_at) = updated_radios.get(&gateway_info.address) {
// It could be already filtered by min_updated_at but recheck won't hurt
if updated_at >= &min_updated_at {
gateway_info.updated_at = Some(*updated_at);
return Some(gateway_info);
}
return None;
}

// Fallback solution #1. Try to use refreshed_at as updated_at field and check
// min_updated_at
if let Some(refreshed_at) = gateway_info.refreshed_at {
Expand Down
38 changes: 27 additions & 11 deletions mobile_config/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@ use helium_proto::services::mobile_config::{
HexBoostingServer,
};
use mobile_config::{
admin_service::AdminService, authorization_service::AuthorizationService,
carrier_service::CarrierService, entity_service::EntityService,
gateway_service::GatewayService, hex_boosting_service::HexBoostingService, key_cache::KeyCache,
mobile_radio_tracker::MobileRadioTracker, settings::Settings,
admin_service::AdminService,
authorization_service::AuthorizationService,
carrier_service::CarrierService,
entity_service::EntityService,
gateway_service::GatewayService,
hex_boosting_service::HexBoostingService,
key_cache::KeyCache,
mobile_radio_tracker::{MobileRadioTracker, TrackedRadiosMap},
settings::Settings,
};
use std::{net::SocketAddr, path::PathBuf, time::Duration};
use std::{net::SocketAddr, path::PathBuf, sync::Arc, time::Duration};
use task_manager::{ManagedTask, TaskManager};
use tokio::sync::RwLock;
use tonic::transport;

#[derive(Debug, clap::Parser)]
Expand Down Expand Up @@ -71,11 +77,15 @@ impl Daemon {

let admin_svc =
AdminService::new(settings, key_cache.clone(), key_cache_updater, pool.clone())?;

let tracked_radios_cache: Arc<RwLock<TrackedRadiosMap>> =
Arc::new(RwLock::new(TrackedRadiosMap::new()));

let gateway_svc = GatewayService::new(
key_cache.clone(),
metadata_pool.clone(),
settings.signing_keypair()?,
pool.clone(),
Arc::clone(&tracked_radios_cache),
);
let auth_svc = AuthorizationService::new(key_cache.clone(), settings.signing_keypair()?);
let entity_svc = EntityService::new(
Expand Down Expand Up @@ -107,13 +117,19 @@ impl Daemon {
hex_boosting_svc,
};

let mobile_tracker = MobileRadioTracker::new(
pool.clone(),
metadata_pool.clone(),
settings.mobile_radio_tracker_interval,
Arc::clone(&tracked_radios_cache),
);
// (Pre)initialize tracked_radios_cache to avoid race condition in GatewayService
mobile_tracker.track_changes().await?;

tracing::info!("Starting grpc server");
TaskManager::builder()
.add_task(grpc_server)
.add_task(MobileRadioTracker::new(
pool.clone(),
metadata_pool.clone(),
settings.mobile_radio_tracker_interval,
))
.add_task(mobile_tracker)
.build()
.start()
.await
Expand Down
Loading