Skip to content

Commit

Permalink
Fix sum() and prod() for tuples (#41510)
Browse files Browse the repository at this point in the history
This PR aims to fix #39182 and #39183 by using the universal implementation of `prod` and `sum` from https://github.com/JuliaLang/julia/blob/97f817a379b0c3c5f9bb803427fe88a018ebfe18/base/reduce.jl#L588

However, the file `abstractarray.jl` is included way sooner, and it is crucial to have already a simplified version of `prod` function.

We can specify a simplified version or `prod` only for a system-wide `Int` type that is sufficient to compile `Base`.
``` julia
prod(x::Tuple{}) = 1
# This is consistent with the regular prod because there is no need for size promotion
# if all elements in the tuple are of system size.
prod(x::Tuple{Int, Vararg{Int}}) = *(x...)
```

Although the implementations are different, they lead to the same binary code for tuples containing ~~`UInt` and~~ `Int`.

``` julia
julia> a = (1,2,3)
(1, 2, 3)

# Simplified version for tuples containing Int only
julia> prod_simplified(x::Tuple{Int, Vararg{Int}}) = *(x...)

julia> @code_native prod_simplified(a)
	.text
; ┌ @ REPL[1]:1 within `prod_simplified'
; │┌ @ operators.jl:560 within `*' @ int.jl:88
	movq	8(%rdi), %rax
	imulq	(%rdi), %rax
	imulq	16(%rdi), %rax
; │└
	retq
	nop
; └
```
``` julia
# Regular prod without the simplification
julia> @code_native prod(a)
        .text
; ┌ @ reduce.jl:588 within `prod`
; │┌ @ reduce.jl:588 within `#prod#247`
; ││┌ @ reduce.jl:289 within `mapreduce`
; │││┌ @ reduce.jl:289 within `#mapreduce#240`
; ││││┌ @ reduce.jl:162 within `mapfoldl`
; │││││┌ @ reduce.jl:162 within `#mapfoldl#236`
; ││││││┌ @ reduce.jl:44 within `mapfoldl_impl`
; │││││││┌ @ reduce.jl:48 within `foldl_impl`
; ││││││││┌ @ tuple.jl:276 within `_foldl_impl`
; │││││││││┌ @ operators.jl:613 within `afoldl`
; ││││││││││┌ @ reduce.jl:81 within `BottomRF`
; │││││││││││┌ @ reduce.jl:38 within `mul_prod`
; ││││││││││││┌ @ int.jl:88 within `*`
        movq    8(%rdi), %rax
        imulq   (%rdi), %rax
; │││││││││└└└└
; │││││││││┌ @ operators.jl:614 within `afoldl`
; ││││││││││┌ @ reduce.jl:81 within `BottomRF`
; │││││││││││┌ @ reduce.jl:38 within `mul_prod`
; ││││││││││││┌ @ int.jl:88 within `*`
        imulq   16(%rdi), %rax
; │└└└└└└└└└└└└
        retq
        nop
; └
```

(cherry picked from commit bada80c)
  • Loading branch information
petvana authored and KristofferC committed Sep 3, 2021
1 parent 7cdfa10 commit a7f8825
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
15 changes: 5 additions & 10 deletions base/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -467,17 +467,12 @@ reverse(t::Tuple) = revargs(t...)

## specialized reduction ##

# TODO: these definitions cannot yet be combined, since +(x...)
# where x might be any tuple matches too many methods.
# TODO: this is inconsistent with the regular sum in cases where the arguments
# require size promotion to system size.
sum(x::Tuple{Any, Vararg{Any}}) = +(x...)

# NOTE: should remove, but often used on array sizes
# TODO: this is inconsistent with the regular prod in cases where the arguments
# require size promotion to system size.
prod(x::Tuple{}) = 1
prod(x::Tuple{Any, Vararg{Any}}) = *(x...)
# This is consistent with the regular prod because there is no need for size promotion
# if all elements in the tuple are of system size.
# It is defined here separately in order to support bootstrap, because it's needed earlier
# than the general prod definition is available.
prod(x::Tuple{Int, Vararg{Int}}) = *(x...)

all(x::Tuple{}) = true
all(x::Tuple{Bool}) = x[1]
Expand Down
18 changes: 18 additions & 0 deletions test/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,24 @@ end
@test prod(()) === 1
@test prod((1,2,3)) === 6

# issue 39182
@test sum((0xe1, 0x1f)) === sum([0xe1, 0x1f])
@test sum((Int8(3),)) === Int(3)
@test sum((UInt8(3),)) === UInt(3)
@test sum((3,)) === Int(3)
@test sum((3.0,)) === 3.0
@test sum(("a",)) == sum(["a"])
@test sum((0xe1, 0x1f), init=0x0) == sum([0xe1, 0x1f], init=0x0)

# issue 39183
@test prod((Int8(100), Int8(100))) === 10000
@test prod((Int8(3),)) === Int(3)
@test prod((UInt8(3),)) === UInt(3)
@test prod((3,)) === Int(3)
@test prod((3.0,)) === 3.0
@test prod(("a",)) == prod(["a"])
@test prod((0xe1, 0x1f), init=0x1) == prod([0xe1, 0x1f], init=0x1)

@testset "all" begin
@test all(()) === true
@test all((false,)) === false
Expand Down

0 comments on commit a7f8825

Please sign in to comment.