Add candle CudaDevice
and MetalDevice
to avoid creating a new unique device each time
#2290
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Checklist
run-checks all
script has been executed.Related Issues/PRs
Flagged by a discord user: https://discord.com/channels/1038839012602941528/1091796857996451942/1286409544280576110
The candle device handling is clumsy. Turns out, every time you create a new device struct for the same device index, a new unique identifier is created - meaning that both structs (to represent the same device) are not equal.
The original issue flagged a problem for metal device due to this inequality, but in my tests the behavior leads to an insane training time with CUDA (ETA > 2hrs).
This is because the
tensor.to_device(device)
method in Candle actually transfers the data via the CPU for CUDA devices that are different. A terrible side effect of using different unique identifiers to represent the same device index.Changes
The
CandleDevice
enum now captures the underlying device structs to avoid creating a new one with a unique identifier each time.To keep the API simple (and similar to the previous usage), I've added creation methods for
Cuda
andMetal
variants:Also had to remove the
Copy
implementation because the candle devices do not implement it.Testing
Tested (locally) the
mnist
example with a new candle feature flag to use the candle backend. The training now completes in approx. 3-4 minutes.