-
Notifications
You must be signed in to change notification settings - Fork 218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
StatsBase.transform
fails on CuArray
#426
Comments
I figured it out myself - the issue is caused by the fact that the transform mean and scale are on the CPU. When I transfer them to the GPU with |
@sdewaele could you please post how did you used cu() to transfer transform mean and scale? |
I actually created a type similar to |
Using this zscore-transform.jl, the following seems to do the job: using CUDA
using StatsBase
include("zscore-transform.jl")
# Generic version
Ac = cu(5randn(3,4).+2)
tc = fit(ZScoreTransformGeneric,Ac;dims=2)
Agc = StatsBase.transform(tc,Ac)
# Compare to original ZScoreTransform result
using Test
A = collect(Ac)
t = fit(ZScoreTransform,A;dims=2)
Ag = StatsBase.transform(t,A)
@test collect(Agc)≈Ag I may submit a PR to Note that there is also |
When completed, this PR will resolve this issue: JuliaStats/StatsBase.jl#622 |
@lorrp1 The PR was merged into the master of StatsBase. |
The following code fails on a
CuArray
:Error message
This condensed reproduction of the
transform!
code may facilitate debugging:The error occurs on this line in
StatsBase.transform!
.I know that it is not too hard to write code from scratch for this type of normalisation. For example, see
Flux.normalise
does something similar, although I do want to use precomputed mean and scale, that is why I triedStatsBase
. However:CUDA.jl
. Or perhaps,StatsBase
?Manifest.toml
Version information:
Cuda version information:
NOTE: I have also reproduced the issue in Julia 1.5.0 on Linux, on a system where the CUDA.jl warning below does not occur and CUDA binaries were downloaded with binarybuilder.
The text was updated successfully, but these errors were encountered: