Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for julia 1.9 #718

Merged
merged 4 commits into from
Jun 2, 2023
Merged

Fix for julia 1.9 #718

merged 4 commits into from
Jun 2, 2023

Conversation

oxinabox
Copy link
Member

JuliaLang/julia#44358 broke our things on julia 1.9
this gets it back, but idk numerical linear algrebra so I am not sure it doesn't cost us anything.
Also inv is broken on GPU again.

@oxinabox oxinabox marked this pull request as ready for review May 19, 2023 15:08
@oxinabox oxinabox requested a review from mcabbott May 19, 2023 15:08
@oxinabox oxinabox self-assigned this May 19, 2023
@oxinabox
Copy link
Member Author

oxinabox commented May 19, 2023

The test we need to pass is

using ChainRulesTestUtils
using Test
using ChainRulesCore

@testset "$f" for f in (/, \)
        @testset "Matrix" begin
            for n in 3:5, m in 3:5
                A = randn(m, n)
                B = randn(m, n)
                test_rrule(f, A, B; check_inferred=false) # ChainRulesCore #407
            end
        end
        @testset "Vector" begin
            x = randn(10)
            y = randn(10)
            test_rrule(f, x, y; check_inferred=false) # ChainRulesCore #407
        end
        if f == (\)
            @testset "Matrix $f Vector" begin
                X = randn(10, 4)
                y = randn(10)
                test_rrule(f, X, y)
            end
            @testset "Vector $f Matrix" begin
                x = randn(10)
                Y = randn(10, 4)
                test_rrule(f, x, Y; output_tangent=Transpose(rand(4)))
            end
        else
            A = rand(2, 4)
            B = rand(4, 4)
            test_rrule(f, A, B; check_inferred=false) # ChainRulesCore #407
        end
end

@oxinabox
Copy link
Member Author

Oh interesting x86 (32bit) segfaults
https://github.com/JuliaDiff/ChainRules.jl/actions/runs/5025659382/jobs/9012921468?pr=718#step:6:213
definately a bug in Julia itself or in JuliaInterpretter though

@oxinabox
Copy link
Member Author

@antoine-levitt you were going to review this?

@antoine-levitt
Copy link

Yes sorry I commented on the other issue here: JuliaLang/julia#49915 (comment). I don't think that's the way to go, because it computes explicit pseudo-inverses. Instead, you can just do something like Ȳ isa AbstractArray || Ȳ = [Ȳ] to get back the old behavior.

@oxinabox
Copy link
Member Author

Ok cool are you abble to review this PR and if it looks good now approve so i can merge?

@antoine-levitt
Copy link

I don't understand why you're doing this. Can't this PR literally be Ȳ isa AbstractArray || Ȳ = [Ȳ]? There's no need to change anything else, is there?

@antoine-levitt
Copy link

Ah, sorry, I didn't get that it was to factor only once. This is not an exact optimization, because A \ b has a slightly different logic than factorize(A)\b, but I don't think this is a problem, it should do the right thing nonetheless. So yes good to go I think.

@antoine-levitt
Copy link

But

julia> randn(2) / factorize(randn(2,2))
ERROR: MethodError: no method matching /(::Vector{Float64}, ::LU{Float64, Tridiagonal{Float64, Vector{Float64}}, …})

Presumably that method needs to be added to factorize for that to work, I'll try to do a PR on julia. In the meantime, you can get by using the \ method of factorized, which does work, and transposing.

@oxinabox
Copy link
Member Author

oxinabox commented Jun 1, 2023

So I pulled out everything except your suggested change @antoine-levitt, and this still fails.
Still with MethodError: no method matching /(::Float64, ::Vector{Float64})
and to make for easier debugging i even removed the thunks then i got:
A bunch of numerical errors, and
DimensionMismatch: variable with size(x) == (4, 4) cannot have a gradient with size(dx) == (2, 4)

( we can make a follow up to pre-factorize once things are not longer broken on julia 1.9)

Not sure why. Do you have any clue?

I think it would be good to nail this down before JuliaLang/julia#49915 gets discussed by triage.
because I think the situation changes if the fix isn't simple but actually is fairly complex.

@antoine-levitt
Copy link

