Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: add Base.isfusing function to allow containers to disable fusion #22063

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@ using Base.Cartesian
using Base: linearindices, tail, OneTo, to_shape,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache,
nullable_returntype, null_safe_op, hasvalue, isoperator
import Base: broadcast, broadcast!
export broadcast_getindex, broadcast_setindex!, dotview, @__dot__
import Base: broadcast, broadcast!, @_pure_meta
export broadcast_getindex, broadcast_setindex!, dotview, @__dot__, isfusing

const ScalarType = Union{Type{Any}, Type{Nullable}}

# containers can override isfusing to disable broadcast fusion for specific container types
isfusing(args...) = (@_pure_meta; true)

## Broadcasting utilities ##
# fallbacks for some special cases
@inline broadcast(f, x::Number...) = f(x...)
Expand Down
29 changes: 19 additions & 10 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -1727,7 +1727,9 @@
oldarg))
fargs args)))
(let ,fbody ,@(reverse (fuse-lets fargs args '()))))))
(define (dot-to-fuse e) ; convert e == (. f (tuple args)) to (fuse f args)
; convert e == (. f (tuple args)) to (fuse f args),
; recursively for nested calls, fusing if fuse? is true.
(define (dot-to-fuse e fuse?)
(define (make-fuse f args) ; check for nested (fuse f args) exprs and combine
(define (split-kwargs args) ; return (cons keyword-args positional-args) extracted from args
(define (sk args kwargs pargs)
Expand All @@ -1742,8 +1744,8 @@
(let* ((kws.args (split-kwargs args))
(kws (car kws.args))
(args (cdr kws.args)) ; fusing occurs on positional args only
(args_ (map dot-to-fuse args)))
(if (anyfuse? args_)
(args_ (map (lambda (e) (dot-to-fuse e fuse?)) args)))
(if (and fuse? (anyfuse? args_))
`(fuse ,(fuse-funcs (to-lambda f args kws) args_) ,(fuse-args args_))
`(fuse ,(to-lambda f args kws) ,args_))))
(if (and (pair? e) (eq? (car e) '|.|))
Expand Down Expand Up @@ -1801,15 +1803,22 @@
(cons farg new-fargs) (cons arg new-args) renames varfarg vararg))))))
(cf (cdadr f) args '() '() '() '() '()))
e)) ; (not (fuse? e))
(let ((e (compress-fuse (dot-to-fuse rhs))) ; an expression '(fuse func args) if expr is a dot call
(lhs-view (ref-to-view lhs))) ; x[...] expressions on lhs turn in to view(x, ...) to update x in-place
; convert fuse expressions to ordinary broadcast calls, or broadcast! if lhs != null:
(define (to-broadcast lhs e)
(if (fuse? e)
(let ((bargs (map (lambda (e) (to-broadcast '() e)) (caddr e))))
(if (null? lhs)
`(call (top broadcast) ,(from-lambda (cadr e)) ,@bargs)
`(call (top broadcast!) ,(from-lambda (cadr e)) ,(ref-to-view lhs) ,@bargs)))
(if (null? lhs)
(expand-forms `(call (top broadcast) ,(from-lambda (cadr e)) ,@(caddr e)))
(expand-forms `(call (top broadcast!) ,(from-lambda (cadr e)) ,lhs-view ,@(caddr e))))
(if (null? lhs)
(expand-forms e)
(expand-forms `(call (top broadcast!) (top identity) ,lhs-view ,e))))))
e
`(call (top broadcast!) (top identity) ,(ref-to-view lhs) ,e))))
(let ((e (compress-fuse (dot-to-fuse rhs #t))) ; an expression '(fuse func args) if expr is a dot call
(e0 (dot-to-fuse rhs #f))) ; e without fusion
(if (fuse? e)
(expand-forms `(if (call (top isfusing) ,@(caddr e))
,(to-broadcast lhs e) ,(to-broadcast lhs e0)))
(expand-forms (to-broadcast lhs e)))))

(define (expand-where body var)
(let* ((bounds (analyze-typevar var))
Expand Down
2 changes: 1 addition & 1 deletion test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ end

# make sure scalars are inlined, which causes f.(x,scalar) to lower to a "thunk"
import Base.Meta: isexpr
@test isexpr(expand(:(f.(x,y))), :call)
@test isexpr(expand(:(f.(x,y))), :body)
@test isexpr(expand(:(f.(x,1))), :thunk)
@test isexpr(expand(:(f.(x,1.0))), :thunk)
@test isexpr(expand(:(f.(x,$π))), :thunk)
Expand Down