Skip to content

Commit

Permalink
Make pull_timeline work with auth enabled.
Browse files Browse the repository at this point in the history
- Make safekeeper read SAFEKEEPER_AUTH_TOKEN env variable with JWT
  token to connect to other safekeepers.
- Set it in neon_local when auth is enabled.
- Create simple rust http client supporting it, and use it in pull_timeline
  implementation.
- Enable auth in all pull_timeline tests.
- Make sk http_client() by default generate safekeeper wide token, it makes
  easier enabling auth in all tests by default.
  • Loading branch information
arssher committed Jun 18, 2024
1 parent 29a41fc commit 4feb6ba
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 41 deletions.
15 changes: 14 additions & 1 deletion control_plane/src/safekeeper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use camino::Utf8PathBuf;
use postgres_connection::PgConnectionConfig;
use reqwest::{IntoUrl, Method};
use thiserror::Error;
use utils::auth::{Claims, Scope};
use utils::{http::error::HttpErrorBody, id::NodeId};

use crate::{
Expand Down Expand Up @@ -197,7 +198,7 @@ impl SafekeeperNode {
&datadir,
&self.env.safekeeper_bin(),
&args,
[],
self.safekeeper_env_variables()?,
background_process::InitialPidFile::Expect(self.pid_file()),
|| async {
match self.check_status().await {
Expand All @@ -210,6 +211,18 @@ impl SafekeeperNode {
.await
}

fn safekeeper_env_variables(&self) -> anyhow::Result<Vec<(String, String)>> {
// Generate a token to connect from safekeeper to peers
if self.conf.auth_enabled {
let token = self
.env
.generate_auth_token(&Claims::new(None, Scope::SafekeeperData))?;
Ok(vec![("SAFEKEEPER_AUTH_TOKEN".to_owned(), token)])
} else {
Ok(Vec::new())
}
}

///
/// Stop the server.
///
Expand Down
19 changes: 19 additions & 0 deletions safekeeper/src/bin/safekeeper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ use tokio::runtime::Handle;
use tokio::signal::unix::{signal, SignalKind};
use tokio::task::JoinError;
use toml_edit::Document;
use utils::logging::SecretString;

use std::env::{var, VarError};
use std::fs::{self, File};
use std::io::{ErrorKind, Write};
use std::str::FromStr;
Expand Down Expand Up @@ -287,6 +289,22 @@ async fn main() -> anyhow::Result<()> {
}
};

// Load JWT auth token to connect to other safekeepers for pull_timeline.
let sk_auth_token = match var("SAFEKEEPER_AUTH_TOKEN") {
Ok(v) => {
info!("loaded JWT token for authentication with safekeepers");
Some(SecretString::from(v))
}
Err(VarError::NotPresent) => {
info!("no JWT token for authentication with safekeepers detected");
None
}
Err(_) => {
warn!("JWT token for authentication with safekeepers is not unicode");
None
}
};

let conf = SafeKeeperConf {
workdir,
my_id: id,
Expand All @@ -307,6 +325,7 @@ async fn main() -> anyhow::Result<()> {
pg_auth,
pg_tenant_only_auth,
http_auth,
sk_auth_token,
current_thread_runtime: args.current_thread_runtime,
walsenders_keep_horizon: args.walsenders_keep_horizon,
partial_backup_enabled: args.partial_backup_enabled,
Expand Down
139 changes: 139 additions & 0 deletions safekeeper/src/http/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
//! Safekeeper http client.
//!
//! Partially copied from pageserver client; some parts might be better to be
//! united.
//!
//! It would be also good to move it out to separate crate, but this needs
//! duplication of internal-but-reported structs like WalSenderState, ServerInfo
//! etc.
use reqwest::{IntoUrl, Method, StatusCode};
use utils::{
http::error::HttpErrorBody,
id::{TenantId, TimelineId},
logging::SecretString,
};

use super::routes::TimelineStatus;

#[derive(Debug, Clone)]
pub struct Client {
mgmt_api_endpoint: String,
authorization_header: Option<SecretString>,
client: reqwest::Client,
}

#[derive(thiserror::Error, Debug)]
pub enum Error {
/// Failed to receive body (reqwest error).
#[error("receive body: {0}")]
ReceiveBody(reqwest::Error),

/// Status is not ok, but failed to parse body as `HttpErrorBody`.
#[error("receive error body: {0}")]
ReceiveErrorBody(String),

/// Status is not ok; parsed error in body as `HttpErrorBody`.
#[error("safekeeper API: {1}")]
ApiError(StatusCode, String),
}

pub type Result<T> = std::result::Result<T, Error>;

pub trait ResponseErrorMessageExt: Sized {
fn error_from_body(self) -> impl std::future::Future<Output = Result<Self>> + Send;
}

/// If status is not ok, try to extract error message from the body.
impl ResponseErrorMessageExt for reqwest::Response {
async fn error_from_body(self) -> Result<Self> {
let status = self.status();
if !(status.is_client_error() || status.is_server_error()) {
return Ok(self);
}

let url = self.url().to_owned();
Err(match self.json::<HttpErrorBody>().await {
Ok(HttpErrorBody { msg }) => Error::ApiError(status, msg),
Err(_) => {
Error::ReceiveErrorBody(format!("http error ({}) at {}.", status.as_u16(), url))
}
})
}
}

impl Client {
pub fn new(mgmt_api_endpoint: String, jwt: Option<SecretString>) -> Self {
Self::from_client(reqwest::Client::new(), mgmt_api_endpoint, jwt)
}

pub fn from_client(
client: reqwest::Client,
mgmt_api_endpoint: String,
jwt: Option<SecretString>,
) -> Self {
Self {
mgmt_api_endpoint,
authorization_header: jwt
.map(|jwt| SecretString::from(format!("Bearer {}", jwt.get_contents()))),
client,
}
}

pub async fn timeline_status(
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
) -> Result<TimelineStatus> {
let uri = format!(
"{}/v1/tenant/{}/timeline/{}",
self.mgmt_api_endpoint, tenant_id, timeline_id
);
let resp = self.get(&uri).await?;
resp.json().await.map_err(Error::ReceiveBody)
}

pub async fn snapshot(
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
) -> Result<reqwest::Response> {
let uri = format!(
"{}/v1/tenant/{}/timeline/{}/snapshot",
self.mgmt_api_endpoint, tenant_id, timeline_id
);
self.get(&uri).await
}

async fn get<U: IntoUrl>(&self, uri: U) -> Result<reqwest::Response> {
self.request(Method::GET, uri, ()).await
}

/// Send the request and check that the status code is good.
async fn request<B: serde::Serialize, U: reqwest::IntoUrl>(
&self,
method: Method,
uri: U,
body: B,
) -> Result<reqwest::Response> {
let res = self.request_noerror(method, uri, body).await?;
let response = res.error_from_body().await?;
Ok(response)
}

/// Just send the request.
async fn request_noerror<B: serde::Serialize, U: reqwest::IntoUrl>(
&self,
method: Method,
uri: U,
body: B,
) -> Result<reqwest::Response> {
let req = self.client.request(method, uri);
let req = if let Some(value) = &self.authorization_header {
req.header(reqwest::header::AUTHORIZATION, value.get_contents())
} else {
req
};
req.json(&body).send().await.map_err(Error::ReceiveBody)
}
}
1 change: 1 addition & 0 deletions safekeeper/src/http/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod client;
pub mod routes;
pub use routes::make_router;

Expand Down
3 changes: 2 additions & 1 deletion safekeeper/src/http/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,9 @@ async fn timeline_pull_handler(mut request: Request<Body>) -> Result<Response<Bo
check_permission(&request, None)?;

let data: pull_timeline::Request = json_request(&mut request).await?;
let conf = get_conf(&request);

let resp = pull_timeline::handle_request(data)
let resp = pull_timeline::handle_request(data, conf.sk_auth_token.clone())
.await
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, resp)
Expand Down
5 changes: 4 additions & 1 deletion safekeeper/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use tokio::runtime::Runtime;
use std::time::Duration;
use storage_broker::Uri;

use utils::{auth::SwappableJwtAuth, id::NodeId};
use utils::{auth::SwappableJwtAuth, id::NodeId, logging::SecretString};

mod auth;
pub mod broker;
Expand Down Expand Up @@ -78,6 +78,8 @@ pub struct SafeKeeperConf {
pub pg_auth: Option<Arc<JwtAuth>>,
pub pg_tenant_only_auth: Option<Arc<JwtAuth>>,
pub http_auth: Option<Arc<SwappableJwtAuth>>,
/// JWT token to connect to other safekeepers with.
pub sk_auth_token: Option<SecretString>,
pub current_thread_runtime: bool,
pub walsenders_keep_horizon: bool,
pub partial_backup_enabled: bool,
Expand Down Expand Up @@ -114,6 +116,7 @@ impl SafeKeeperConf {
pg_auth: None,
pg_tenant_only_auth: None,
http_auth: None,
sk_auth_token: None,
heartbeat_timeout: Duration::new(5, 0),
max_offloader_lag_bytes: defaults::DEFAULT_MAX_OFFLOADER_LAG_BYTES,
current_thread_runtime: false,
Expand Down
60 changes: 33 additions & 27 deletions safekeeper/src/pull_timeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ use tracing::{error, info, instrument};
use crate::{
control_file::{self, CONTROL_FILE_NAME},
debug_dump,
http::routes::TimelineStatus,
http::{
client::{self, Client},
routes::TimelineStatus,
},
safekeeper::Term,
timeline::{get_tenant_dir, get_timeline_dir, FullAccessTimeline, Timeline, TimelineError},
wal_storage::{self, open_wal_file, Storage},
Expand All @@ -36,6 +39,7 @@ use crate::{
use utils::{
crashsafe::{durable_rename, fsync_async_opt},
id::{TenantId, TenantTimelineId, TimelineId},
logging::SecretString,
lsn::Lsn,
pausable_failpoint,
};
Expand Down Expand Up @@ -163,6 +167,11 @@ impl FullAccessTimeline {

// We need to stream since the oldest segment someone (s3 or pageserver)
// still needs. This duplicates calc_horizon_lsn logic.
//
// We know that WAL wasn't removed up to this point because it cannot be
// removed further than `backup_lsn`. Since we're holding shared_state
// lock and setting `wal_removal_on_hold` later, it guarantees that WAL
// won't be removed until we're done.
let from_lsn = min(
shared_state.sk.state.remote_consistent_lsn,
shared_state.sk.state.backup_lsn,
Expand Down Expand Up @@ -255,7 +264,10 @@ pub struct DebugDumpResponse {
}

/// Find the most advanced safekeeper and pull timeline from it.
pub async fn handle_request(request: Request) -> Result<Response> {
pub async fn handle_request(
request: Request,
sk_auth_token: Option<SecretString>,
) -> Result<Response> {
let existing_tli = GlobalTimelines::get(TenantTimelineId::new(
request.tenant_id,
request.timeline_id,
Expand All @@ -264,26 +276,22 @@ pub async fn handle_request(request: Request) -> Result<Response> {
bail!("Timeline {} already exists", request.timeline_id);
}

let client = reqwest::Client::new();
let http_hosts = request.http_hosts.clone();

// Send request to /v1/tenant/:tenant_id/timeline/:timeline_id
let responses = futures::future::join_all(http_hosts.iter().map(|url| {
let url = format!(
"{}/v1/tenant/{}/timeline/{}",
url, request.tenant_id, request.timeline_id
);
client.get(url).send()
}))
.await;
// Figure out statuses of potential donors.
let responses: Vec<Result<TimelineStatus, client::Error>> =
futures::future::join_all(http_hosts.iter().map(|url| async {
let cclient = Client::new(url.clone(), sk_auth_token.clone());
let info = cclient
.timeline_status(request.tenant_id, request.timeline_id)
.await?;
Ok(info)
}))
.await;

let mut statuses = Vec::new();
for (i, response) in responses.into_iter().enumerate() {
let response = response.context(format!("fetching status from {}", http_hosts[i]))?;
response
.error_for_status_ref()
.context(format!("checking status from {}", http_hosts[i]))?;
let status: crate::http::routes::TimelineStatus = response.json().await?;
let status = response.context(format!("fetching status from {}", http_hosts[i]))?;
statuses.push((status, i));
}

Expand All @@ -303,10 +311,14 @@ pub async fn handle_request(request: Request) -> Result<Response> {
assert!(status.tenant_id == request.tenant_id);
assert!(status.timeline_id == request.timeline_id);

pull_timeline(status, safekeeper_host).await
pull_timeline(status, safekeeper_host, sk_auth_token).await
}

async fn pull_timeline(status: TimelineStatus, host: String) -> Result<Response> {
async fn pull_timeline(
status: TimelineStatus,
host: String,
sk_auth_token: Option<SecretString>,
) -> Result<Response> {
let ttid = TenantTimelineId::new(status.tenant_id, status.timeline_id);
info!(
"pulling timeline {} from safekeeper {}, commit_lsn={}, flush_lsn={}, term={}, epoch={}",
Expand All @@ -322,17 +334,11 @@ async fn pull_timeline(status: TimelineStatus, host: String) -> Result<Response>

let (_tmp_dir, tli_dir_path) = create_temp_timeline_dir(conf, ttid).await?;

let client = reqwest::Client::new();

let client = Client::new(host.clone(), sk_auth_token.clone());
// Request stream with basebackup archive.
let bb_resp = client
.get(format!(
"{}/v1/tenant/{}/timeline/{}/snapshot",
host, status.tenant_id, status.timeline_id
))
.send()
.snapshot(status.tenant_id, status.timeline_id)
.await?;
bb_resp.error_for_status_ref()?;

// Make Stream of Bytes from it...
let bb_stream = bb_resp.bytes_stream().map_err(std::io::Error::other);
Expand Down
1 change: 1 addition & 0 deletions safekeeper/tests/walproposer_sim/safekeeper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ pub fn run_server(os: NodeOs, disk: Arc<SafekeeperDisk>) -> Result<()> {
pg_auth: None,
pg_tenant_only_auth: None,
http_auth: None,
sk_auth_token: None,
current_thread_runtime: false,
walsenders_keep_horizon: false,
partial_backup_enabled: false,
Expand Down
Loading

1 comment on commit 4feb6ba

@github-actions
Copy link

@github-actions github-actions bot commented on 4feb6ba Jun 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3310 tests run: 3184 passed, 0 failed, 126 skipped (full report)


Flaky tests (1)

Postgres 16

  • test_subscriber_restart: debug

Code coverage* (full report)

  • functions: 32.4% (6831 of 21069 functions)
  • lines: 49.9% (53213 of 106597 lines)

* collected from Rust tests only


The comment gets automatically updated with the latest test results
4feb6ba at 2024-06-18T15:22:18.964Z :recycle:

Please sign in to comment.