Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scope self arguments using let block syntax #17

Merged
merged 9 commits into from
Aug 21, 2023
51 changes: 32 additions & 19 deletions src/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,18 @@ println(model) # "Linear(3 => 1)"
This can be useful when using `@compact` to hierarchically construct
complex models to be used inside a `Chain`.
"""
macro compact(fex, _kwexs...)
# check inputs
Meta.isexpr(fex, :(->)) || error("expects a do block")
macro compact(_exs...)
# check inputs, extracting function expression fex and unprocessed keyword arguments _kwexs
isempty(_exs) && error("expects at least one expression")
gaurav-arya marked this conversation as resolved.
Show resolved Hide resolved
if Meta.isexpr(_exs[1], :parameters)
length(_exs) >= 2 || error("expects an anonymous function")
fex = _exs[2]
_kwexs = (_exs[1], _exs[3:end]...)
else
fex = _exs[1]
_kwexs = _exs[2:end]
end
Meta.isexpr(fex, :(->)) || error("expects an anonymous function")
isempty(_kwexs) && error("expects keyword arguments")
all(ex -> Meta.isexpr(ex, (:kw,:(=),:parameters)), _kwexs) || error("expects only keyword arguments")

Expand All @@ -112,33 +121,37 @@ macro compact(fex, _kwexs...)
# make strings
layer = "@compact"
setup = NamedTuple(map(ex -> Symbol(string(ex.args[1])) => string(ex.args[2]), kwexs))
input = join(fex.args[1].args, ", ")
input =
try
fex_args = fex.args[1]
isa(fex_args, Symbol) ? string(fex_args) : join(fex_args.args, ", ")
catch e
@warn "Function stringifying does not yet handle all cases. Falling back to empty string for input arguments"
""
end
block = string(Base.remove_linenums!(fex).args[2])

# edit expressions
vars = map(ex -> ex.args[1], kwexs)
@gensym self
pushfirst!(fex.args[1].args, self)
addprefix!(fex, self, vars)
fex = supportself(fex, vars)

# assemble
return esc(quote
let
$CompactLayer($fex, $name, ($layer, $input, $block), $setup; $(kwexs...))
end
end)
return esc(:($CompactLayer($fex, $name, ($layer, $input, $block), $setup; $(kwexs...))))
end

function addprefix!(ex::Expr, self, vars)
for i = 1:length(ex.args)
if ex.args[i] in vars
ex.args[i] = :($self.$(ex.args[i]))
else
addprefix!(ex.args[i], self, vars)
function supportself(fex::Expr, vars)
@gensym self
@gensym curried_f
# To avoid having to manipulate fex's arguments and body explicitly, we form a curried function first
# that wraps the full fex expression, and then uncurry it programatically rather than syntactically.
let_exprs = map(var -> :($var = $self.$var), vars)
return quote
$curried_f = ($self) -> let $(let_exprs...)
$fex
end
($self, args...; kwargs...) -> $curried_f($self)(args...; kwargs...)
end
end
addprefix!(not_ex, self, vars) = nothing

struct CompactLayer{F,NT1<:NamedTuple,NT2<:NamedTuple}
fun::F
Expand Down
20 changes: 20 additions & 0 deletions test/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,5 +212,25 @@ end
end
@test model(2) == _a + _b * 2 + c * 2^2
end

@testset "Keyword arguments with anonymous function" begin
model = @test_nowarn @compact(x -> x+a+b; a=1, b=2)
@test model(3) == 1 + 2 + 3
expected_string = """@compact(
a = 1,
b = 2,
) do x
x + a + b
end"""
@test similar_strings(get_model_string(model), expected_string)
end

@testset "Scoping of parameter arguments" begin
model = @compact(w1 = 3, w2 = 5) do a
g(w1, w2) = 2 * w1 * w2
return (w1 + w2) * g(a, a)
end
@test model(2) == (3 + 5) * 2 * 2 * 2
end
gaurav-arya marked this conversation as resolved.
Show resolved Hide resolved
end