diff --git a/docs/src/releasenotes.md b/docs/src/releasenotes.md index 350bcc9d..ad0ef12f 100644 --- a/docs/src/releasenotes.md +++ b/docs/src/releasenotes.md @@ -1,6 +1,7 @@ # Release Notes ## Unreleased +* `Py` is now treated as a scalar when broadcasting. * Bug fixes. ## 0.9.15 (2023-10-25) diff --git a/src/Py/Py.jl b/src/Py/Py.jl index 754430ce..396b39d8 100644 --- a/src/Py/Py.jl +++ b/src/Py/Py.jl @@ -334,6 +334,8 @@ Base.in(v, x::Py) = pycontains(x, v) Base.hash(x::Py, h::UInt) = reinterpret(UInt, Int(pyhash(x))) - 3h +Base.broadcastable(x::Py) = Ref(x) + (f::Py)(args...; kwargs...) = pycall(f, args...; kwargs...) # comparisons diff --git a/test/compat.jl b/test/compat.jl index b09f5037..4d5c1745 100644 --- a/test/compat.jl +++ b/test/compat.jl @@ -1,3 +1,16 @@ +@testitem "Base.jl" begin + @testset "broadcast" begin + # Py always broadcasts as a scalar + x = [1 2; 3 4] .+ Py(1) + @test isequal(x, [Py(2) Py(3); Py(4) Py(5)]) + x = Py("foo") .* [1 2; 3 4] + @test isequal(x, [Py("foo") Py("foofoo"); Py("foofoofoo") Py("foofoofoofoo")]) + # this previously treated the list as a shape (2,) object + # but now tries to do `1 + [1, 2]` which properly fails + @test_throws PyException [1 2; 3 4] .+ pylist([1, 2]) + end +end + @testitem "pywith" begin @testset "no error" begin tdir = pyimport("tempfile").TemporaryDirectory()