-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OptArgs
with GPU array
#8
Comments
Hmm, I can understand why you would expect Currently, Not that details of |
Trying out monkey-patching to see what overloads (
First, this won't work as-is, because |
This works with using AccessibleOptimization, DifferentiationInterface, Optimization, OptimizationOptimJL, Random, GPUArrays, LinearAlgebra, AMDGPU, Zygote
let n = 8
data = ROCArray(rand(n))
f2(θ::AbstractGPUArray, x) = θ⋅x
θ₀ = ROCArray(rand(n))
vars = OptArgs(@o _[∗])
optfunc = OptimizationFunction(f2, Optimization.AutoZygote())
ops = OptProblemSpec((@o optfunc(_, data)), θ₀, vars,)
@time "solve" soln = solve(ops, BFGS())
end I tried playing with Accessors.setall(obj::AbstractArray, ::Elements, vs::ROCArray) =
(@assert length(obj) == length(vs); (reshape(vs, size(obj)))) but there are so many moving parts in this MWE with the autodiff and everything I kinda lose track of what's going on. |
It should be automatic if not passed – basically, "whatever you get from I have basically no experience working with GPU arrays (nor with reverse autodiff), probably won't be of much help here. But if you find some fixes or self-contained examples we can add to tests (even if Also, note that the main usecase for AccessibleOptimization (and Accessors in general, btw) is having some reasonably-small number of parameters that you specify with optics. Arrays and tuples also work of course, with |
If I understand,
fromrawu
is supposed to change theVector
into a GPU array here but it doesn't: the first argument off2
is based on aVector
.Am I supposed to do
OptArgs(@o ROCArray(_[∗]))
or overloadfromrawu
or_convert
or are we waiting onBase.into
or something else?The text was updated successfully, but these errors were encountered: