diff --git a/Cargo.lock b/Cargo.lock index e4c707338..4d172493c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2879,7 +2879,7 @@ dependencies = [ [[package]] name = "helium-proto" version = "0.1.0" -source = "git+https://github.com/helium/proto?branch=master#94bc62b3b65391260bd79a102f60a0b683fcd62f" +source = "git+https://github.com/helium/proto?branch=master#3061e06dff4f4a643dd9e2bf98bc24b462071de3" dependencies = [ "bytes", "prost", diff --git a/docker-compose.yml b/docker-compose.yml index 2ca033886..deb3d15d7 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -93,6 +93,8 @@ services: iot-verifier mobile-packet-verifier iot-packet-verifier + iot-price + mobile-price ORACLE_ID: oraclesecretid ORACLE_KEY: oraclesecretkey entrypoint: diff --git a/file_store/src/traits/msg_verify.rs b/file_store/src/traits/msg_verify.rs index fd0cf4d8a..e374273fd 100644 --- a/file_store/src/traits/msg_verify.rs +++ b/file_store/src/traits/msg_verify.rs @@ -46,15 +46,14 @@ impl_msg_verify!(iot_config::RouteGetEuisReqV1, signature); impl_msg_verify!(iot_config::RouteUpdateEuisReqV1, signature); impl_msg_verify!(iot_config::RouteGetDevaddrRangesReqV1, signature); impl_msg_verify!(iot_config::RouteUpdateDevaddrRangesReqV1, signature); +impl_msg_verify!(iot_config::RouteSkfListReqV1, signature); +impl_msg_verify!(iot_config::RouteSkfGetReqV1, signature); +impl_msg_verify!(iot_config::RouteSkfUpdateReqV1, signature); impl_msg_verify!(iot_config::GatewayLocationReqV1, signature); impl_msg_verify!(iot_config::GatewayRegionParamsReqV1, signature); impl_msg_verify!(iot_config::AdminAddKeyReqV1, signature); impl_msg_verify!(iot_config::AdminLoadRegionReqV1, signature); impl_msg_verify!(iot_config::AdminRemoveKeyReqV1, signature); -impl_msg_verify!(iot_config::SessionKeyFilterGetReqV1, signature); -impl_msg_verify!(iot_config::SessionKeyFilterListReqV1, signature); -impl_msg_verify!(iot_config::SessionKeyFilterStreamReqV1, signature); -impl_msg_verify!(iot_config::SessionKeyFilterUpdateReqV1, signature); impl_msg_verify!(iot_config::GatewayInfoReqV1, signature); impl_msg_verify!(iot_config::GatewayInfoStreamReqV1, signature); impl_msg_verify!(iot_config::RegionParamsReqV1, signature); diff --git a/iot_config.Dockerfile b/iot_config.Dockerfile index 99e704b90..20b80d07e 100644 --- a/iot_config.Dockerfile +++ b/iot_config.Dockerfile @@ -19,7 +19,7 @@ RUN mkdir ./iot_config/src \ && sed -i -e '/ingest/d' -e '/mobile_config/d' -e '/mobile_verifier/d' \ -e '/poc_entropy/d' -e '/iot_verifier/d' -e '/price/d' \ -e '/reward_index/d' -e '/denylist/d' -e '/iot_packet_verifier/d' \ - -e '/mobile_packet_verifier/d' \ + -e '/solana/d' -e '/mobile_packet_verifier/d' \ Cargo.toml \ && cargo build --package iot-config --release diff --git a/iot_config/migrations/8_skfs_by_route.sql b/iot_config/migrations/8_skfs_by_route.sql new file mode 100644 index 000000000..5fa76cd80 --- /dev/null +++ b/iot_config/migrations/8_skfs_by_route.sql @@ -0,0 +1,16 @@ +drop table session_key_filters; + +create table route_session_key_filters ( + route_id uuid not null references routes(id) on delete cascade, + devaddr int not null, + session_key text not null, + + inserted_at timestamptz not null default now(), + updated_at timestamptz not null default now(), + + primary key (route_id, devaddr, session_key) +); + +create index skf_devaddr_idx on route_session_key_filters (devaddr); + +select trigger_updated_at('route_session_key_filters'); diff --git a/iot_config/src/lib.rs b/iot_config/src/lib.rs index 90b067233..b7778e670 100644 --- a/iot_config/src/lib.rs +++ b/iot_config/src/lib.rs @@ -9,8 +9,6 @@ pub mod org_service; pub mod region_map; pub mod route; pub mod route_service; -pub mod session_key; -pub mod session_key_service; pub mod settings; pub mod telemetry; @@ -20,7 +18,6 @@ pub use gateway_service::GatewayService; use lora_field::{LoraField, NetIdField}; pub use org_service::OrgService; pub use route_service::RouteService; -pub use session_key_service::SessionKeyFilterService; pub use settings::Settings; use helium_crypto::PublicKey; diff --git a/iot_config/src/lora_field.rs b/iot_config/src/lora_field.rs index 35f927ca8..d5881b6fd 100644 --- a/iot_config/src/lora_field.rs +++ b/iot_config/src/lora_field.rs @@ -11,7 +11,7 @@ pub type EuiField = LoraField<16>; pub mod proto { pub use helium_proto::services::iot_config::{ - DevaddrConstraintV1, DevaddrRangeV1, EuiPairV1, OrgV1, + DevaddrConstraintV1, DevaddrRangeV1, EuiPairV1, OrgV1, SkfV1, }; } @@ -30,6 +30,10 @@ impl DevAddrRange { end_addr, } } + + pub fn contains_addr(&self, addr: DevAddrField) -> bool { + self.start_addr <= addr && self.end_addr >= addr + } } impl FromRow<'_, PgRow> for DevAddrRange { @@ -73,10 +77,6 @@ impl DevAddrConstraint { pub fn contains_range(&self, range: &DevAddrRange) -> bool { self.start_addr <= range.start_addr && self.end_addr >= range.end_addr } - - pub fn contains_addr(&self, addr: DevAddrField) -> bool { - self.start_addr <= addr && self.end_addr >= addr - } } #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] @@ -108,6 +108,35 @@ impl FromRow<'_, PgRow> for EuiPair { } } +#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] +pub struct Skf { + pub route_id: String, + pub devaddr: DevAddrField, + pub session_key: String, +} + +impl Skf { + pub fn new(route_id: String, devaddr: DevAddrField, session_key: String) -> Self { + Self { + route_id, + devaddr, + session_key, + } + } +} + +impl FromRow<'_, PgRow> for Skf { + fn from_row(row: &PgRow) -> sqlx::Result { + Ok(Self { + route_id: row + .try_get::("route_id")? + .to_string(), + devaddr: row.get::("devaddr").into(), + session_key: row.get::("session_key"), + }) + } +} + #[derive(thiserror::Error, Debug)] pub enum ParseError { #[error("char len mismatch: expected {0}, found {1}")] @@ -523,6 +552,46 @@ impl From<&EuiPair> for proto::EuiPairV1 { } } +impl From for Skf { + fn from(filter: proto::SkfV1) -> Self { + Self { + route_id: filter.route_id, + devaddr: filter.devaddr.into(), + session_key: filter.session_key, + } + } +} + +impl From<&proto::SkfV1> for Skf { + fn from(filter: &proto::SkfV1) -> Self { + Self { + route_id: filter.route_id.to_owned(), + devaddr: filter.devaddr.into(), + session_key: filter.session_key.to_owned(), + } + } +} + +impl From for proto::SkfV1 { + fn from(filter: Skf) -> Self { + Self { + route_id: filter.route_id, + devaddr: filter.devaddr.into(), + session_key: filter.session_key, + } + } +} + +impl From<&Skf> for proto::SkfV1 { + fn from(filter: &Skf) -> Self { + Self { + route_id: filter.route_id.to_owned(), + devaddr: filter.devaddr.into(), + session_key: filter.session_key.to_owned(), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/iot_config/src/main.rs b/iot_config/src/main.rs index cdff071bf..dc2d2623a 100644 --- a/iot_config/src/main.rs +++ b/iot_config/src/main.rs @@ -1,13 +1,11 @@ use anyhow::{Error, Result}; use clap::Parser; use futures_util::TryFutureExt; -use helium_proto::services::iot_config::{ - AdminServer, GatewayServer, OrgServer, RouteServer, SessionKeyFilterServer, -}; +use helium_proto::services::iot_config::{AdminServer, GatewayServer, OrgServer, RouteServer}; use iot_config::{ - admin::AuthCache, gateway_service::GatewayService, org_service::OrgService, - region_map::RegionMapReader, route_service::RouteService, - session_key_service::SessionKeyFilterService, settings::Settings, AdminService, + admin::AuthCache, admin_service::AdminService, gateway_service::GatewayService, + org_service::OrgService, region_map::RegionMapReader, route_service::RouteService, + settings::Settings, }; use std::{path::PathBuf, time::Duration}; use tokio::signal; @@ -115,12 +113,12 @@ impl Daemon { region_map.clone(), region_updater, )?; - let session_key_filter_svc = SessionKeyFilterService::new( - settings, - auth_cache.clone(), - pool.clone(), - shutdown_listener.clone(), - )?; + + let pubkey = settings + .signing_keypair() + .map(|keypair| keypair.public_key().to_string())?; + tracing::debug!("listening on {listen_addr}"); + tracing::debug!("signing as {pubkey}"); let server = transport::Server::builder() .http2_keepalive_interval(Some(Duration::from_secs(250))) @@ -129,7 +127,6 @@ impl Daemon { .add_service(OrgServer::new(org_svc)) .add_service(RouteServer::new(route_svc)) .add_service(AdminServer::new(admin_svc)) - .add_service(SessionKeyFilterServer::new(session_key_filter_svc)) .serve_with_shutdown(listen_addr, shutdown_listener) .map_err(Error::from); diff --git a/iot_config/src/org.rs b/iot_config/src/org.rs index c855cd20a..ab26fcc82 100644 --- a/iot_config/src/org.rs +++ b/iot_config/src/org.rs @@ -232,7 +232,10 @@ pub async fn get_org_pubkeys( ) -> Result, DbOrgError> { let org = get(oui, db).await?; - let mut pubkeys: Vec = vec![PublicKey::try_from(org.owner)?]; + let mut pubkeys: Vec = vec![ + PublicKey::try_from(org.owner)?, + PublicKey::try_from(org.payer)?, + ]; let mut delegate_pubkeys: Vec = org .delegate_keys @@ -262,7 +265,10 @@ pub async fn get_org_pubkeys_by_route( .fetch_one(db) .await?; - let mut pubkeys: Vec = vec![PublicKey::try_from(org.owner)?]; + let mut pubkeys: Vec = vec![ + PublicKey::try_from(org.owner)?, + PublicKey::try_from(org.payer)?, + ]; let mut delegate_keys: Vec = org .delegate_keys diff --git a/iot_config/src/route.rs b/iot_config/src/route.rs index 2b73ac9ce..c1bcbb115 100644 --- a/iot_config/src/route.rs +++ b/iot_config/src/route.rs @@ -1,6 +1,6 @@ use crate::{ broadcast_update, - lora_field::{DevAddrRange, EuiPair, NetIdField}, + lora_field::{DevAddrField, DevAddrRange, EuiPair, NetIdField, Skf}, }; use anyhow::anyhow; use chrono::Utc; @@ -141,15 +141,15 @@ pub async fn create_route( signer, signature: vec![], }; - signing_key + _ = signing_key .sign(&update.encode_to_vec()) - .map_err(|err| anyhow!(format!("error signing route stream response: {err:?}"))) + .map_err(|err| tracing::error!("error signing route stream response: {err:?}")) .and_then(|signature| { update.signature = signature; update_tx.send(update).map_err(|err| { - anyhow!(format!("error broadcasting route stream response: {err:?}")) + tracing::warn!("error broadcasting route stream response: {err:?}") }) - })?; + }); }; Ok(new_route) @@ -206,12 +206,12 @@ pub async fn update_route( _ = signing_key .sign(&update_res.encode_to_vec()) - .map_err(|err| anyhow!(format!("error signing route stream response: {err:?}"))) + .map_err(|err| tracing::error!("error signing route stream response: {err:?}")) .and_then(|signature| { update_res.signature = signature; - update_tx.send(update_res).map_err(|err| { - anyhow!(format!("error broadcasting route stream response: {err:?}")) - }) + update_tx + .send(update_res) + .map_err(|err| tracing::warn!("error broadcasting route stream response: {err:?}")) }); Ok(updated_route) @@ -562,6 +562,18 @@ pub fn devaddr_range_stream<'a>( .boxed() } +pub fn skf_stream<'a>(db: impl sqlx::PgExecutor<'a> + 'a + Copy) -> impl Stream + 'a { + sqlx::query_as::<_, Skf>( + r#" + select skf.route_id, skf.devaddr, skf.session_key + from route_session_key_filters skf + "#, + ) + .fetch(db) + .filter_map(|skf| async move { skf.ok() }) + .boxed() +} + pub async fn get_route(id: &str, db: impl sqlx::PgExecutor<'_>) -> anyhow::Result { let uuid = Uuid::try_parse(id)?; let route_row = sqlx::query_as::<_, StorageRoute>( @@ -645,6 +657,149 @@ pub async fn delete_route( Ok(()) } +pub fn list_skfs_for_route<'a>( + id: &str, + db: impl sqlx::PgExecutor<'a> + 'a + Copy, +) -> Result> + 'a, RouteStorageError> { + let id = Uuid::try_parse(id)?; + const SKF_SELECT_SQL: &str = r#" + select skf.route_id, skf.devaddr, skf.session_key + from route_session_key_filters skf + where skf.route_id = $1 + "#; + + Ok(sqlx::query_as::<_, Skf>(SKF_SELECT_SQL) + .bind(id) + .fetch(db) + .boxed()) +} + +pub fn list_skfs_for_route_and_devaddr<'a>( + id: &str, + devaddr: DevAddrField, + db: impl sqlx::PgExecutor<'a> + 'a + Copy, +) -> Result> + 'a, RouteStorageError> { + let id = Uuid::try_parse(id)?; + + Ok(sqlx::query_as::<_, Skf>( + r#" + select skf.route_id, skf.devaddr, skf.session_key + from route_session_key_filters skf + where skf.route_id = $1 and devaddr = $2 + "#, + ) + .bind(id) + .bind(i32::from(devaddr)) + .fetch(db) + .boxed()) +} + +pub async fn update_skfs( + to_add: &[Skf], + to_remove: &[Skf], + db: impl sqlx::PgExecutor<'_> + sqlx::Acquire<'_, Database = sqlx::Postgres> + Copy, + signing_key: Arc, + update_tx: Sender, +) -> anyhow::Result<()> { + let mut transaction = db.begin().await?; + + let added_updates: Vec<(Skf, proto::ActionV1)> = insert_skfs(to_add, &mut transaction) + .await? + .into_iter() + .map(|added_skf| (added_skf, proto::ActionV1::Add)) + .collect(); + + let removed_updates: Vec<(Skf, proto::ActionV1)> = remove_skfs(to_remove, &mut transaction) + .await? + .into_iter() + .map(|removed_skf| (removed_skf, proto::ActionV1::Remove)) + .collect(); + + transaction.commit().await?; + + tokio::spawn(async move { + let timestamp = Utc::now().encode_timestamp(); + let signer: Vec = signing_key.public_key().into(); + stream::iter([added_updates, removed_updates].concat()) + .map(Ok) + .try_for_each(|(update, action)| { + let mut skf_update = proto::RouteStreamResV1 { + action: i32::from(action), + data: Some(proto::route_stream_res_v1::Data::Skf(update.into())), + timestamp, + signer: signer.clone(), + signature: vec![], + }; + futures::future::ready(signing_key.sign(&skf_update.encode_to_vec())) + .map_err(|_| anyhow!("failed to sign session key filter update")) + .and_then(|signature| { + skf_update.signature = signature; + broadcast_update::(skf_update, update_tx.clone()) + .map_err(|_| anyhow!("failed to broadcast session key filter update")) + }) + }) + .await + }); + + Ok(()) +} + +async fn insert_skfs(skfs: &[Skf], db: impl sqlx::PgExecutor<'_>) -> anyhow::Result> { + if skfs.is_empty() { + return Ok(vec![]); + } + + let skfs = skfs + .iter() + .map(|filter| filter.try_into()) + .collect::, _>>()?; + + const SKF_INSERT_VALS: &str = + " insert into route_session_key_filters (route_id, devaddr, session_key) "; + const SKF_INSERT_CONFLICT: &str = + " on conflict (route_id, devaddr, session_key) do nothing returning * "; + + let mut query_builder: sqlx::QueryBuilder = + sqlx::QueryBuilder::new(SKF_INSERT_VALS); + query_builder + .push_values(skfs, |mut builder, (route_id, devaddr, session_key)| { + builder + .push_bind(route_id) + .push_bind(devaddr) + .push_bind(session_key); + }) + .push(SKF_INSERT_CONFLICT); + + Ok(query_builder.build_query_as::().fetch_all(db).await?) +} + +async fn remove_skfs(skfs: &[Skf], db: impl sqlx::PgExecutor<'_>) -> anyhow::Result> { + if skfs.is_empty() { + return Ok(vec![]); + } + + let skfs = skfs + .iter() + .map(|filter| filter.try_into()) + .collect::, _>>()?; + + const SKF_DELETE_VALS: &str = + " delete from route_session_key_filters where (route_id, devaddr, session_key) in "; + const SKF_DELETE_RETURN: &str = " returning * "; + let mut query_builder: sqlx::QueryBuilder = + sqlx::QueryBuilder::new(SKF_DELETE_VALS); + query_builder + .push_tuples(skfs, |mut builder, (route_id, devaddr, session_key)| { + builder + .push_bind(route_id) + .push_bind(devaddr) + .push_bind(session_key); + }) + .push(SKF_DELETE_RETURN); + + Ok(query_builder.build_query_as::().fetch_all(db).await?) +} + #[derive(Debug, Serialize)] pub struct RouteList { routes: Vec, @@ -707,6 +862,15 @@ impl TryFrom<&DevAddrRange> for (Uuid, i32, i32) { } } +impl TryFrom<&Skf> for (Uuid, i32, String) { + type Error = sqlx::types::uuid::Error; + + fn try_from(skf: &Skf) -> Result<(Uuid, i32, String), Self::Error> { + let uuid = Uuid::try_parse(&skf.route_id)?; + Ok((uuid, i32::from(skf.devaddr), skf.session_key.clone())) + } +} + pub type Port = u32; pub type GwmpMap = BTreeMap; diff --git a/iot_config/src/route_service.rs b/iot_config/src/route_service.rs index 7e475c805..a1e5ada1a 100644 --- a/iot_config/src/route_service.rs +++ b/iot_config/src/route_service.rs @@ -1,6 +1,6 @@ use crate::{ admin::{AuthCache, KeyType}, - lora_field::{DevAddrConstraint, DevAddrRange, EuiPair}, + lora_field::{DevAddrConstraint, DevAddrRange, EuiPair, Skf}, org::{self, DbOrgError}, route::{self, Route, RouteStorageError}, telemetry, update_channel, verify_public_key, GrpcResult, GrpcStreamRequest, GrpcStreamResult, @@ -16,20 +16,22 @@ use futures::{ use helium_crypto::{Keypair, PublicKey, Sign}; use helium_proto::{ services::iot_config::{ - self, route_stream_res_v1, ActionV1, DevaddrRangeV1, EuiPairV1, RouteCreateReqV1, - RouteDeleteReqV1, RouteDevaddrRangesResV1, RouteEuisResV1, RouteGetDevaddrRangesReqV1, - RouteGetEuisReqV1, RouteGetReqV1, RouteListReqV1, RouteListResV1, RouteResV1, - RouteStreamReqV1, RouteStreamResV1, RouteUpdateDevaddrRangesReqV1, RouteUpdateEuisReqV1, - RouteUpdateReqV1, RouteV1, + self, route_skf_update_req_v1, route_stream_res_v1, ActionV1, DevaddrRangeV1, EuiPairV1, + RouteCreateReqV1, RouteDeleteReqV1, RouteDevaddrRangesResV1, RouteEuisResV1, + RouteGetDevaddrRangesReqV1, RouteGetEuisReqV1, RouteGetReqV1, RouteListReqV1, + RouteListResV1, RouteResV1, RouteSkfGetReqV1, RouteSkfListReqV1, RouteSkfUpdateReqV1, + RouteSkfUpdateResV1, RouteStreamReqV1, RouteStreamResV1, RouteUpdateDevaddrRangesReqV1, + RouteUpdateEuisReqV1, RouteUpdateReqV1, RouteV1, SkfV1, }, Message, }; use sqlx::{Pool, Postgres}; -use std::sync::Arc; +use std::{pin::Pin, sync::Arc}; use tokio::sync::{broadcast, mpsc}; use tonic::{Request, Response, Status}; const UPDATE_BATCH_LIMIT: usize = 5_000; +const SKF_UPDATE_LIMIT: usize = 100; pub struct RouteService { auth_cache: AuthCache, @@ -135,6 +137,39 @@ impl RouteService { DevAddrEuiValidator::new(route_id, admin_keys, &self.pool, check_constraints).await } + + async fn validate_skf_devaddrs<'a>( + &self, + route_id: &'a str, + updates: &[route_skf_update_req_v1::RouteSkfUpdateV1], + ) -> Result<(), Status> { + let ranges: Vec = route::list_devaddr_ranges_for_route(route_id, &self.pool) + .map_err(|err| match err { + RouteStorageError::UuidParse(_) => { + Status::invalid_argument(format!("unable to parse route_id: {route_id}")) + } + _ => Status::internal("error retrieving devaddrs for route"), + })? + .filter_map(|range| async move { range.ok() }) + .collect() + .await; + + for update in updates { + let devaddr = update.devaddr.into(); + if !ranges.iter().any(|range| range.contains_addr(devaddr)) { + let ranges = ranges + .iter() + .map(|r| format!("{} -- {}", r.start_addr, r.end_addr)) + .collect::>() + .join(", "); + return Err(Status::invalid_argument(format!( + "devaddr {devaddr} not within registered ranges for route {route_id} :: {ranges}" + ))); + } + } + + Ok(()) + } } #[tonic::async_trait] @@ -344,6 +379,7 @@ impl iot_config::Route for RouteService { if stream_existing_routes(&pool, &signing_key, tx.clone()) .and_then(|_| stream_existing_euis(&pool, &signing_key, tx.clone())) .and_then(|_| stream_existing_devaddrs(&pool, &signing_key, tx.clone())) + .and_then(|_| stream_existing_skfs(&pool, &signing_key, tx.clone())) .await .is_err() { @@ -351,18 +387,18 @@ impl iot_config::Route for RouteService { } tracing::info!("existing routes sent; streaming updates as available"); - telemetry::stream_subscribe("route-stream"); + telemetry::route_stream_subscribe(); loop { let shutdown = shutdown_listener.clone(); tokio::select! { _ = shutdown => { - telemetry::stream_unsubscribe("route-stream"); + telemetry::route_stream_unsubscribe(); return } msg = route_updates.recv() => if let Ok(update) = msg { if tx.send(Ok(update)).await.is_err() { - telemetry::stream_unsubscribe("route-stream"); + telemetry::route_stream_unsubscribe(); return; } } @@ -428,88 +464,88 @@ impl iot_config::Route for RouteService { &self, request: GrpcStreamRequest, ) -> GrpcResult { - let mut request = request.into_inner(); + let request = request.into_inner(); telemetry::count_request("route", "update-euis"); - let mut to_add: Vec = vec![]; - let mut to_remove: Vec = vec![]; - let mut pending_updates: usize = 0; - - let mut validator: DevAddrEuiValidator = - if let Ok(Some(first_update)) = request.message().await { - if let Some(eui_pair) = &first_update.eui_pair { - let mut validator = self - .update_validator(&eui_pair.route_id, false) - .await - .map_err(|_| Status::internal("unable to verify updates"))?; - validator.validate_update(&first_update)?; - match first_update.action() { - ActionV1::Add => to_add.push(eui_pair.into()), - ActionV1::Remove => to_remove.push(eui_pair.into()), - }; - pending_updates += 1; - validator - } else { - return Err(Status::invalid_argument("no valid route_id for update")); + let mut incoming_stream = request.peekable(); + let mut validator: DevAddrEuiValidator = Pin::new(&mut incoming_stream) + .peek() + .await + .map(|first_update| async move { + match first_update { + Ok(ref update) => match update.eui_pair { + Some(ref eui_pair) => self + .update_validator(&eui_pair.route_id, false) + .await + .map_err(|err| { + Status::internal(format!("unable to verify updates: {err:?}")) + }), + None => Err(Status::invalid_argument("no eui pairs provided")), + }, + Err(_) => Err(Status::invalid_argument("no eui pairs provided")), } - } else { - return Err(Status::invalid_argument("no eui pair provided")); - }; + }) + .ok_or_else(|| Status::invalid_argument("no eui pairs provided"))? + .await?; - while let Ok(Some(update)) = request.message().await { - validator.validate_update(&update)?; - match (update.action(), update.eui_pair) { - (ActionV1::Add, Some(eui_pair)) => to_add.push(eui_pair.into()), - (ActionV1::Remove, Some(eui_pair)) => to_remove.push(eui_pair.into()), - _ => return Err(Status::invalid_argument("no eui pair provided")), - }; - pending_updates += 1; - if pending_updates >= UPDATE_BATCH_LIMIT { + incoming_stream + .map_ok(|update| match validator.validate_update(&update) { + Ok(()) => Ok(update), + Err(reason) => Err(Status::invalid_argument(format!( + "invalid update request: {reason:?}" + ))), + }) + .try_chunks(UPDATE_BATCH_LIMIT) + .map_err(|err| Status::internal(format!("eui pair updates failed to batch: {err:?}"))) + .and_then(|batch| async move { + batch + .into_iter() + .collect::, Status>>() + }) + .and_then(|batch| async move { + batch + .into_iter() + .map( + |update: RouteUpdateEuisReqV1| match (update.action(), update.eui_pair) { + (ActionV1::Add, Some(eui_pair)) => Ok((ActionV1::Add, eui_pair)), + (ActionV1::Remove, Some(eui_pair)) => Ok((ActionV1::Remove, eui_pair)), + _ => Err(Status::invalid_argument("invalid eui pair update request")), + }, + ) + .collect::, Status>>() + }) + .try_for_each(|batch: Vec<(ActionV1, EuiPairV1)>| async move { + let (to_add, to_remove): (Vec<(ActionV1, EuiPairV1)>, Vec<(ActionV1, EuiPairV1)>) = + batch + .into_iter() + .partition(|(action, _update)| action == &ActionV1::Add); telemetry::count_eui_updates(to_add.len(), to_remove.len()); tracing::debug!( adding = to_add.len(), removing = to_remove.len(), - "updating eui pairs", + "updating eui pairs" ); + let adds_update: Vec = + to_add.into_iter().map(|(_, add)| add.into()).collect(); + let removes_update: Vec = to_remove + .into_iter() + .map(|(_, remove)| remove.into()) + .collect(); route::update_euis( - &to_add, - &to_remove, + &adds_update, + &removes_update, &self.pool, self.signing_key.clone(), - self.update_channel.clone(), + self.clone_update_channel(), ) .await .map_err(|err| { tracing::error!("eui pair update failed: {err:?}"); - Status::internal("eui pair update failed") - })?; - to_add = vec![]; - to_remove = vec![]; - pending_updates = 0; - } - } - - if pending_updates > 0 { - telemetry::count_eui_updates(to_add.len(), to_remove.len()); - tracing::debug!( - adding = to_add.len(), - removing = to_remove.len(), - "updating euis", - ); + Status::internal(format!("eui pair update failed: {err:?}")) + }) + }) + .await?; - route::update_euis( - &to_add, - &to_remove, - &self.pool, - self.signing_key.clone(), - self.clone_update_channel(), - ) - .await - .map_err(|err| { - tracing::error!("eui update failed: {err:?}"); - Status::internal("eui update failed") - })?; - } let mut resp = RouteEuisResV1 { timestamp: Utc::now().encode_timestamp(), signer: self.signing_key.public_key().into(), @@ -576,88 +612,94 @@ impl iot_config::Route for RouteService { &self, request: GrpcStreamRequest, ) -> GrpcResult { - let mut request = request.into_inner(); + let request = request.into_inner(); telemetry::count_request("route", "update-devaddr-ranges"); - let mut to_add: Vec = vec![]; - let mut to_remove: Vec = vec![]; - let mut pending_updates: usize = 0; - - let mut validator: DevAddrEuiValidator = - if let Ok(Some(first_update)) = request.message().await { - if let Some(devaddr) = &first_update.devaddr_range { - let mut validator = self - .update_validator(&devaddr.route_id, true) - .await - .map_err(|_| Status::internal("unable to verify updates"))?; - validator.validate_update(&first_update)?; - match first_update.action() { - ActionV1::Add => to_add.push(devaddr.into()), - ActionV1::Remove => to_remove.push(devaddr.into()), - }; - pending_updates += 1; - validator - } else { - return Err(Status::invalid_argument("no valid route_id for update")); + let mut incoming_stream = request.peekable(); + let mut validator: DevAddrEuiValidator = Pin::new(&mut incoming_stream) + .peek() + .await + .map(|first_update| async move { + match first_update { + Ok(ref update) => match update.devaddr_range { + Some(ref devaddr_range) => self + .update_validator(&devaddr_range.route_id, true) + .await + .map_err(|err| { + Status::internal(format!("unable to verify update {err:?}")) + }), + None => Err(Status::invalid_argument("no devaddr range provided")), + }, + Err(_) => Err(Status::invalid_argument("no devaddr range provided")), } - } else { - return Err(Status::invalid_argument("no devaddr range provided")); - }; + }) + .ok_or_else(|| Status::invalid_argument("no devaddr range provided"))? + .await?; - while let Ok(Some(update)) = request.message().await { - validator.validate_update(&update)?; - match (update.action(), update.devaddr_range) { - (ActionV1::Add, Some(devaddr)) => to_add.push(devaddr.into()), - (ActionV1::Remove, Some(devaddr)) => to_remove.push(devaddr.into()), - _ => return Err(Status::invalid_argument("no devaddr range provided")), - }; - pending_updates += 1; - if pending_updates >= UPDATE_BATCH_LIMIT { + incoming_stream + .map_ok(|update| match validator.validate_update(&update) { + Ok(()) => Ok(update), + Err(reason) => Err(Status::invalid_argument(format!( + "invalid update request: {reason:?}" + ))), + }) + .try_chunks(UPDATE_BATCH_LIMIT) + .map_err(|err| { + Status::internal(format!("devaddr range update failed to batch: {err:?}")) + }) + .and_then(|batch| async move { + batch + .into_iter() + .collect::, Status>>() + }) + .and_then(|batch| async move { + batch + .into_iter() + .map(|update: RouteUpdateDevaddrRangesReqV1| { + match (update.action(), update.devaddr_range) { + (ActionV1::Add, Some(range)) => Ok((ActionV1::Add, range)), + (ActionV1::Remove, Some(range)) => Ok((ActionV1::Remove, range)), + _ => Err(Status::invalid_argument( + "invalid devaddr range update request", + )), + } + }) + .collect::, Status>>() + }) + .try_for_each(|batch: Vec<(ActionV1, DevaddrRangeV1)>| async move { + let (to_add, to_remove): ( + Vec<(ActionV1, DevaddrRangeV1)>, + Vec<(ActionV1, DevaddrRangeV1)>, + ) = batch + .into_iter() + .partition(|(action, _update)| action == &ActionV1::Add); telemetry::count_devaddr_updates(to_add.len(), to_remove.len()); tracing::debug!( adding = to_add.len(), removing = to_remove.len(), "updating devaddr ranges" ); + let adds_update: Vec = + to_add.into_iter().map(|(_, add)| add.into()).collect(); + let removes_update: Vec = to_remove + .into_iter() + .map(|(_, remove)| remove.into()) + .collect(); route::update_devaddr_ranges( - &to_add, - &to_remove, + &adds_update, + &removes_update, &self.pool, self.signing_key.clone(), - self.update_channel.clone(), + self.clone_update_channel(), ) .await .map_err(|err| { tracing::error!("devaddr range update failed: {err:?}"); Status::internal("devaddr range update failed") - })?; - to_add = vec![]; - to_remove = vec![]; - pending_updates = 0; - } - } - - if pending_updates > 0 { - telemetry::count_devaddr_updates(to_add.len(), to_remove.len()); - tracing::debug!( - adding = to_add.len(), - removing = to_remove.len(), - "updating devaddr ranges" - ); + }) + }) + .await?; - route::update_devaddr_ranges( - &to_add, - &to_remove, - &self.pool, - self.signing_key.clone(), - self.update_channel.clone(), - ) - .await - .map_err(|err| { - tracing::error!("devaddr range update failed: {err:?}"); - Status::internal("devaddr range update failed") - })?; - } let mut resp = RouteDevaddrRangesResV1 { timestamp: Utc::now().encode_timestamp(), signer: self.signing_key.public_key().into(), @@ -667,6 +709,182 @@ impl iot_config::Route for RouteService { Ok(Response::new(resp)) } + + type list_skfsStream = GrpcStreamResult; + async fn list_skfs( + &self, + request: Request, + ) -> GrpcResult { + let request = request.into_inner(); + telemetry::count_request("route", "list-skfs"); + + let signer = verify_public_key(&request.signer)?; + self.verify_request_signature(&signer, &request, OrgId::RouteId(&request.route_id)) + .await?; + + let pool = self.pool.clone(); + let (tx, rx) = tokio::sync::mpsc::channel(20); + + tracing::debug!( + route_id = request.route_id, + "listing session key filters for route" + ); + + tokio::spawn(async move { + let mut skf_stream = match route::list_skfs_for_route(&request.route_id, &pool) { + Ok(skfs) => skfs, + Err(RouteStorageError::UuidParse(err)) => { + _ = tx + .send(Err(Status::invalid_argument(format!("{}", err)))) + .await; + return; + } + Err(_) => { + _ = tx + .send(Err(Status::internal(format!( + "failed retrieving skfs for route {}", + &request.route_id + )))) + .await; + return; + } + }; + + while let Some(skf) = skf_stream.next().await { + let message = match skf { + Ok(skf) => Ok(skf.into()), + Err(bad_skf) => Err(Status::internal(format!("invalid skf: {:?}", bad_skf))), + }; + if tx.send(message).await.is_err() { + break; + } + } + }); + + Ok(Response::new(GrpcStreamResult::new(rx))) + } + + type get_skfsStream = GrpcStreamResult; + async fn get_skfs( + &self, + request: Request, + ) -> GrpcResult { + let request = request.into_inner(); + telemetry::count_request("route", "get-skfs"); + + let signer = verify_public_key(&request.signer)?; + self.verify_request_signature(&signer, &request, OrgId::RouteId(&request.route_id)) + .await?; + + let pool = self.pool.clone(); + let (tx, rx) = tokio::sync::mpsc::channel(20); + + tracing::debug!( + route_id = request.route_id, + "listing session key filters for route and devaddr" + ); + + tokio::spawn(async move { + let mut skf_stream = match route::list_skfs_for_route_and_devaddr( + &request.route_id, + request.devaddr.into(), + &pool, + ) { + Ok(skfs) => skfs, + Err(RouteStorageError::UuidParse(err)) => { + _ = tx + .send(Err(Status::invalid_argument(format!("{}", err)))) + .await; + return; + } + Err(_) => { + _ = tx + .send(Err(Status::internal(format!( + "failed retrieving skfs for route {} and devaddr {}", + &request.route_id, &request.devaddr + )))) + .await; + return; + } + }; + + while let Some(skf) = skf_stream.next().await { + let message = match skf { + Ok(skf) => Ok(skf.into()), + Err(bad_skf) => Err(Status::internal(format!("invalid skf: {:?}", bad_skf))), + }; + if tx.send(message).await.is_err() { + break; + } + } + }); + + Ok(Response::new(GrpcStreamResult::new(rx))) + } + + async fn update_skfs( + &self, + request: Request, + ) -> GrpcResult { + let request = request.into_inner(); + telemetry::count_request("route", "update-skfs"); + + if request.updates.len() > SKF_UPDATE_LIMIT { + return Err(Status::invalid_argument( + "exceeds 100 skf update limit per request", + )); + }; + + let signer = verify_public_key(&request.signer)?; + self.verify_request_signature(&signer, &request, OrgId::RouteId(&request.route_id)) + .await?; + + self.validate_skf_devaddrs(&request.route_id, &request.updates) + .await?; + + let (to_add, to_remove): (Vec<(ActionV1, Skf)>, Vec<(ActionV1, Skf)>) = request + .updates + .into_iter() + .map(|update: route_skf_update_req_v1::RouteSkfUpdateV1| { + ( + update.action(), + Skf::new( + request.route_id.clone(), + update.devaddr.into(), + update.session_key, + ), + ) + }) + .partition(|(action, _update)| action == &ActionV1::Add); + telemetry::count_skf_updates(to_add.len(), to_remove.len()); + tracing::debug!( + adding = to_add.len(), + removing = to_remove.len(), + "updating session key filters" + ); + let adds_update: Vec = to_add.into_iter().map(|(_, add)| add).collect(); + let removes_update: Vec = to_remove.into_iter().map(|(_, remove)| remove).collect(); + route::update_skfs( + &adds_update, + &removes_update, + &self.pool, + self.signing_key.clone(), + self.clone_update_channel(), + ) + .await + .map_err(|err| { + tracing::error!("session key update failed: {err:?}"); + Status::internal(format!("session key update failed {err:?}")) + })?; + + let mut resp = RouteSkfUpdateResV1 { + timestamp: Utc::now().encode_timestamp(), + signer: self.signing_key.public_key().into(), + signature: vec![], + }; + resp.signature = self.sign_response(&resp.encode_to_vec())?; + Ok(Response::new(resp)) + } } struct DevAddrEuiValidator { @@ -678,7 +896,7 @@ struct DevAddrEuiValidator { #[derive(thiserror::Error, Debug)] enum DevAddrEuiValidationError { #[error("devaddr range outside of constraint bounds {0}")] - RangeOutOfBounds(String), + DevAddrOutOfBounds(String), #[error("no route for update {0}")] NoRouteId(String), #[error("unauthorized signature {0}")] @@ -793,7 +1011,7 @@ where return Ok(update); } } - Err(DevAddrEuiValidationError::RangeOutOfBounds(format!( + Err(DevAddrEuiValidationError::DevAddrOutOfBounds(format!( "{update:?}" ))) } @@ -913,3 +1131,31 @@ async fn stream_existing_devaddrs( .try_fold((), |acc, _| async move { Ok(acc) }) .await } + +async fn stream_existing_skfs( + pool: &Pool, + signing_key: &Keypair, + tx: mpsc::Sender>, +) -> Result<()> { + let timestamp = Utc::now().encode_timestamp(); + let signer: Vec = signing_key.public_key().into(); + route::skf_stream(pool) + .then(|skf| { + let mut skf_res = RouteStreamResV1 { + action: ActionV1::Add.into(), + data: Some(route_stream_res_v1::Data::Skf(skf.into())), + timestamp, + signer: signer.clone(), + signature: vec![], + }; + if let Ok(signature) = signing_key.sign(&skf_res.encode_to_vec()) { + skf_res.signature = signature; + tx.send(Ok(skf_res)) + } else { + tx.send(Err(Status::internal("failed to sign session key filter"))) + } + }) + .map_err(|err| anyhow!(err)) + .try_fold((), |acc, _| async move { Ok(acc) }) + .await +} diff --git a/iot_config/src/session_key.rs b/iot_config/src/session_key.rs deleted file mode 100644 index c01d685e5..000000000 --- a/iot_config/src/session_key.rs +++ /dev/null @@ -1,227 +0,0 @@ -use crate::{broadcast_update, lora_field::DevAddrField}; -use anyhow::anyhow; -use chrono::Utc; -use file_store::traits::TimestampEncode; -use futures::{ - future::TryFutureExt, - stream::{self, Stream, StreamExt, TryStreamExt}, -}; -use helium_crypto::{Keypair, Sign}; -use helium_proto::{ - services::iot_config::{ActionV1, SessionKeyFilterStreamResV1, SessionKeyFilterV1}, - Message, -}; -use sqlx::{postgres::PgRow, FromRow, Row}; -use std::sync::Arc; -use tokio::sync::broadcast::Sender; - -#[derive(Clone, Debug)] -pub struct SessionKeyFilter { - pub oui: u64, - pub devaddr: DevAddrField, - pub session_key: String, -} - -impl FromRow<'_, PgRow> for SessionKeyFilter { - fn from_row(row: &PgRow) -> sqlx::Result { - Ok(Self { - oui: row.get::("oui") as u64, - devaddr: row.get::("devaddr").into(), - session_key: row.get::("session_key"), - }) - } -} - -pub fn list_stream<'a>( - db: impl sqlx::PgExecutor<'a> + 'a, -) -> impl Stream + 'a { - sqlx::query_as::<_, SessionKeyFilter>(r#" select * from session_key_filters "#) - .fetch(db) - .filter_map(|filter| async move { filter.ok() }) - .boxed() -} - -pub fn list_for_oui<'a>( - oui: u64, - db: impl sqlx::PgExecutor<'a> + 'a, -) -> impl Stream> + 'a { - sqlx::query_as::<_, SessionKeyFilter>( - r#" - select * from session_key_filters - where oui = $1 - "#, - ) - .bind(oui as i64) - .fetch(db) - .boxed() -} - -pub fn list_for_oui_and_devaddr<'a>( - oui: u64, - devaddr: DevAddrField, - db: impl sqlx::PgExecutor<'a> + 'a, -) -> impl Stream> + 'a { - sqlx::query_as::<_, SessionKeyFilter>( - r#" - select * from session_key_filters - where oui = $1 and devaddr = $2 - "#, - ) - .bind(oui as i64) - .bind(i32::from(devaddr)) - .fetch(db) - .boxed() -} - -pub async fn update_session_keys( - to_add: &[SessionKeyFilter], - to_remove: &[SessionKeyFilter], - db: impl sqlx::PgExecutor<'_> + sqlx::Acquire<'_, Database = sqlx::Postgres> + Copy, - signing_key: Arc, - update_tx: Sender, -) -> Result<(), sqlx::Error> { - let mut transaction = db.begin().await?; - - let added_updates: Vec<(SessionKeyFilter, ActionV1)> = - insert_session_key_filters(to_add, &mut transaction) - .await? - .into_iter() - .map(|added_skf| (added_skf, ActionV1::Add)) - .collect(); - - let removed_updates: Vec<(SessionKeyFilter, ActionV1)> = - remove_session_key_filters(to_remove, &mut transaction) - .await? - .into_iter() - .map(|removed_skf| (removed_skf, ActionV1::Remove)) - .collect(); - - transaction.commit().await?; - - tokio::spawn(async move { - let timestamp = Utc::now().encode_timestamp(); - let signer: Vec = signing_key.public_key().into(); - stream::iter([added_updates, removed_updates].concat()) - .map(Ok) - .try_for_each(|(update, action)| { - let mut skf_update = SessionKeyFilterStreamResV1 { - action: i32::from(action), - filter: Some(update.into()), - timestamp, - signer: signer.clone(), - signature: vec![], - }; - futures::future::ready(signing_key.sign(&skf_update.encode_to_vec())) - .map_err(|_| anyhow!("failed to sign session key filter update")) - .and_then(|signature| { - skf_update.signature = signature; - broadcast_update::( - skf_update, - update_tx.clone(), - ) - .map_err(|_| anyhow!("failed to broadcast session key filter update")) - }) - }) - .await - }); - - Ok(()) -} - -async fn insert_session_key_filters( - session_key_filters: &[SessionKeyFilter], - db: impl sqlx::PgExecutor<'_>, -) -> Result, sqlx::Error> { - if session_key_filters.is_empty() { - return Ok(vec![]); - } - - const SESSION_KEY_FILTER_INSERT_VALS: &str = - " insert into session_key_filters (oui, devaddr, session_key) "; - const SESSION_KEY_FILTER_INSERT_CONFLICT: &str = - " on conflict (oui, devaddr, session_key) do nothing returning * "; - - let mut query_builder: sqlx::QueryBuilder = - sqlx::QueryBuilder::new(SESSION_KEY_FILTER_INSERT_VALS); - query_builder - .push_values(session_key_filters, |mut builder, session_key_filter| { - builder - .push_bind(session_key_filter.oui as i64) - .push_bind(i32::from(session_key_filter.devaddr)) - .push_bind(session_key_filter.session_key.clone()); - }) - .push(SESSION_KEY_FILTER_INSERT_CONFLICT); - - query_builder - .build_query_as::() - .fetch_all(db) - .await -} - -async fn remove_session_key_filters( - session_key_filters: &[SessionKeyFilter], - db: impl sqlx::PgExecutor<'_>, -) -> Result, sqlx::Error> { - if session_key_filters.is_empty() { - return Ok(vec![]); - } - - const SESSION_KEY_FILTER_DELETE_VALS: &str = - " delete from session_key_filters where (oui, devaddr, session_key) in "; - const SESSION_KEY_FILTER_DELETE_RETURN: &str = " returning * "; - let mut query_builder: sqlx::QueryBuilder = - sqlx::QueryBuilder::new(SESSION_KEY_FILTER_DELETE_VALS); - query_builder - .push_tuples(session_key_filters, |mut builder, session_key_filter| { - builder - .push_bind(session_key_filter.oui as i64) - .push_bind(i32::from(session_key_filter.devaddr)) - .push_bind(session_key_filter.session_key.clone()); - }) - .push(SESSION_KEY_FILTER_DELETE_RETURN); - - query_builder - .build_query_as::() - .fetch_all(db) - .await -} - -impl From for SessionKeyFilter { - fn from(value: SessionKeyFilterV1) -> Self { - Self { - oui: value.oui, - devaddr: value.devaddr.into(), - session_key: value.session_key, - } - } -} - -impl From<&SessionKeyFilterV1> for SessionKeyFilter { - fn from(value: &SessionKeyFilterV1) -> Self { - Self { - oui: value.oui, - devaddr: value.devaddr.into(), - session_key: value.session_key.to_owned(), - } - } -} - -impl From for SessionKeyFilterV1 { - fn from(value: SessionKeyFilter) -> Self { - Self { - oui: value.oui, - devaddr: value.devaddr.into(), - session_key: value.session_key, - } - } -} - -impl From<&SessionKeyFilter> for SessionKeyFilterV1 { - fn from(value: &SessionKeyFilter) -> Self { - Self { - oui: value.oui, - devaddr: value.devaddr.into(), - session_key: value.session_key.to_owned(), - } - } -} diff --git a/iot_config/src/session_key_service.rs b/iot_config/src/session_key_service.rs deleted file mode 100644 index 69d406071..000000000 --- a/iot_config/src/session_key_service.rs +++ /dev/null @@ -1,473 +0,0 @@ -use crate::{ - admin::{AuthCache, KeyType}, - lora_field::DevAddrConstraint, - org::{self, DbOrgError}, - session_key::{self, SessionKeyFilter}, - telemetry, update_channel, verify_public_key, GrpcResult, GrpcStreamRequest, GrpcStreamResult, - Settings, -}; -use anyhow::{anyhow, Result}; -use chrono::Utc; -use file_store::traits::{MsgVerify, TimestampEncode}; -use futures::{ - future::TryFutureExt, - stream::{StreamExt, TryStreamExt}, -}; -use helium_crypto::{Keypair, PublicKey, Sign}; -use helium_proto::{ - services::iot_config::{ - self, ActionV1, SessionKeyFilterGetReqV1, SessionKeyFilterListReqV1, - SessionKeyFilterStreamReqV1, SessionKeyFilterStreamResV1, SessionKeyFilterUpdateReqV1, - SessionKeyFilterUpdateResV1, SessionKeyFilterV1, - }, - Message, -}; -use sqlx::{Pool, Postgres}; -use std::{pin::Pin, sync::Arc}; -use tokio::sync::{broadcast, mpsc}; -use tonic::{Request, Response, Status}; - -const UPDATE_BATCH_LIMIT: usize = 5_000; - -pub struct SessionKeyFilterService { - auth_cache: AuthCache, - pool: Pool, - update_channel: broadcast::Sender, - shutdown: triggered::Listener, - signing_key: Arc, -} - -impl SessionKeyFilterService { - pub fn new( - settings: &Settings, - auth_cache: AuthCache, - pool: Pool, - shutdown: triggered::Listener, - ) -> Result { - Ok(Self { - auth_cache, - pool, - update_channel: update_channel(), - shutdown, - signing_key: Arc::new(settings.signing_keypair()?), - }) - } - - fn subscribe_to_session_keys(&self) -> broadcast::Receiver { - self.update_channel.subscribe() - } - - fn clone_update_channel(&self) -> broadcast::Sender { - self.update_channel.clone() - } - - async fn verify_request_signature<'a, R>( - &self, - signer: &PublicKey, - request: &R, - id: u64, - ) -> Result<(), Status> - where - R: MsgVerify, - { - if self - .auth_cache - .verify_signature_with_type(KeyType::Administrator, signer, request) - .is_ok() - { - tracing::debug!(signer = signer.to_string(), "request authorized by admin"); - return Ok(()); - } - - let org_keys = org::get_org_pubkeys(id, &self.pool) - .await - .map_err(|_| Status::internal("auth verification error"))?; - - if org_keys.as_slice().contains(signer) && request.verify(signer).is_ok() { - tracing::debug!( - signer = signer.to_string(), - "request authorized by delegate" - ); - return Ok(()); - } - Err(Status::permission_denied("unauthorized request signature")) - } - - fn verify_stream_request_signature( - &self, - signer: &PublicKey, - request: &R, - ) -> Result<(), Status> - where - R: MsgVerify, - { - if self.auth_cache.verify_signature(signer, request).is_ok() { - tracing::debug!(signer = signer.to_string(), "request authorized"); - Ok(()) - } else { - Err(Status::permission_denied("unauthorized request signature")) - } - } - - fn sign_response(&self, response: &[u8]) -> Result, Status> { - self.signing_key - .sign(response) - .map_err(|_| Status::internal("response signing error")) - } - - async fn update_validator(&self, oui: u64) -> Result { - let admin_keys = self.auth_cache.get_keys_by_type(KeyType::Administrator); - - SkfValidator::new(oui, admin_keys, &self.pool).await - } -} - -#[tonic::async_trait] -impl iot_config::SessionKeyFilter for SessionKeyFilterService { - type listStream = GrpcStreamResult; - async fn list( - &self, - request: Request, - ) -> GrpcResult { - let request = request.into_inner(); - telemetry::count_request("session-key-filter", "list"); - - let signer = verify_public_key(&request.signer)?; - self.verify_request_signature(&signer, &request, request.oui) - .await?; - - let pool = self.pool.clone(); - let (tx, rx) = tokio::sync::mpsc::channel(20); - - tokio::spawn(async move { - let mut filters = session_key::list_for_oui(request.oui, &pool); - - while let Some(filter) = filters.next().await { - let message = match filter { - Ok(filter) => Ok(filter.into()), - Err(bad_filter) => Err(Status::internal(format!( - "invalid session key filter {bad_filter:?}" - ))), - }; - if tx.send(message).await.is_err() { - break; - } - } - }); - - Ok(Response::new(GrpcStreamResult::new(rx))) - } - - type getStream = GrpcStreamResult; - async fn get(&self, request: Request) -> GrpcResult { - let request = request.into_inner(); - telemetry::count_request("session-key-filter", "get"); - - let signer = verify_public_key(&request.signer)?; - self.verify_request_signature(&signer, &request, request.oui) - .await?; - - let (tx, rx) = tokio::sync::mpsc::channel(20); - let pool = self.pool.clone(); - - tokio::spawn(async move { - let mut filters = - session_key::list_for_oui_and_devaddr(request.oui, request.devaddr.into(), &pool); - - while let Some(filter) = filters.next().await { - let message = match filter { - Ok(filter) => Ok(filter.into()), - Err(bad_filter) => Err(Status::internal(format!( - "invalid session key filter {bad_filter:?}" - ))), - }; - if tx.send(message).await.is_err() { - break; - } - } - }); - - Ok(Response::new(GrpcStreamResult::new(rx))) - } - - async fn update( - &self, - request: GrpcStreamRequest, - ) -> GrpcResult { - let request = request.into_inner(); - telemetry::count_request("session-key-filter", "update"); - - let mut incoming_stream = request.peekable(); - let mut validator: SkfValidator = Pin::new(&mut incoming_stream) - .peek() - .await - .map(|first_update| async move { - match first_update { - Ok(ref update) => match update.filter { - Some(ref filter) => { - self.update_validator(filter.oui).await.map_err(|err| { - Status::internal(format!("unable to verify updates {err:?}")) - }) - } - None => Err(Status::invalid_argument("no session key filter provided")), - }, - Err(_) => Err(Status::invalid_argument("no session key filter provided")), - } - }) - .ok_or_else(|| Status::invalid_argument("no session key filter provided"))? - .await?; - - incoming_stream - .map_ok(|update| match validator.validate_update(&update) { - Ok(()) => Ok(update), - Err(reason) => Err(Status::invalid_argument(format!( - "invalid update request: {reason:?}" - ))), - }) - .try_chunks(UPDATE_BATCH_LIMIT) - .map_err(|err| Status::internal(format!("session key update failed to batch {err:?}"))) - .and_then(|batch| async move { - batch - .into_iter() - .collect::, Status>>() - }) - .and_then(|batch| async move { - batch - .into_iter() - .map(|update: SessionKeyFilterUpdateReqV1| { - match (update.action(), update.filter) { - (ActionV1::Add, Some(filter)) => Ok((ActionV1::Add, filter)), - (ActionV1::Remove, Some(filter)) => Ok((ActionV1::Remove, filter)), - _ => Err(Status::invalid_argument("invalid filter update request")), - } - }) - .collect::, Status>>() - }) - .try_for_each(|batch: Vec<(ActionV1, SessionKeyFilterV1)>| async move { - let (to_add, to_remove): ( - Vec<(ActionV1, SessionKeyFilterV1)>, - Vec<(ActionV1, SessionKeyFilterV1)>, - ) = batch - .into_iter() - .partition(|(action, _update)| action == &ActionV1::Add); - telemetry::count_skf_updates(to_add.len(), to_remove.len()); - tracing::debug!( - adding = to_add.len(), - removing = to_remove.len(), - "updating session key filters" - ); - let adds_update = to_add - .into_iter() - .map(|(_, add)| add.into()) - .collect::>(); - let removes_update = to_remove - .into_iter() - .map(|(_, remove)| remove.into()) - .collect::>(); - session_key::update_session_keys( - &adds_update, - &removes_update, - &self.pool, - self.signing_key.clone(), - self.clone_update_channel(), - ) - .await - .map_err(|err| { - tracing::error!("session key update failed: {err:?}"); - Status::internal(format!("session key update failed {err:?}")) - }) - }) - .await?; - - let mut resp = SessionKeyFilterUpdateResV1 { - timestamp: Utc::now().encode_timestamp(), - signer: self.signing_key.public_key().into(), - signature: vec![], - }; - resp.signature = self.sign_response(&resp.encode_to_vec())?; - Ok(Response::new(resp)) - } - - type streamStream = GrpcStreamResult; - async fn stream( - &self, - request: Request, - ) -> GrpcResult { - let request = request.into_inner(); - telemetry::count_request("session-key-filter", "stream"); - - let signer = verify_public_key(&request.signer)?; - self.verify_stream_request_signature(&signer, &request)?; - - tracing::info!("client subscribed to session key stream"); - - let pool = self.pool.clone(); - let shutdown_listener = self.shutdown.clone(); - let (tx, rx) = tokio::sync::mpsc::channel(20); - let signing_key = self.signing_key.clone(); - - let mut session_key_updates = self.subscribe_to_session_keys(); - - tokio::spawn(async move { - if stream_existing_skfs(&pool, signing_key, tx.clone()) - .await - .is_err() - { - return; - } - - tracing::info!("existing session keys sent; streaming updates as available"); - telemetry::stream_subscribe("session-key-filter-stream"); - loop { - let shutdown = shutdown_listener.clone(); - - tokio::select! { - _ = shutdown => { - telemetry::stream_unsubscribe("session-key-filter-stream"); - return - } - msg = session_key_updates.recv() => if let Ok(update) = msg { - if tx.send(Ok(update)).await.is_err() { - telemetry::stream_unsubscribe("session-key-filter-stream"); - return; - } - } - } - } - }); - - Ok(Response::new(GrpcStreamResult::new(rx))) - } -} - -async fn stream_existing_skfs( - pool: &Pool, - signing_key: Arc, - tx: mpsc::Sender>, -) -> Result<()> { - let timestamp = Utc::now().encode_timestamp(); - let signer: Vec = signing_key.public_key().into(); - session_key::list_stream(pool) - .then(|session_key_filter| { - let mut skf_resp = SessionKeyFilterStreamResV1 { - action: ActionV1::Add.into(), - filter: Some(session_key_filter.into()), - timestamp, - signer: signer.clone(), - signature: vec![], - }; - - futures::future::ready(signing_key.sign(&skf_resp.encode_to_vec())) - .map_err(|_| anyhow!("failed signing session key filter")) - .and_then(|signature| { - skf_resp.signature = signature; - tx.send(Ok(skf_resp)) - .map_err(|_| anyhow!("failed sending session key filter")) - }) - }) - .map_err(|err| anyhow!(err)) - .try_fold((), |acc, _| async move { Ok(acc) }) - .await -} - -struct SkfValidator { - oui: u64, - constraints: Vec, - signing_keys: Vec, -} - -#[derive(thiserror::Error, Debug)] -enum SkfValidatorError { - #[error("devaddr outside of constraint bounds {0}")] - AddrOutOfBounds(String), - #[error("wrong oui for session key filter {0}")] - WrongOui(String), - #[error("unauthorized signature {0}")] - UnauthorizedSignature(String), - #[error("invalid update {0}")] - InvalidUpdate(String), -} - -impl SkfValidator { - async fn new( - oui: u64, - mut admin_keys: Vec, - db: impl sqlx::PgExecutor<'_> + Copy, - ) -> Result { - let org = org::get_with_constraints(oui, db).await?; - let mut org_keys = org::get_org_pubkeys(oui, db).await?; - org_keys.append(&mut admin_keys); - - Ok(Self { - oui, - constraints: org.constraints, - signing_keys: org_keys, - }) - } - - fn validate_update<'a>( - &'a mut self, - request: &'a SessionKeyFilterUpdateReqV1, - ) -> Result<(), Status> { - validate_oui(request, self.oui) - .and_then(|update| validate_constraint_bounds(update, self.constraints.as_ref())) - .and_then(|update| validate_signature(update, &mut self.signing_keys)) - .map_err(|err| Status::invalid_argument(format!("{err:?}")))?; - Ok(()) - } -} - -fn validate_oui( - update: &SessionKeyFilterUpdateReqV1, - oui: u64, -) -> Result<&SessionKeyFilterUpdateReqV1, SkfValidatorError> { - let filter_oui = if let Some(ref filter) = update.filter { - filter.oui - } else { - return Err(SkfValidatorError::InvalidUpdate(format!("{update:?}"))); - }; - - if oui == filter_oui { - Ok(update) - } else { - Err(SkfValidatorError::WrongOui(format!( - "authorized oui: {oui}, update: {filter_oui}" - ))) - } -} - -fn validate_constraint_bounds<'a>( - update: &'a SessionKeyFilterUpdateReqV1, - constraints: &'a Vec, -) -> Result<&'a SessionKeyFilterUpdateReqV1, SkfValidatorError> { - let filter_addr = if let Some(ref filter) = update.filter { - filter.devaddr - } else { - return Err(SkfValidatorError::InvalidUpdate(format!("{update:?}"))); - }; - - for constraint in constraints { - if constraint.contains_addr(filter_addr.into()) { - return Ok(update); - } - } - Err(SkfValidatorError::AddrOutOfBounds(format!("{update:?}"))) -} - -fn validate_signature<'a, R>( - request: &'a R, - signing_keys: &mut [PublicKey], -) -> Result<&'a R, SkfValidatorError> -where - R: MsgVerify + std::fmt::Debug, -{ - for (idx, pubkey) in signing_keys.iter().enumerate() { - if request.verify(pubkey).is_ok() { - signing_keys.swap(idx, 0); - return Ok(request); - } - } - Err(SkfValidatorError::UnauthorizedSignature(format!( - "{request:?}" - ))) -} diff --git a/iot_config/src/telemetry.rs b/iot_config/src/telemetry.rs index c1bc242e4..32b3dbec8 100644 --- a/iot_config/src/telemetry.rs +++ b/iot_config/src/telemetry.rs @@ -56,10 +56,10 @@ pub fn count_devaddr_updates(adds: usize, removes: usize) { metrics::counter!(DEVADDR_REMOVE_COUNT_METRIC, removes as u64); } -pub fn stream_subscribe(stream: &'static str) { - metrics::increment_gauge!(STREAM_METRIC, 1.0, "stream" => stream); +pub fn route_stream_subscribe() { + metrics::increment_gauge!(STREAM_METRIC, 1.0); } -pub fn stream_unsubscribe(stream: &'static str) { - metrics::decrement_gauge!(STREAM_METRIC, 1.0, "stream" => stream); +pub fn route_stream_unsubscribe() { + metrics::decrement_gauge!(STREAM_METRIC, 1.0); }