diff --git a/Cargo.lock b/Cargo.lock index cd041a9..767381e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,3 +1,5 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. [[package]] name = "aho-corasick" version = "0.7.3" diff --git a/src/main.rs b/src/main.rs index 44f00dd..2fe5d89 100644 --- a/src/main.rs +++ b/src/main.rs @@ -40,7 +40,7 @@ use receiver::Receiver; use std::env; use std::net::SocketAddr; use stream::StreamManager; -use user::{Scope, User}; +use user::{OauthScope::*, Scope, User}; use warp::path; use warp::Filter as WarpFilter; @@ -110,37 +110,62 @@ fn main() { h: query::Hashtag, l: query::List, ws: warp::ws::Ws2| { - let unauthorized = Err(warp::reject::custom("Error: Invalid Access Token")); + let scopes = user.scopes.clone(); let timeline = match q.stream.as_ref() { // Public endpoints: tl @ "public" | tl @ "public:local" if m.is_truthy() => format!("{}:media", tl), tl @ "public:media" | tl @ "public:local:media" => tl.to_string(), tl @ "public" | tl @ "public:local" => tl.to_string(), - // User - "user" if user.id == -1 => return unauthorized, - "user" => format!("{}", user.id), - "user:notification" => { - user = user.with_notification_filter(); - format!("{}", user.id) - } // Hashtag endpoints: // TODO: handle missing query tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, h.tag), + // Private endpoints: User + "user" + if user.id > 0 + && (scopes.contains(&Read) || scopes.contains(&ReadStatuses)) => + { + format!("{}", user.id) + } + "user:notification" + if user.id > 0 + && (scopes.contains(&Read) || scopes.contains(&ReadNotifications)) => + { + user = user.with_notification_filter(); + format!("{}", user.id) + } // List endpoint: // TODO: handle missing query - "list" if user.authorized_for_list(l.list).is_err() => return unauthorized, - "list" => format!("list:{}", l.list), + "list" + if user.authorized_for_list(l.list).is_ok() + && (scopes.contains(&Read) || scopes.contains(&ReadList)) => + { + format!("list:{}", l.list) + } + // Direct endpoint: - "direct" if user.id == -1 => return unauthorized, - "direct" => "direct".to_string(), + "direct" + if user.id > 0 + && (scopes.contains(&Read) || scopes.contains(&ReadStatuses)) => + { + "direct".to_string() + } + // Reject unathorized access attempts for private endpoints + "user" | "user:notification" | "direct" | "list" => { + return Err(warp::reject::custom("Error: Invalid Access Token")) + } // Other endpoints don't exist: _ => return Err(warp::reject::custom("Error: Nonexistent WebSocket query")), }; + let token = user.access_token.clone(); let stream = redis_updates_ws.configure_copy(&timeline, user); - Ok(ws.on_upgrade(move |socket| ws::send_replies(socket, stream))) + Ok(( + ws.on_upgrade(move |socket| ws::send_replies(socket, stream)), + token, + )) }, - ); + ) + .map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token)); let address: SocketAddr = env::var("SERVER_ADDR") .unwrap_or("127.0.0.1:4000".to_owned()) diff --git a/src/stream.rs b/src/stream.rs index 069328b..928f97d 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -51,7 +51,10 @@ impl Stream for StreamManager { type Error = Error; fn poll(&mut self) -> Poll, Self::Error> { - let mut receiver = self.receiver.lock().expect("No other thread panic"); + let mut receiver = self + .receiver + .lock() + .expect("StreamManager: No other thread panic"); receiver.update(self.id, &self.target_timeline.clone()); match receiver.poll() { Ok(Async::Ready(Some(value))) => { @@ -61,19 +64,19 @@ impl Stream for StreamManager { .expect("Previously set current user"); let user_langs = user.langs.clone(); - let copy = value.clone(); - let event = copy["event"].as_str().expect("Redis string"); - let copy = value.clone(); - let payload = copy["payload"].to_string(); - let copy = value.clone(); - let toot_lang = copy["payload"]["language"] - .as_str() - .expect("redis str") - .to_string(); + let event = value["event"].as_str().expect("Redis string"); + let payload = value["payload"].to_string(); match (&user.filter, user_langs) { (Filter::Notification, _) if event != "notification" => Ok(Async::NotReady), - (Filter::Language, Some(ref langs)) if !langs.contains(&toot_lang) => { + (Filter::Language, Some(ref user_langs)) + if !user_langs.contains( + &value["payload"]["language"] + .as_str() + .expect("Redis str") + .to_string(), + ) => + { Ok(Async::NotReady) } _ => Ok(Async::Ready(Some(json!( diff --git a/src/user.rs b/src/user.rs index 5a5cbb2..11c2a66 100644 --- a/src/user.rs +++ b/src/user.rs @@ -27,18 +27,43 @@ pub enum Filter { #[derive(Clone, Debug, PartialEq)] pub struct User { pub id: i64, + pub access_token: String, + pub scopes: Vec, pub langs: Option>, pub logged_in: bool, pub filter: Filter, } +#[derive(Clone, Debug, PartialEq)] +pub enum OauthScope { + Read, + ReadStatuses, + ReadNotifications, + ReadList, + Other, +} +impl From<&str> for OauthScope { + fn from(scope: &str) -> Self { + use OauthScope::*; + match scope { + "read" => Read, + "read:statuses" => ReadStatuses, + "read:notifications" => ReadNotifications, + "read:lists" => ReadList, + _ => Other, + } + } +} impl User { /// Create a user from the access token supplied in the header or query paramaters - pub fn from_access_token(token: String, scope: Scope) -> Result { + pub fn from_access_token( + access_token: String, + scope: Scope, + ) -> Result { let conn = connect_to_postgres(); let result = &conn .query( " -SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages +SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes FROM oauth_access_tokens INNER JOIN users ON @@ -46,16 +71,25 @@ oauth_access_tokens.resource_owner_id = users.id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1", - &[&token], + &[&access_token], ) .expect("Hard-coded query will return Some([0 or more rows])"); if !result.is_empty() { let only_row = result.get(0); let id: i64 = only_row.get(1); + let scopes = only_row + .get::<_, String>(3) + .split(' ') + .map(|scope: &str| scope.into()) + .filter(|scope| scope != &OauthScope::Other) + .collect(); + dbg!(&scopes); let langs: Option> = only_row.get(2); info!("Granting logged-in access"); Ok(User { id, + access_token, + scopes, langs, logged_in: true, filter: Filter::None, @@ -64,6 +98,8 @@ LIMIT 1", info!("Granting public access to non-authenticated client"); Ok(User { id: -1, + access_token, + scopes: Vec::new(), langs: None, logged_in: false, filter: Filter::None, @@ -116,6 +152,8 @@ LIMIT 1", pub fn public() -> Self { User { id: -1, + access_token: String::new(), + scopes: Vec::new(), langs: None, logged_in: false, filter: Filter::None, @@ -130,16 +168,25 @@ pub enum Scope { } impl Scope { pub fn get_access_token(self) -> warp::filters::BoxedFilter<(String,)> { - let token_from_header = warp::header::header::("authorization") + let token_from_header_http_push = warp::header::header::("authorization") .map(|auth: String| auth.split(' ').nth(1).unwrap_or("invalid").to_string()); + let token_from_header_ws = + warp::header::header::("Sec-WebSocket-Protocol").map(|auth: String| auth); let token_from_query = warp::query().map(|q: query::Auth| q.access_token); + + let private_scopes = any_of!( + token_from_header_http_push, + token_from_header_ws, + token_from_query + ); + let public = warp::any().map(|| "no access token".to_string()); match self { // if they're trying to access a private scope without an access token, reject the request - Scope::Private => any_of!(token_from_query, token_from_header).boxed(), + Scope::Private => private_scopes.boxed(), // if they're trying to access a public scope without an access token, proceed - Scope::Public => any_of!(token_from_query, token_from_header, public).boxed(), + Scope::Public => any_of!(private_scopes, public).boxed(), } } }