Skip to content

Commit

Permalink
feat: support CUDA
Browse files Browse the repository at this point in the history
  • Loading branch information
Rorical committed Jul 7, 2023
1 parent 2410ce4 commit 4e5db87
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 6 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:
- '*'
env:
ORT_STRATEGY: system
ORT_LIB_LOCATION: ./onnxruntime-linux-x64-1.14.1/lib/
ORT_LIB_LOCATION: ./onnxruntime-linux-x64-gpu-1.14.1/lib/
jobs:
build:
runs-on: ubuntu-latest
Expand All @@ -21,7 +21,7 @@ jobs:
- name: Install requirement
run: sudo apt-get update && sudo apt-get install -y protobuf-compiler
- name: Download Onnx runtime
run: wget https://github.com/microsoft/onnxruntime/releases/download/v1.14.1/onnxruntime-linux-x64-1.14.1.tgz && tar -xzvf onnxruntime-linux-x64-1.14.1.tgz
run: wget https://github.com/microsoft/onnxruntime/releases/download/v1.14.1/onnxruntime-linux-x64-gpu-1.14.1.tgz && tar -xzvf onnxruntime-linux-x64-gpu-1.14.1.tgz
- name: Build
run: cargo build --bin clip-as-service-server --release
- name: Upload artifact
Expand Down
178 changes: 177 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ license = "MIT"
ndarray = "0.15.0"
itertools = "0.10.5"
num_cpus = "1.15.0"
ort = { git = "https://github.com/pykeio/ort.git", default-features = false }
ort = { git = "https://github.com/pykeio/ort.git", features = ["cuda"] }
prost = "0.11.8"
tokenizers = "0.13.2"
tokio = { version = "1.26.0", features = ["macros", "rt-multi-thread"] }
tonic = "0.8.3"
image = "0.24.6"
clap = { version = "4.2.4", features = ["derive"] }
tracing-subscriber = { version = "0.3", features = [ "env-filter", "fmt" ] }

[build-dependencies]
tonic-build = "0.8.4"
Expand All @@ -31,3 +32,9 @@ rustflags = ["-C", "linker-flavor=ld.lld"]

[target.x86_64-pc-windows-msvc]
rustflags = ["-C", "target-feature=+crt-static"]

[profile.dev]
rpath = true

[profile.release]
rpath = true
8 changes: 6 additions & 2 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,17 @@ impl Encoder for EncoderService {

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();

let args = Args::parse();

let addr: &String = &args.listen;

let environment =
Arc::new(Environment::builder()
.with_name("clip")
.build().unwrap());
.with_name("clip")
.with_execution_providers([ExecutionProvider::cuda()])
.build().unwrap());

let server = EncoderService::new(&environment, &args);

Expand Down

0 comments on commit 4e5db87

Please sign in to comment.