Skip to content

Commit

Permalink
feat: support file:// URLs for snapshot locations
Browse files Browse the repository at this point in the history
LW-11112

This is very useful for live demos, where you don’t want to repeatedly
download from Google Cloud Storage because of the wait time and cost.
  • Loading branch information
michalrus committed Aug 9, 2024
1 parent 52a7beb commit f8fd719
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ mithril-infra/.terraform*
mithril-infra/terraform.tfstate*
mithril-infra/*.tfvars
justfile

result
90 changes: 66 additions & 24 deletions mithril-client/src/snapshot_downloader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ use async_trait::async_trait;
use futures::StreamExt;
use reqwest::{Response, StatusCode};
use slog::{debug, Logger};
use std::fs;
use std::path::Path;
use tokio::fs::File;
use tokio::io::AsyncReadExt;

#[cfg(test)]
use mockall::automock;
Expand Down Expand Up @@ -51,6 +54,8 @@ pub struct HttpSnapshotDownloader {
logger: Logger,
}

const FILE_SCHEME: &str = "file://";

impl HttpSnapshotDownloader {
/// Constructs a new `HttpSnapshotDownloader`.
pub fn new(feedback_sender: FeedbackSender, logger: Logger) -> MithrilResult<Self> {
Expand Down Expand Up @@ -98,7 +103,6 @@ impl SnapshotDownloader for HttpSnapshotDownloader {
)?;
}
let mut downloaded_bytes: u64 = 0;
let mut remote_stream = self.get(location).await?.bytes_stream();
let (sender, receiver) = flume::bounded(5);

let dest_dir = target_dir.to_path_buf();
Expand All @@ -107,21 +111,50 @@ impl SnapshotDownloader for HttpSnapshotDownloader {
unpacker.unpack_snapshot(receiver, compression_algorithm, &dest_dir)
});

while let Some(item) = remote_stream.next().await {
let chunk = item.with_context(|| "Download: Could not read from byte stream")?;

sender.send_async(chunk.to_vec()).await.with_context(|| {
format!("Download: could not write {} bytes to stream.", chunk.len())
})?;

downloaded_bytes += chunk.len() as u64;
self.feedback_sender
.send_event(MithrilEvent::SnapshotDownloadProgress {
download_id: download_id.to_owned(),
downloaded_bytes,
size: snapshot_size,
})
.await
if location.starts_with(FILE_SCHEME) {
// Stream the `location` directly from the local filesystem
let local_path = &location[FILE_SCHEME.len()..];

Check warning

Code scanning / clippy

stripping a prefix manually Warning

stripping a prefix manually

Check warning

Code scanning / clippy

stripping a prefix manually Warning

stripping a prefix manually
let mut file = File::open(local_path).await?;

loop {
// We can either allocate here each time, or clone a shared buffer into sender.
// A larger read buffer is faster, less context switches:
let mut buffer = vec![0; 16 * 1024 * 1024];
let bytes_read = file.read(&mut buffer).await?;
if bytes_read == 0 {
break;
}
buffer.truncate(bytes_read);
sender.send_async(buffer).await.with_context(|| {
format!("Local file read: could not write {} bytes to stream.", bytes_read)
})?;
downloaded_bytes += bytes_read as u64;
self.feedback_sender
.send_event(MithrilEvent::SnapshotDownloadProgress {
download_id: download_id.to_owned(),
downloaded_bytes,
size: snapshot_size,
})
.await
}
} else {
let mut remote_stream = self.get(location).await?.bytes_stream();
while let Some(item) = remote_stream.next().await {
let chunk = item.with_context(|| "Download: Could not read from byte stream")?;

sender.send_async(chunk.to_vec()).await.with_context(|| {
format!("Download: could not write {} bytes to stream.", chunk.len())
})?;

downloaded_bytes += chunk.len() as u64;
self.feedback_sender
.send_event(MithrilEvent::SnapshotDownloadProgress {
download_id: download_id.to_owned(),
downloaded_bytes,
size: snapshot_size,
})
.await
}
}

drop(sender); // Signal EOF
Expand All @@ -143,15 +176,24 @@ impl SnapshotDownloader for HttpSnapshotDownloader {
async fn probe(&self, location: &str) -> MithrilResult<()> {
debug!(self.logger, "HEAD Snapshot location='{location}'.");

let request_builder = self.http_client.head(location);
let response = request_builder.send().await.with_context(|| {
format!("Cannot perform a HEAD for snapshot at location='{location}'")
})?;
if location.starts_with(FILE_SCHEME) {
let local_path = &location[FILE_SCHEME.len()..];

Check warning

Code scanning / clippy

stripping a prefix manually Warning

stripping a prefix manually

Check warning

Code scanning / clippy

stripping a prefix manually Warning

stripping a prefix manually
if fs::metadata(local_path).is_ok() {
Ok(())
} else {
Err(anyhow!("Local snapshot location='{location}' not found"))
}
} else {
let request_builder = self.http_client.head(location);
let response = request_builder.send().await.with_context(|| {
format!("Cannot perform a HEAD for snapshot at location='{location}'")
})?;

match response.status() {
StatusCode::OK => Ok(()),
StatusCode::NOT_FOUND => Err(anyhow!("Snapshot location='{location} not found")),
status_code => Err(anyhow!("Unhandled error {status_code}")),
match response.status() {
StatusCode::OK => Ok(()),
StatusCode::NOT_FOUND => Err(anyhow!("Snapshot location='{location} not found")),
status_code => Err(anyhow!("Unhandled error {status_code}")),
}
}
}
}

0 comments on commit f8fd719

Please sign in to comment.