Skip to content

Commit

Permalink
api client reauthentication
Browse files Browse the repository at this point in the history
  • Loading branch information
ennasus4sun committed Aug 8, 2024
1 parent 4bcada5 commit 06b6b48
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 106 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- NEXT-37303 - [BREAKING] changed `sync` command argument `--schema` `-s` to `--profile` `-p`
- NEXT-37303 - [BREAKING] Fixed an issue where `row` values were always provided as strings in the deserialize script.
Now they are converted into their proper types before passed to the script.
- NEXT-37313 - Implemented re-authentication for API calls to handle expired bearer tokens

# v0.7.1

Expand Down
260 changes: 154 additions & 106 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ pub mod filter;
use crate::api::filter::{Criteria, CriteriaFilter};
use crate::config_file::Credentials;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::{header, Client, Response, StatusCode};
use reqwest::{header, Client, Method, Response, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::fmt::Debug;
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use std::time::{Duration, Instant};
use thiserror::Error;
use tokio::sync::Semaphore;
use tokio::sync::{Mutex, Semaphore};

#[derive(Debug, Clone)]
pub struct SwClient {
Expand Down Expand Up @@ -52,15 +52,13 @@ impl SwClient {
})
}

pub async fn sync<S: Into<String>, T: Serialize>(
pub async fn sync<S: Into<String>, T: Serialize + Debug>(
&self,
entity: S,
action: SyncAction,
payload: &[T],
) -> Result<(), SwApiError> {
let entity: String = entity.into();
// ToDo: implement retry on auth fail
let access_token = self.access_token.lock().unwrap().clone();
let body = SyncBody {
write_data: SyncOperation {
entity: entity.clone(),
Expand All @@ -69,31 +67,35 @@ impl SwClient {
},
};

let response = {
let _lock = self.in_flight_semaphore.acquire().await.unwrap();
let start_instant = Instant::now();
println!(
"sync {:?} '{}' with payload size {}",
action,
&entity,
payload.len()
);
let res = self
.client
.post(format!("{}/api/_action/sync", self.credentials.base_url))
.bearer_auth(access_token)
.header("single-operation", 1)
.header("indexing-behavior", "disable-indexing")
.header("sw-skip-trigger-flow", 1)
.json(&body)
.send()
.await?;
println!(
"sync request finished after {} ms",
start_instant.elapsed().as_millis()
);
res
};
println!(
"sync {:?} '{}' with payload size {}",
action,
&entity,
payload.len()
);

let mut headers = HeaderMap::new();
headers.insert("single-operation", HeaderValue::from_static("1"));
headers.insert(
"indexing-behavior",
HeaderValue::from_static("disable-indexing"),
);
headers.insert("sw-skip-trigger-flow", HeaderValue::from_static("1"));

let (response, duration) = self
.handle_authenticated_request(
Method::POST,
"/api/_action/sync",
Some(&body),
Some(headers),
true,
)
.await?;

println!(
"sync request finished after {} ms",
duration.unwrap().as_millis()
);

if !response.status().is_success() {
let status = response.status();
Expand All @@ -105,19 +107,15 @@ impl SwClient {
}

pub async fn entity_schema(&self) -> Result<Entity, SwApiError> {
// ToDo: implement retry on auth fail
let access_token = self.access_token.lock().unwrap().clone();
let response = {
let _lock = self.in_flight_semaphore.acquire().await.unwrap();
self.client
.get(format!(
"{}/api/_info/entity-schema.json",
self.credentials.base_url
))
.bearer_auth(access_token)
.send()
.await?
};
let (response, _) = self
.handle_authenticated_request::<()>(
Method::GET,
"/api/_info/entity-schema.json",
None,
None,
false,
)
.await?;

if !response.status().is_success() {
let status = response.status();
Expand All @@ -136,32 +134,27 @@ impl SwClient {
) -> Result<u64, SwApiError> {
// entity needs to be provided as kebab-case instead of snake_case
let entity = entity.replace('_', "-");
let body = json!({
"limit": 1,
"filter": filter,
"aggregations": [
{
"name": "count",
"type": "count",
"field": "id"
}
]
});

// ToDo: implement retry on auth fail
let access_token = self.access_token.lock().unwrap().clone();

let response = {
let _lock = self.in_flight_semaphore.acquire().await.unwrap();
self.client
.post(format!(
"{}/api/search/{}",
self.credentials.base_url, entity
))
.bearer_auth(access_token)
.json(&json!({
"limit": 1,
"filter": filter,
"aggregations": [
{
"name": "count",
"type": "count",
"field": "id"
}
]
}))
.send()
.await?
};
let (response, _) = self
.handle_authenticated_request(
Method::POST,
&format!("/api/search/{}", entity),
Some(&body),
None,
false,
)
.await?;

if !response.status().is_success() {
let status = response.status();
Expand Down Expand Up @@ -189,37 +182,29 @@ impl SwClient {
// entity needs to be provided as kebab-case instead of snake_case
let entity = entity.replace('_', "-");

// ToDo: implement retry on auth fail
let access_token = self.access_token.lock().unwrap().clone();
let response = {
let _lock = self.in_flight_semaphore.acquire().await.unwrap();
let start_instant = Instant::now();

if let Some(limit) = criteria.limit {
println!(
"fetching page {} of '{}' with limit {}",
criteria.page, entity, limit
);
} else {
println!("fetching page {} of '{}'", criteria.page, entity);
}

let res = self
.client
.post(format!(
"{}/api/search/{}",
self.credentials.base_url, entity
))
.bearer_auth(access_token)
.json(criteria)
.send()
.await?;
if let Some(limit) = criteria.limit {
println!(
"search request finished after {} ms",
start_instant.elapsed().as_millis()
"fetching page {} of '{}' with limit {}",
criteria.page, entity, limit
);
res
};
} else {
println!("fetching page {} of '{}'", criteria.page, entity);
}

let (response, duration) = self
.handle_authenticated_request(
Method::POST,
&format!("/api/search/{}", entity),
Some(criteria),
None,
true,
)
.await?;

println!(
"search request finished after {} ms",
duration.unwrap().as_millis()
);

if !response.status().is_success() {
let status = response.status();
Expand Down Expand Up @@ -261,14 +246,14 @@ impl SwClient {
}

pub async fn index(&self, skip: Vec<String>) -> Result<(), SwApiError> {
let access_token = self.access_token.lock().unwrap().clone();

let response = self
.client
.post(format!("{}/api/_action/index", self.credentials.base_url))
.bearer_auth(access_token)
.json(&IndexBody { skip })
.send()
let (response, _) = self
.handle_authenticated_request(
Method::POST,
"/api/_action/index",
Some(&IndexBody { skip }),
None,
false,
)
.await?;

if !response.status().is_success() {
Expand Down Expand Up @@ -312,6 +297,69 @@ impl SwClient {
});
worker_rx.await.unwrap()
}

async fn handle_authenticated_request<T: Serialize>(
&self,
method: Method,
path: &str,
body: Option<&T>,
additional_headers: Option<HeaderMap>,
measure_time: bool,
) -> Result<(Response, Option<Duration>), SwApiError> {
let url = format!("{}{}", self.credentials.base_url, path);
let mut retry_count = 0;
const MAX_RETRIES: u8 = 1;

let mut request_builder = self.client.request(method, &url);

if let Some(headers) = additional_headers {
request_builder = request_builder.headers(headers);
}

if let Some(body_value) = body {
request_builder = request_builder.json(body_value);
}

loop {
let access_token = self.access_token.lock().await.clone();
let request = request_builder
.try_clone()
.unwrap()
.bearer_auth(&access_token);

let _lock = self.in_flight_semaphore.acquire().await.unwrap();

let start_time = if measure_time {
Some(Instant::now())
} else {
None
};

let response = request.send().await?;

if response.status() == StatusCode::UNAUTHORIZED && retry_count < MAX_RETRIES {
// lock the access token
let mut access_token_guard = self.access_token.lock().await;
// compare the access token with the one we used to make the request
if *access_token_guard != access_token {
// Another thread has already re-authenticated
continue;
}

// Perform re-authentication
let auth_response = Self::authenticate(&self.client, &self.credentials).await?;
let new_token = auth_response.access_token;
*access_token_guard = new_token;

retry_count += 1;
continue;
}

let duration = start_time.map(|start_time| start_time.elapsed());

return Ok((response, duration));
}
}
}

#[derive(Debug, Serialize)]
Expand Down

0 comments on commit 06b6b48

Please sign in to comment.