From 06b6b4820fed169cf461d65b0ba245abb6788441 Mon Sep 17 00:00:00 2001 From: Susanne Hartung Date: Wed, 31 Jul 2024 13:47:38 +0200 Subject: [PATCH] api client reauthentication --- CHANGELOG.md | 1 + src/api/mod.rs | 260 +++++++++++++++++++++++++++++-------------------- 2 files changed, 155 insertions(+), 106 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d905fc..ae1c082 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - NEXT-37303 - [BREAKING] changed `sync` command argument `--schema` `-s` to `--profile` `-p` - NEXT-37303 - [BREAKING] Fixed an issue where `row` values were always provided as strings in the deserialize script. Now they are converted into their proper types before passed to the script. +- NEXT-37313 - Implemented re-authentication for API calls to handle expired bearer tokens # v0.7.1 diff --git a/src/api/mod.rs b/src/api/mod.rs index 72e2756..3fb0823 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -5,14 +5,14 @@ pub mod filter; use crate::api::filter::{Criteria, CriteriaFilter}; use crate::config_file::Credentials; use reqwest::header::{HeaderMap, HeaderValue}; -use reqwest::{header, Client, Response, StatusCode}; +use reqwest::{header, Client, Method, Response, StatusCode}; use serde::{Deserialize, Serialize}; use serde_json::json; use std::fmt::Debug; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::time::{Duration, Instant}; use thiserror::Error; -use tokio::sync::Semaphore; +use tokio::sync::{Mutex, Semaphore}; #[derive(Debug, Clone)] pub struct SwClient { @@ -52,15 +52,13 @@ impl SwClient { }) } - pub async fn sync, T: Serialize>( + pub async fn sync, T: Serialize + Debug>( &self, entity: S, action: SyncAction, payload: &[T], ) -> Result<(), SwApiError> { let entity: String = entity.into(); - // ToDo: implement retry on auth fail - let access_token = self.access_token.lock().unwrap().clone(); let body = SyncBody { write_data: SyncOperation { entity: entity.clone(), @@ -69,31 +67,35 @@ impl SwClient { }, }; - let response = { - let _lock = self.in_flight_semaphore.acquire().await.unwrap(); - let start_instant = Instant::now(); - println!( - "sync {:?} '{}' with payload size {}", - action, - &entity, - payload.len() - ); - let res = self - .client - .post(format!("{}/api/_action/sync", self.credentials.base_url)) - .bearer_auth(access_token) - .header("single-operation", 1) - .header("indexing-behavior", "disable-indexing") - .header("sw-skip-trigger-flow", 1) - .json(&body) - .send() - .await?; - println!( - "sync request finished after {} ms", - start_instant.elapsed().as_millis() - ); - res - }; + println!( + "sync {:?} '{}' with payload size {}", + action, + &entity, + payload.len() + ); + + let mut headers = HeaderMap::new(); + headers.insert("single-operation", HeaderValue::from_static("1")); + headers.insert( + "indexing-behavior", + HeaderValue::from_static("disable-indexing"), + ); + headers.insert("sw-skip-trigger-flow", HeaderValue::from_static("1")); + + let (response, duration) = self + .handle_authenticated_request( + Method::POST, + "/api/_action/sync", + Some(&body), + Some(headers), + true, + ) + .await?; + + println!( + "sync request finished after {} ms", + duration.unwrap().as_millis() + ); if !response.status().is_success() { let status = response.status(); @@ -105,19 +107,15 @@ impl SwClient { } pub async fn entity_schema(&self) -> Result { - // ToDo: implement retry on auth fail - let access_token = self.access_token.lock().unwrap().clone(); - let response = { - let _lock = self.in_flight_semaphore.acquire().await.unwrap(); - self.client - .get(format!( - "{}/api/_info/entity-schema.json", - self.credentials.base_url - )) - .bearer_auth(access_token) - .send() - .await? - }; + let (response, _) = self + .handle_authenticated_request::<()>( + Method::GET, + "/api/_info/entity-schema.json", + None, + None, + false, + ) + .await?; if !response.status().is_success() { let status = response.status(); @@ -136,32 +134,27 @@ impl SwClient { ) -> Result { // entity needs to be provided as kebab-case instead of snake_case let entity = entity.replace('_', "-"); + let body = json!({ + "limit": 1, + "filter": filter, + "aggregations": [ + { + "name": "count", + "type": "count", + "field": "id" + } + ] + }); - // ToDo: implement retry on auth fail - let access_token = self.access_token.lock().unwrap().clone(); - - let response = { - let _lock = self.in_flight_semaphore.acquire().await.unwrap(); - self.client - .post(format!( - "{}/api/search/{}", - self.credentials.base_url, entity - )) - .bearer_auth(access_token) - .json(&json!({ - "limit": 1, - "filter": filter, - "aggregations": [ - { - "name": "count", - "type": "count", - "field": "id" - } - ] - })) - .send() - .await? - }; + let (response, _) = self + .handle_authenticated_request( + Method::POST, + &format!("/api/search/{}", entity), + Some(&body), + None, + false, + ) + .await?; if !response.status().is_success() { let status = response.status(); @@ -189,37 +182,29 @@ impl SwClient { // entity needs to be provided as kebab-case instead of snake_case let entity = entity.replace('_', "-"); - // ToDo: implement retry on auth fail - let access_token = self.access_token.lock().unwrap().clone(); - let response = { - let _lock = self.in_flight_semaphore.acquire().await.unwrap(); - let start_instant = Instant::now(); - - if let Some(limit) = criteria.limit { - println!( - "fetching page {} of '{}' with limit {}", - criteria.page, entity, limit - ); - } else { - println!("fetching page {} of '{}'", criteria.page, entity); - } - - let res = self - .client - .post(format!( - "{}/api/search/{}", - self.credentials.base_url, entity - )) - .bearer_auth(access_token) - .json(criteria) - .send() - .await?; + if let Some(limit) = criteria.limit { println!( - "search request finished after {} ms", - start_instant.elapsed().as_millis() + "fetching page {} of '{}' with limit {}", + criteria.page, entity, limit ); - res - }; + } else { + println!("fetching page {} of '{}'", criteria.page, entity); + } + + let (response, duration) = self + .handle_authenticated_request( + Method::POST, + &format!("/api/search/{}", entity), + Some(criteria), + None, + true, + ) + .await?; + + println!( + "search request finished after {} ms", + duration.unwrap().as_millis() + ); if !response.status().is_success() { let status = response.status(); @@ -261,14 +246,14 @@ impl SwClient { } pub async fn index(&self, skip: Vec) -> Result<(), SwApiError> { - let access_token = self.access_token.lock().unwrap().clone(); - - let response = self - .client - .post(format!("{}/api/_action/index", self.credentials.base_url)) - .bearer_auth(access_token) - .json(&IndexBody { skip }) - .send() + let (response, _) = self + .handle_authenticated_request( + Method::POST, + "/api/_action/index", + Some(&IndexBody { skip }), + None, + false, + ) .await?; if !response.status().is_success() { @@ -312,6 +297,69 @@ impl SwClient { }); worker_rx.await.unwrap() } + + async fn handle_authenticated_request( + &self, + method: Method, + path: &str, + body: Option<&T>, + additional_headers: Option, + measure_time: bool, + ) -> Result<(Response, Option), SwApiError> { + let url = format!("{}{}", self.credentials.base_url, path); + let mut retry_count = 0; + const MAX_RETRIES: u8 = 1; + + let mut request_builder = self.client.request(method, &url); + + if let Some(headers) = additional_headers { + request_builder = request_builder.headers(headers); + } + + if let Some(body_value) = body { + request_builder = request_builder.json(body_value); + } + + loop { + let access_token = self.access_token.lock().await.clone(); + let request = request_builder + .try_clone() + .unwrap() + .bearer_auth(&access_token); + + let _lock = self.in_flight_semaphore.acquire().await.unwrap(); + + let start_time = if measure_time { + Some(Instant::now()) + } else { + None + }; + + let response = request.send().await?; + + if response.status() == StatusCode::UNAUTHORIZED && retry_count < MAX_RETRIES { + // lock the access token + let mut access_token_guard = self.access_token.lock().await; + // compare the access token with the one we used to make the request + if *access_token_guard != access_token { + // Another thread has already re-authenticated + continue; + } + + // Perform re-authentication + let auth_response = Self::authenticate(&self.client, &self.credentials).await?; + let new_token = auth_response.access_token; + *access_token_guard = new_token; + + retry_count += 1; + continue; + } + + let duration = start_time.map(|start_time| start_time.elapsed()); + + return Ok((response, duration)); + } + } } #[derive(Debug, Serialize)]