Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add candle CudaDevice and MetalDevice to avoid creating a new unique device each time #2290

Merged
merged 4 commits into from
Sep 25, 2024

Conversation

laggui
Copy link
Member

@laggui laggui commented Sep 20, 2024

Checklist

  • Confirmed that run-checks all script has been executed.

Related Issues/PRs

Flagged by a discord user: https://discord.com/channels/1038839012602941528/1091796857996451942/1286409544280576110

The candle device handling is clumsy. Turns out, every time you create a new device struct for the same device index, a new unique identifier is created - meaning that both structs (to represent the same device) are not equal.

The original issue flagged a problem for metal device due to this inequality, but in my tests the behavior leads to an insane training time with CUDA (ETA > 2hrs).

This is because the tensor.to_device(device) method in Candle actually transfers the data via the CPU for CUDA devices that are different. A terrible side effect of using different unique identifiers to represent the same device index.

Changes

The CandleDevice enum now captures the underlying device structs to avoid creating a new one with a unique identifier each time.

To keep the API simple (and similar to the previous usage), I've added creation methods for Cuda and Metal variants:

// Create a Cuda device from its index
let device = CandleDevice::cuda(0);
// Create a Metal device from its index
let device = CandleDevice::metal(0);

Also had to remove the Copy implementation because the candle devices do not implement it.

Testing

Tested (locally) the mnist example with a new candle feature flag to use the candle backend. The training now completes in approx. 3-4 minutes.

Copy link

codecov bot commented Sep 20, 2024

Codecov Report

Attention: Patch coverage is 17.02128% with 39 lines in your changes missing coverage. Please review.

Project coverage is 85.42%. Comparing base (a6f7a5e) to head (6c99a64).
Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
crates/burn-candle/src/backend.rs 0.00% 37 Missing ⚠️
crates/burn-candle/src/ops/base.rs 80.00% 1 Missing ⚠️
crates/burn-candle/src/ops/int_tensor.rs 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2290      +/-   ##
==========================================
- Coverage   85.44%   85.42%   -0.03%     
==========================================
  Files         766      766              
  Lines       97916    97948      +32     
==========================================
+ Hits        83667    83669       +2     
- Misses      14249    14279      +30     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

ivnsch added a commit to ivnsch/tns_brn that referenced this pull request Sep 22, 2024
for better training speed
needs tracel-ai/burn#2290, thus temporary switch to main dep
@laggui laggui merged commit 112f09e into main Sep 25, 2024
11 checks passed
@laggui laggui deleted the fix/candle/device branch September 25, 2024 17:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants