Skip to content

Commit

Permalink
fmap_with_keypath
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Mar 8, 2024
1 parent 4019d4b commit 811fc92
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 3 deletions.
28 changes: 27 additions & 1 deletion src/Functors.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module Functors

export @functor, @flexiblefunctor, fmap, fmapstructure, fcollect, execute,
KeyPath
KeyPath, fmap_with_keypath

include("functor.jl")
include("walks.jl")
Expand Down Expand Up @@ -305,4 +305,30 @@ julia> fcollect(m, exclude = v -> Functors.isleaf(v))
"""
fcollect


""""
fmap_with_keypath(f, x, ys...; exclude = Functors.isleaf, walk = Functors.DefaultWalkWithKeyPath())
Like [`fmap`](@ref), but also passes a `KeyPath` to `f` for each node in the
recursion. The `KeyPath` is a tuple of the indices used to reach the current
node from the root of the recursion. The `KeyPath` is constructed by the
`walk` function, and can be used to reconstruct the path to the current node
from the root of the recursion.
`f` should accept two arguments: the value of the current node, and the associated `KeyPath`.
`exclude` also receives the `KeyPath` as its first argument.
# Examples
```jldoctest
julia> x = ([1, 2, 3], 4, (a=5, b=Dict("A"=>6, "B"=>7), c=Dict("C"=>8, "D"=>9)));
julia> fexclude(kp, x) = kp == KeyPath(3, :c) || Functors.isleaf(x)
julia> fmap_with_keypath((kp, x) -> x isa Dict ? nothing : x.^2, x; exclude = fexclude)
([1, 4, 9], 16, (a = 25, b = Dict("B" => 49, "A" => 36), c = nothing))
```
"""
fmap_with_keypath

end # module
1 change: 1 addition & 0 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ macro functor(args...)
end

isleaf(@nospecialize(x)) = children(x) === NoChildren()
isleaf(::KeyPath, @nospecialize(x)) = isleaf(x)

children(x) = functor(x)[1]

Expand Down
13 changes: 11 additions & 2 deletions src/keypath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,16 @@ end
Base.getindex(kp::KeyPath, i::Int) = kp.keys[i]
Base.length(kp::KeyPath) = length(kp.keys)
Base.iterate(kp::KeyPath, state=1) = iterate(kp.keys, state)
Base.:(==)(kp1::KeyPath, kp2::KeyPath) = kp1.keys == kp2.keys

function Base.show(io::IO, kp::KeyPath)
compat = get(io, :compact, false)
if compat
print(io, keypathstr(kp))
else
print(io, "KeyPath$(kp.keys)")
end
end

Base.show(io::IO, kp::KeyPath) = print(io, "KeyPath$(kp.keys)")

keypathstr(kp::KeyPath) = join(kp.keys, ".")

26 changes: 26 additions & 0 deletions src/maps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,32 @@ function fmap(f, x, ys...; exclude = isleaf,
execute(_walk, x, ys...)
end

""""
fmap_with_keypath(f, x, ys...; exclude = Functors.isleaf, walk = Functors.DefaultWalkWithKeyPath())
Like [`fmap`](@ref), but also passes a `KeyPath` to `f` for each node in the
recursion. The `KeyPath` is a tuple of the indices used to reach the current
node from the root of the recursion. The `KeyPath` is constructed by the
`walk` function, and can be used to reconstruct the path to the current node
from the root of the recursion.
# Examples
```jldoctest
julia> fmap_with_keypath((x, kp) -> (x, kp), (1, (2, 3)))
(1, ())
(2, (1,))
(3, (2,))
```
"""
function fmap_with_keypath(f, x, ys...;
exclude = isleaf,
walk = DefaultWalkWithKeyPath())

_walk = ExcludeWalkWithKeyPath(walk, f, exclude)
return execute(_walk, KeyPath(), x, ys...)
end

fmapstructure(f, x; kwargs...) = fmap(f, x; walk = StructuralWalk(), kwargs...)

fcollect(x; exclude = v -> false) =
Expand Down
25 changes: 25 additions & 0 deletions src/walks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ _map(f, x::Dict, ys...) = Dict(k => f(v, (y[k] for y in ys)...) for (k, v) in x)
_values(x) = x
_values(x::Dict) = values(x)

_keys(x::Dict) = Dict(k => k for k in keys(x))
_keys(x::Tuple) = (keys(x)...,)
_keys(x::AbstractArray) = collect(x)
_keys(x::NamedTuple{Ks}) where Ks = NamedTuple{Ks}(Ks)

"""
AbstractWalk
Expand Down Expand Up @@ -76,6 +81,16 @@ function (::DefaultWalk)(recurse, x, ys...)
re(_map(recurse, func, yfuncs...))
end

struct DefaultWalkWithKeyPath <: AbstractWalk end

function (::DefaultWalkWithKeyPath)(recurse, kp::KeyPath, x, ys...)
x_children, re = functor(x)
kps = _map(c -> KeyPath(kp, c), _keys(x_children)) # use _keys and _map to preserve x_children type
ys_children = map(children, ys)
re(_map(recurse, kps, x_children, ys_children...))
end


"""
StructuralWalk()
Expand Down Expand Up @@ -106,6 +121,16 @@ end
(walk::ExcludeWalk)(recurse, x, ys...) =
walk.exclude(x) ? walk.fn(x, ys...) : walk.walk(recurse, x, ys...)

struct ExcludeWalkWithKeyPath{T, F, G} <: AbstractWalk
walk::T
fn::F
exclude::G
end

(walk::ExcludeWalkWithKeyPath)(recurse, kp::KeyPath, x, ys...) =
walk.exclude(kp, x) ? walk.fn(kp, x, ys...) : walk.walk(recurse, kp, x, ys...)


struct NoKeyword end

usecache(::Union{AbstractDict, AbstractSet}, x) =
Expand Down

0 comments on commit 811fc92

Please sign in to comment.