Skip to content

Commit

Permalink
tch cuda cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
r730 committed Nov 6, 2023
1 parent 57f21b3 commit da6b739
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ export LIBTORCH_INCLUDE=/Users/tom/Downloads/libtorch
export LIBTORCH_LIB=/Users/tom/Downloads/libtorch
export LIBTORCH=/Users/tom/Downloads/libtorch
export DYLD_LIBRARY_PATH=/Users/tom/Downloads/libtorch/lib:$DYLD_LIBRARY_PATH
3. ubuntu22.04.3版本上使用 Tesla M40(NVIDIA-Linux-x86_64-535.129.03.run) cuda118版本,测试tch cuda版本通过(缺少包可以拷贝到debug下)
3. ubuntu22.04.3版本上使用 Tesla M40(NVIDIA-Linux-x86_64-535.129.03.run) cuda118版本,测试tch cuda版本通过(缺少包可以拷贝到debug下)
- `cargo test --test tch_test`
21 changes: 18 additions & 3 deletions tests/tch.rs → tests/tch_test.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#[cfg(test)]
mod alstm {
use serde_json::Value;
use tch::nn::OptimizerConfig;
use tch::{
kind,
Expand All @@ -15,11 +14,27 @@ mod alstm {
}

#[test]
fn alstm_build_model() {
fn alstm_build_model_cuda_works() {
let vs = nn::VarStore::new(Device::Cuda(0));
let my_module = my_module(vs.root(), 7);
let mut opt = nn::Sgd::default().build(&vs, 1e-2).unwrap();
for _idx in 1..5000 {
// Dummy mini-batches made of zeros.
let xs = Tensor::zeros(&[7], kind::FLOAT_CUDA);
let ys = Tensor::zeros(&[7], kind::FLOAT_CUDA);
let loss = (my_module.forward(&xs) - ys)
.pow_tensor_scalar(2)
.sum(kind::Kind::Float);
opt.backward_step(&loss);
}
}

#[test]
fn alstm_build_model_cpu_works() {
let vs = nn::VarStore::new(Device::Cpu);
let my_module = my_module(vs.root(), 7);
let mut opt = nn::Sgd::default().build(&vs, 1e-2).unwrap();
for _idx in 1..50 {
for _idx in 1..5000 {
// Dummy mini-batches made of zeros.
let xs = Tensor::zeros(&[7], kind::FLOAT_CPU);
let ys = Tensor::zeros(&[7], kind::FLOAT_CPU);
Expand Down

0 comments on commit da6b739

Please sign in to comment.