Skip to content

Commit

Permalink
feat: add support for clamp and clamp! (#247)
Browse files Browse the repository at this point in the history
* feat: add support for `clamp` and `clamp!`

* chore: apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* test: add simple tests

* feat: allow mixed types

* fix: make testing approx

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
avik-pal and github-actions[bot] authored Nov 8, 2024
1 parent 1ff11c9 commit 3ba7c3e
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 1 deletion.
8 changes: 8 additions & 0 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -847,3 +847,11 @@ function maybe_expand_dims(x::AbstractArray{T,N}, dims) where {T,N}
dims N && return x
return reshape(x, ntuple(i -> i N ? size(x, i) : 1, dims))
end

for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRNumber))
@eval function Base.clamp!(x::TracedRArray{T}, min::$(minT), max::$(maxT)) where {T}
y = clamp.(x, min, max)
x.mlir_data = y.mlir_data
return x
end
end
13 changes: 13 additions & 0 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,19 @@ Base.abs2(x::TracedRNumber{<:Real}) = x^2

Base.log1p(x::TracedRNumber{T}) where {T} = log(x + one(T))

for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRNumber))
@eval function Base.clamp(x::TracedRNumber{T}, min::$(minT), max::$(maxT)) where {T}
min = promote_to(TracedRNumber{T}, min)
max = promote_to(TracedRNumber{T}, max)
return TracedRNumber{T}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.clamp(min.mlir_data, x.mlir_data, max.mlir_data), 1
),
)
end
end

struct TypeCast{T<:ReactantPrimitive} <: Function end

(::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = promote_to(TracedRNumber{T}, x)
Expand Down
19 changes: 19 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -538,3 +538,22 @@ end
@test Int(x) isa Int
@test float(x) isa ConcreteRNumber{Float64}
end

@testset "clamp" begin
x = randn(2, 3)
x_ra = Reactant.to_rarray(x)

y = @jit(clamp!(x_ra, 0.0, 0.25))
@test maximum(y) 0.25
@test minimum(y) 0.0
@test maximum(x_ra) == maximum(y)
@test minimum(x_ra) == minimum(y)

x = randn(2, 3)
x_ra = Reactant.to_rarray(x)

y = @jit(clamp.(x_ra, 0.0, 0.25))
@test maximum(y) 0.25
@test minimum(y) 0.0
@test x_ra x
end
2 changes: 1 addition & 1 deletion test/nn/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using NNlib, Reactant, Enzyme
x_act_ca = Reactant.ConcreteRArray(x_act)

@testset "Activation: $act" for act in (
identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2
identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2, relu6
)
f_compile = Reactant.compile(sumabs2, (act, x_act))

Expand Down

1 comment on commit 3ba7c3e

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: 3ba7c3e Previous: 1ff11c9 Ratio
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 5787425685 ns 6384860404 ns 0.91
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant 5292258390 ns 5301328245 ns 1.00
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 6086056532 ns 5203462205 ns 1.17
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 7587601119 ns 7305272771 ns 1.04
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux 28087750784 ns 34987991089 ns 0.80
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1563822331 ns 1561591277 ns 1.00
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant 1543677512 ns 1557307113 ns 0.99
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1553822136 ns 1542360845 ns 1.01
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3309603029 ns 3316796244 ns 1.00
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux 3236551447 ns 3043968768 ns 1.06
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 2198150190 ns 2146877682 ns 1.02
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant 2155687426 ns 2131581565 ns 1.01
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 2192886728 ns 2138865967 ns 1.03
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 3908194881 ns 3910063623 ns 1.00
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux 5993416352 ns 5708530964 ns 1.05
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1406808783.5 ns 1416785120 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant 1407299141 ns 1421169200 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1410969730 ns 1406671878 ns 1.00
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3156311368 ns 3159231592 ns 1.00
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux 1099155376.5 ns 1143922805 ns 0.96
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 1727787162 ns 1719831635 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant 1727804980 ns 1705729225 ns 1.01
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 1711663111 ns 1700558831 ns 1.01
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3460051766 ns 3451683364 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux 3010659432 ns 3140401890 ns 0.96
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 2148427239 ns 2158380749 ns 1.00
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant 2170426380 ns 2164881935 ns 1.00
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 2187259107 ns 2126575115 ns 1.03
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3958804601 ns 3935604871 ns 1.01
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux 6647100753 ns 5762722471.5 ns 1.15
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 3146044029 ns 3016029772 ns 1.04
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant 3146912971 ns 2981006759 ns 1.06
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 3047329260 ns 2989746439 ns 1.02
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 4862728550 ns 4846949677 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux 12794226734 ns 10864111732 ns 1.18
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 3132478421 ns 3166559857 ns 0.99
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant 3179953038 ns 3144854123 ns 1.01
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 3185074336 ns 3165468463 ns 1.01
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 5092564084 ns 5000879641 ns 1.02
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux 12253319305 ns 15188174420 ns 0.81
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1855345054 ns 1819847550 ns 1.02
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant 1849809131 ns 1826616736 ns 1.01
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1855337197 ns 1850176920 ns 1.00
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3604644289 ns 3562720585 ns 1.01
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux 5868629461.5 ns 4485724051 ns 1.31

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.