Yeah same problem just a bit below:

            Ā = add!!(Ā, A' \ Y * (Ȳ' - B̄'A))

That Y (but not the one above) needs to be []ified too.

@oxinabox
Copy link
Member Author

oxinabox commented Jun 1, 2023

With that the MethodError is gone, but the DimensionMismatch, and the accuracy problem remains.

@antoine-levitt
Copy link

What's your MWE? It looks like it's working for me. And what do you mean by accuracy?

@oxinabox
Copy link
Member Author

oxinabox commented Jun 1, 2023

What's your MWE?

I am just running the tests:
To reproduce the particular ones that fail:

julia> using ChainRules, ChainRulesCore, ChainRulesTestUtils, Test

julia> @testset "$f" for f in (/, \)
               @testset "Matrix" begin
                   for n in 3:5, m in 3:5
                       A = randn(m, n)
                       B = randn(m, n)
                       test_rrule(f, A, B; check_inferred=false) # ChainRulesCore #407
                   end
               end
               @testset "Vector" begin
                   x = randn(10)
                   y = randn(10)
                   test_rrule(f, x, y; check_inferred=false) # ChainRulesCore #407
               end
               if f == (\)
                   @testset "Matrix $f Vector" begin
                       X = randn(10, 4)
                       y = randn(10)
                       test_rrule(f, X, y)
                   end
                   @testset "Vector $f Matrix" begin
                       x = randn(10)
                       Y = randn(10, 4)
                       test_rrule(f, x, Y; output_tangent=Transpose(rand(4)))
                   end
               else
                   A = rand(2, 4)
                   B = rand(4, 4)
                   test_rrule(f, A, B; check_inferred=false) # ChainRulesCore #407
               end
           end

You can see it failing in CI: https://github.com/JuliaDiff/ChainRules.jl/actions/runs/5143424642/jobs/9258419472?pr=718#step:6:804

By accuracy I mean: It is giving significantly numerically different results to finite differencing.
Might be numeric issues, might actually be the code is wrong now? Its a pretty big difference.

@antoine-levitt
Copy link

That's just a typo, you wrote

        ∂A = #@thunk begin
            B̄ = A' \ Ȳ

which is doing ∂A = B̄ = A' \ Ȳ, which is not what you want

@oxinabox
Copy link
Member Author

oxinabox commented Jun 2, 2023

ah lol, in trying to simplify the problem I caused new ones

@oxinabox
Copy link
Member Author

oxinabox commented Jun 2, 2023

Still getting errors.
I believe it is because for
when

│   typeof(Y) = Matrix{Float64} (alias for Array{Float64, 2})
│   typeof(A) = LinearAlgebra.Adjoint{Float64, Vector{Float64}}
└   typeof(B) = LinearAlgebra.Adjoint{Float64, Vector{Float64}}

we have that (B - A * Y) * B̄' is a Float64 (since it is Vector * Adjoint{Vector}).
When i assign that to a variable and wrap it in a [t] when it is scalar I get more errors.

See the whole complexity of this to me is really saying that, it is kinda needed that / accepts a scalar and a vector for matching that products of vectors and adjoint(vector)s produce scalars.

@antoine-levitt
Copy link

function rrule(A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
    Y = A \ B
    maybe_promote(x) = x isa AbstractArray ? x : [x]

    function backslash_pullback(ȳ)
        Ȳ = ȳ
        ∂A = begin
            B̄ = A' \ maybe_promote(Ȳ)
            Ā = -B̄ * Y'
            Ā += maybe_promote((B - A * Y) * B̄') / A'
            Ā += (A' \ maybe_promote(Y)) * (Ȳ' - B̄'A)
            (Ā)
        end
        ∂B = (A' \ maybe_promote(Ȳ))
        return ∂A, ∂B
    end
    return Y, backslash_pullback
end


A = randn(2)
B = randn(2)
rrule(A, B)[2](1.0)

works for me.

See the whole complexity of this to me is really saying that, it is kinda needed that / accepts a scalar and a vector for matching that products of vectors and adjoint(vector)s produce scalars.

This workaround is needed because of the interaction of vector' * vector being a scalar (which is reasonable and isn't really questionned) and vector \ vector being pinv(A)*B. I think there's a consensus for saying that vector/vector and vector\vector working was a bad leftover from matlab (where it's also a bad idea, but unavoidable since there's no distinction between vectors and 2x1 matrices there), and that it should be removed in 2.0 in favor of matrix\vector and vector/matrix (in which case, to get the old A/B for A and B vectors, you'd do A/hcat(B)), so that this issue just goes away. Again I'm pretty sure this issue is the result of an overzealous implementation of the rule to support the vector\vector usecase, which pretty much nobody uses (and nobody should use, since hopefully it will be gone in 2.0).

Also make second Y not scalar

more coercing some things into arrays some of the time

cleaner def with a helper function
@oxinabox
Copy link
Member Author

oxinabox commented Jun 2, 2023

ok, merging this. Though can revert if JuliaLang/julia#49915 goes through.

I am going to just note that 1.6 on 64 bit is broken.
As it seems unreleated to this PR. And seems like a bug in Julia itself

@oxinabox oxinabox merged commit 50d9d03 into main Jun 2, 2023
@oxinabox oxinabox deleted the ox/1.9fixes branch June 2, 2023 09:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants