Skip to content

Commit

Permalink
refactor: removed use of sudo + docker from cuda crate (#1517)
Browse files Browse the repository at this point in the history
  • Loading branch information
yourbuddyconner committed Sep 19, 2024
2 parents 23c97e9 + 2e8b0a8 commit 4ed9f05
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 37 deletions.
88 changes: 56 additions & 32 deletions crates/cuda/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
#[rustfmt::skip]
pub mod proto {
pub mod api;
}

use core::time::Duration;
use std::{
error::Error as StdError,
future::Future,
io::{BufReader, Read, Write},
process::{Command, Stdio},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::{Duration, Instant},
};

use crate::proto::api::ProverServiceClient;
Expand All @@ -27,6 +23,11 @@ use sp1_stark::ShardProof;
use tokio::task::block_in_place;
use twirp::{url::Url, Client};

#[rustfmt::skip]
pub mod proto {
pub mod api;
}

/// A remote client to [sp1_prover::SP1Prover] that runs inside a container.
///
/// This is currently used to provide experimental support for GPU hardware acceleration.
Expand Down Expand Up @@ -84,28 +85,31 @@ pub struct WrapRequestPayload {
impl SP1CudaProver {
/// Creates a new [SP1Prover] that runs inside a Docker container and returns a
/// [SP1ProverClient] that can be used to communicate with the container.
pub fn new() -> Self {
pub fn new() -> Result<Self, Box<dyn StdError>> {
let container_name = "sp1-gpu";
let image_name = "succinctlabs/sp1-gpu:v1.2.0-rc2";

let cleaned_up = Arc::new(AtomicBool::new(false));
let cleanup_name = container_name;
let cleanup_flag = cleaned_up.clone();

// Pull the docker image if it's not present.
Command::new("sudo")
.args(["docker", "pull", image_name])
.output()
.expect("failed to pull docker image");
// Check if Docker is available and the user has necessary permissions
if !Self::check_docker_availability()? {
return Err("Docker is not available or you don't have the necessary permissions. Please ensure Docker is installed and you are part of the docker group.".into());
}

// Start the docker container.
let rust_log_level = std::env::var("RUST_LOG").unwrap_or("none".to_string());
let mut child = Command::new("sudo")
// Pull the docker image if it's not present
if let Err(e) = Command::new("docker").args(["pull", image_name]).output() {
return Err(format!("Failed to pull Docker image: {}. Please check your internet connection and Docker permissions.", e).into());
}

// Start the docker container
let rust_log_level = std::env::var("RUST_LOG").unwrap_or_else(|_| "none".to_string());
let mut child = Command::new("docker")
.args([
"docker",
"run",
"-e",
format!("RUST_LOG={}", rust_log_level).as_str(),
&format!("RUST_LOG={}", rust_log_level),
"-p",
"3000:3000",
"--rm",
Expand All @@ -118,7 +122,7 @@ impl SP1CudaProver {
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.expect("failed to start Docker container");
.map_err(|e| format!("Failed to start Docker container: {}. Please check your Docker installation and permissions.", e))?;

let stdout = child.stdout.take().unwrap();
std::thread::spawn(move || {
Expand All @@ -136,7 +140,7 @@ impl SP1CudaProver {
}
});

// Kill the container on control-c.
// Kill the container on control-c
ctrlc::set_handler(move || {
tracing::debug!("received Ctrl+C, cleaning up...");
if !cleanup_flag.load(Ordering::SeqCst) {
Expand All @@ -147,37 +151,57 @@ impl SP1CudaProver {
})
.unwrap();

// Wait a few seconds for the container to start.
// Wait a few seconds for the container to start
std::thread::sleep(Duration::from_secs(2));

// Check if the container is ready.
// Check if the container is ready
let client = Client::from_base_url(
Url::parse("http://localhost:3000/twirp/").expect("failed to parse url"),
)
.expect("failed to create client");

let timeout = Duration::from_secs(60); // Set a 60-second timeout
let start_time = Instant::now();

block_on(async {
tracing::info!("waiting for proving server to be ready");
loop {
if start_time.elapsed() > timeout {
return Err("Timeout: proving server did not become ready within 60 seconds. Please check your Docker container and network settings.".to_string());
}

let request = ReadyRequest {};
let response = client.ready(request).await;
if let Ok(response) = response {
if response.ready {
match client.ready(request).await {
Ok(response) if response.ready => {
tracing::info!("proving server is ready");
break;
}
Ok(_) => {
tracing::info!("proving server is not ready, retrying...");
}
Err(e) => {
tracing::warn!("Error checking server readiness: {}", e);
}
}
tracing::info!("proving server is not ready, retrying...");
std::thread::sleep(Duration::from_secs(2));
tokio::time::sleep(Duration::from_secs(2)).await;
}
});
Ok(())
})?;

SP1CudaProver {
Ok(SP1CudaProver {
client: Client::from_base_url(
Url::parse("http://localhost:3000/twirp/").expect("failed to parse url"),
)
.expect("failed to create client"),
container_name: container_name.to_string(),
cleaned_up: cleaned_up.clone(),
})
}

fn check_docker_availability() -> Result<bool, Box<dyn std::error::Error>> {
match Command::new("docker").arg("version").output() {
Ok(output) => Ok(output.status.success()),
Err(_) => Ok(false),
}
}

Expand Down Expand Up @@ -258,7 +282,7 @@ impl SP1CudaProver {

impl Default for SP1CudaProver {
fn default() -> Self {
Self::new()
Self::new().expect("Failed to create SP1CudaProver")
}
}

Expand All @@ -274,8 +298,8 @@ impl Drop for SP1CudaProver {

/// Cleans up the a docker container with the given name.
fn cleanup_container(container_name: &str) {
if let Err(e) = Command::new("sudo").args(["docker", "rm", "-f", container_name]).output() {
eprintln!("failed to remove container: {}", e);
if let Err(e) = Command::new("docker").args(["rm", "-f", container_name]).output() {
eprintln!("Failed to remove container: {}. You may need to manually remove it using 'docker rm -f {}'", e, container_name);
}
}

Expand Down Expand Up @@ -313,7 +337,7 @@ mod tests {
setup_logger();

let prover = SP1Prover::<DefaultProverComponents>::new();
let client = SP1CudaProver::new();
let client = SP1CudaProver::new().expect("Failed to create SP1CudaProver");
let (pk, vk) = prover.setup(FIBONACCI_ELF);

println!("proving core");
Expand Down
2 changes: 1 addition & 1 deletion crates/sdk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl ProverClient {
#[cfg(not(feature = "cuda"))]
prover: Box::new(CpuProver::new()),
#[cfg(feature = "cuda")]
prover: Box::new(CudaProver::new()),
prover: Box::new(CudaProver::new(SP1Prover::new())),
},
"network" => {
cfg_if! {
Expand Down
7 changes: 3 additions & 4 deletions crates/sdk/src/provers/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ pub struct CudaProver {

impl CudaProver {
/// Creates a new [CudaProver].
pub fn new() -> Self {
let prover = SP1Prover::new();
pub fn new(prover: SP1Prover) -> Self {
let cuda_prover = SP1CudaProver::new();
Self { prover, cuda_prover }
Self { prover, cuda_prover: cuda_prover.expect("Failed to initialize CUDA prover") }
}
}

Expand Down Expand Up @@ -102,6 +101,6 @@ impl Prover<DefaultProverComponents> for CudaProver {

impl Default for CudaProver {
fn default() -> Self {
Self::new()
Self::new(SP1Prover::new())
}
}

0 comments on commit 4ed9f05

Please sign in to comment.