From e10f1b25ba3a138174c2edebd278726af379cdb1 Mon Sep 17 00:00:00 2001 From: Susanne Hartung Date: Wed, 31 Jul 2024 13:47:38 +0200 Subject: [PATCH] api client reauthentication --- CHANGELOG.md | 5 +- src/api/mod.rs | 233 +++++++++++++++++++++++++++---------------------- 2 files changed, 132 insertions(+), 106 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ff32fb..5bf2d8f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,8 +6,9 @@ - NEXT-37316 - Added `index` command, to trigger the indexing of the Shopware shop - NEXT-37315 - Trigger indexing of the shop by default at the end of an import (can be disabled with flag `-d` `--disable-index`) - NEXT-37303 - [BREAKING] changed `sync` command argument `--schema` `-s` to `--profile` `-p` -- NEXT-37303 - [BREAKING] Fixed an issue where `row` values were always provided as strings in the deserialize script. - Now they are converted into their proper types before passed to the script. +- NEXT-37303 - [BREAKING] Fixed an issue where `row` values were always provided as strings in the deserialize script. +Now they are converted into their proper types before passed to the script. +- NEXT-37313 - Implemented re-authentication for API calls to handle expired bearer tokens # v0.7.1 diff --git a/src/api/mod.rs b/src/api/mod.rs index be31208..94166d5 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -4,9 +4,9 @@ pub mod filter; use crate::api::filter::{Criteria, CriteriaFilter}; use crate::config_file::Credentials; -use reqwest::blocking::{Client, Response}; +use reqwest::blocking::{Client, RequestBuilder, Response}; use reqwest::header::{HeaderMap, HeaderValue}; -use reqwest::{header, StatusCode}; +use reqwest::{header, Method, StatusCode}; use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::HashMap; @@ -52,7 +52,6 @@ impl SwClient { let total = self.get_total("language", &[])?; - let access_token = self.access_token.lock().unwrap().clone(); while language_list.len() < total as usize { let mut criteria = Criteria { page, @@ -63,13 +62,15 @@ impl SwClient { criteria.add_association("locale"); - let response = { - self.client - .post(format!("{}/api/search/language", self.credentials.base_url)) - .bearer_auth(&access_token) - .json(&criteria) - .send()? - }; + let request_builder = self + .client + .request( + Method::POST, + format!("{}/api/search/language", self.credentials.base_url), + ) + .json(&criteria); + + let response = self.handle_authenticated_request(request_builder)?; let value: LanguageLocaleSearchResponse = Self::deserialize(response)?; for item in value.data { @@ -84,15 +85,13 @@ impl SwClient { }) } - pub fn sync, T: Serialize>( + pub fn sync, T: Serialize + Debug>( &self, entity: S, action: SyncAction, payload: &[T], ) -> Result<(), SwApiError> { let entity: String = entity.into(); - // ToDo: implement retry on auth fail - let access_token = self.access_token.lock().unwrap().clone(); let body = SyncBody { write_data: SyncOperation { entity: entity.clone(), @@ -101,29 +100,25 @@ impl SwClient { }, }; - let response = { - let start_instant = Instant::now(); - println!( - "sync {:?} '{}' with payload size {}", - action, - &entity, - payload.len() - ); - let res = self - .client - .post(format!("{}/api/_action/sync", self.credentials.base_url)) - .bearer_auth(access_token) - .header("single-operation", 1) - .header("indexing-behavior", "disable-indexing") - .header("sw-skip-trigger-flow", 1) - .json(&body) - .send()?; - println!( - "sync request finished after {} ms", - start_instant.elapsed().as_millis() - ); - res - }; + println!( + "sync {:?} '{}' with payload size {}", + action, + &entity, + payload.len() + ); + + let request_builder = self + .client + .request( + Method::POST, + format!("{}/api/_action/sync", self.credentials.base_url), + ) + .header("single-operation", "1") + .header("indexing-behavior", "disable-indexing") + .header("sw-skip-trigger-flow", "1") + .json(&body); + + let response = self.handle_authenticated_request(request_builder)?; if !response.status().is_success() { let status = response.status(); @@ -135,17 +130,12 @@ impl SwClient { } pub fn entity_schema(&self) -> Result { - // ToDo: implement retry on auth fail - let access_token = self.access_token.lock().unwrap().clone(); - let response = { - self.client - .get(format!( - "{}/api/_info/entity-schema.json", - self.credentials.base_url - )) - .bearer_auth(access_token) - .send()? - }; + let request_builder = self.client.request( + Method::GET, + format!("{}/api/_info/entity-schema.json", self.credentials.base_url), + ); + + let response = self.handle_authenticated_request(request_builder)?; if !response.status().is_success() { let status = response.status(); @@ -160,30 +150,27 @@ impl SwClient { pub fn get_total(&self, entity: &str, filter: &[CriteriaFilter]) -> Result { // entity needs to be provided as kebab-case instead of snake_case let entity = entity.replace('_', "-"); + let body = json!({ + "limit": 1, + "filter": filter, + "aggregations": [ + { + "name": "count", + "type": "count", + "field": "id" + } + ] + }); - // ToDo: implement retry on auth fail - let access_token = self.access_token.lock().unwrap().clone(); - - let response = { - self.client - .post(format!( - "{}/api/search/{}", - self.credentials.base_url, entity - )) - .bearer_auth(access_token) - .json(&json!({ - "limit": 1, - "filter": filter, - "aggregations": [ - { - "name": "count", - "type": "count", - "field": "id" - } - ] - })) - .send()? - }; + let request_builder = self + .client + .request( + Method::POST, + format!("{}/api/search/{}", self.credentials.base_url, entity), + ) + .json(&body); + + let response = self.handle_authenticated_request(request_builder)?; if !response.status().is_success() { let status = response.status(); @@ -207,35 +194,24 @@ impl SwClient { // entity needs to be provided as kebab-case instead of snake_case let entity = entity.replace('_', "-"); - // ToDo: implement retry on auth fail - let access_token = self.access_token.lock().unwrap().clone(); - let response = { - let start_instant = Instant::now(); - - if let Some(limit) = criteria.limit { - println!( - "fetching page {} of '{}' with limit {}", - criteria.page, entity, limit - ); - } else { - println!("fetching page {} of '{}'", criteria.page, entity); - } - - let res = self - .client - .post(format!( - "{}/api/search/{}", - self.credentials.base_url, entity - )) - .bearer_auth(access_token) - .json(criteria) - .send()?; + if let Some(limit) = criteria.limit { println!( - "search request finished after {} ms", - start_instant.elapsed().as_millis() + "fetching page {} of '{}' with limit {}", + criteria.page, entity, limit ); - res - }; + } else { + println!("fetching page {} of '{}'", criteria.page, entity); + } + + let request_builder = self + .client + .request( + Method::POST, + format!("{}/api/search/{}", self.credentials.base_url, entity), + ) + .json(criteria); + + let response = self.handle_authenticated_request(request_builder)?; if !response.status().is_success() { let status = response.status(); @@ -276,14 +252,15 @@ impl SwClient { } pub fn index(&self, skip: Vec) -> Result<(), SwApiError> { - let access_token = self.access_token.lock().unwrap().clone(); - - let response = self + let request_builder = self .client - .post(format!("{}/api/_action/index", self.credentials.base_url)) - .bearer_auth(access_token) - .json(&IndexBody { skip }) - .send()?; + .request( + Method::POST, + format!("{}/api/_action/index", self.credentials.base_url), + ) + .json(&IndexBody { skip }); + + let response = self.handle_authenticated_request(request_builder)?; if !response.status().is_success() { let status = response.status(); @@ -320,6 +297,54 @@ impl SwClient { }; result } + + fn handle_authenticated_request( + &self, + request_builder: RequestBuilder, + ) -> Result { + let mut retry_count = 0; + const MAX_RETRIES: u8 = 1; + let binding = request_builder.try_clone().unwrap().build().unwrap(); + let path = binding.url().path(); + + loop { + let access_token = self.access_token.lock().unwrap().clone(); + let request = request_builder + .try_clone() + .unwrap() + .bearer_auth(&access_token); + + let start_time = Instant::now(); + let response = request.send()?; + + if response.status() == StatusCode::UNAUTHORIZED && retry_count < MAX_RETRIES { + // lock the access token + let mut access_token_guard = self.access_token.lock().unwrap(); + // compare the access token with the one we used to make the request + if *access_token_guard != access_token { + // Another thread has already re-authenticated + continue; + } + + // Perform re-authentication + let auth_response = Self::authenticate(&self.client, &self.credentials)?; + let new_token = auth_response.access_token; + *access_token_guard = new_token; + + retry_count += 1; + continue; + } + + let duration = start_time.elapsed(); + println!( + "{} request finished after {} ms", + path, + duration.as_millis() + ); + + return Ok(response); + } + } } #[derive(Debug, Serialize)]