Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@lag works correctly with multiple indices #67

Merged
merged 2 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/linearize.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
##################################################################################
# This file is part of ModelBaseEcon.jl
# BSD 3-Clause License
# Copyright (c) 2020-2022, Bank of Canada
# Copyright (c) 2020-2024, Bank of Canada
# All rights reserved.
##################################################################################

Expand Down Expand Up @@ -219,11 +219,12 @@ function print_linearized(io::IO, model::Model; compact::Bool=true)
sort!(nonzerosofjay, by=x -> x[1] * base^2 + x[2])

# names of variables corresponding to columns of J
var_from_col = map(string, (
ModelBaseEcon.normal_ref(var.name, lag) for var in model.varshks for lag in -model.maxlag:model.maxlead
))
var_from_col = String[
string(Expr(:ref, var.name, normal_ref(lag)))
for var in model.varshks for lag in -model.maxlag:model.maxlead
]

io = IOContext(io, :compact=>get(io, :compact, compact))
io = IOContext(io, :compact => get(io, :compact, compact))

# loop over the non-zeros of J and print
this_r = 0
Expand Down
40 changes: 19 additions & 21 deletions src/metafuncs.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
##################################################################################
# This file is part of ModelBaseEcon.jl
# BSD 3-Clause License
# Copyright (c) 2020-2023, Bank of Canada
# Copyright (c) 2020-2024, Bank of Canada
# All rights reserved.
##################################################################################

Expand All @@ -12,7 +12,8 @@ has_t(sym::Symbol) = sym == :t
has_t(expr::Expr) = has_t(expr.args...)

# normalized :ref expression
normal_ref(var, lag) = Expr(:ref, var, lag == 0 ? :t : lag > 0 ? :(t + $lag) : :(t - $(-lag)))
# normal_ref(var, lag) = Expr(:ref, var, lag == 0 ? :t : lag > 0 ? :(t + $lag) : :(t - $(-lag)))
normal_ref(lag) = lag == 0 ? :t : lag > 0 ? :(t + $lag) : :(t - $(-lag))

"""
at_lag(expr[, n=1])
Expand All @@ -24,18 +25,22 @@ function at_lag(expr::Expr, n=1)
if n == 0
return expr
elseif expr.head == :ref
var, index = expr.args
if has_t(index)
if @capture(index, t + lag_)
return normal_ref(var, lag - n)
elseif index == :t
return normal_ref(var, -n)
elseif @capture(index, t - lag_)
return normal_ref(var, -lag - n)
else
error("Must use `t`, `t+n` or `t-n`, not $index")
var, index... = expr.args
for i = eachindex(index)
ind_expr = index[i]
if has_t(ind_expr)
if @capture(ind_expr, t + lag_)
index[i] = normal_ref(lag - n)
elseif ind_expr == :t
index[i] = normal_ref(-n)
elseif @capture(ind_expr, t - lag_)
index[i] = normal_ref(-lag - n)
else
error("Must use `t`, `t+n` or `t-n`, not $(ind_expr)")
end
end
end
return Expr(:ref, var, index...)
end
return Expr(expr.head, at_lag.(expr.args, n)...)
end
Expand Down Expand Up @@ -85,14 +90,6 @@ function at_d(expr::Expr, n=1, s=0)
end
end
return ret
#### old implementation
# if s > 0
# expr = :($expr - $(at_lag(expr, s)))
# end
# for i = 1:n
# expr = :($expr - $(at_lag(expr)))
# end
# return expr
end

"""
Expand Down Expand Up @@ -128,10 +125,11 @@ at_movav(expr::Expr, n::Integer) = MacroTools.unblock(:($(at_movsum(expr, n)) /

"""
at_movsumw(expr, n, weights)
at_movsumw(expr, n, w1, w2, ..., wn)

Apply moving weighted sum with n periods backwards to the given expression with
the given weights.
For example: `at_movsumw(x[t], w, 3) = w[1]*x[t] + w[2]*x[t-1] + w[3]*x[t-2]`
For example: `at_movsumw(x[t], 3, w) = w[1]*x[t] + w[2]*x[t-1] + w[3]*x[t-2]`

See also [`at_lag`](@ref).
"""
Expand Down
17 changes: 10 additions & 7 deletions src/model.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
##################################################################################
# This file is part of ModelBaseEcon.jl
# BSD 3-Clause License
# Copyright (c) 2020-2023, Bank of Canada
# Copyright (c) 2020-2024, Bank of Canada
# All rights reserved.
##################################################################################

Expand Down Expand Up @@ -902,7 +902,7 @@ export islog, islin
islog(eq::AbstractEquation) = flag(eq, :log)
islin(eq::AbstractEquation) = flag(eq, :lin)

function error_process(msg, expr, mod)
function error_process(msg, expr, mod)
err = ArgumentError("$msg\n During processing of\n $(expr)")
mod.eval(:(throw($err)))
end
Expand Down Expand Up @@ -1040,25 +1040,28 @@ function process_equation(model::Model, expr::Expr;
end
if ex.head == :ref
# expression is an indexing expression
name, index = ex.args
name, index... = ex.args
if haskey(model.parameters, name)
# indexing in a parameter - leave it alone, but keep track
add_pref(name)
if has_t(index)
if any(has_t, index)
error_process("Indexing parameters on time not allowed: $ex", expr, modelmodule)
end
return Expr(:ref, name, modelmodule.eval(index))
return Expr(:ref, name, modelmodule.eval.(index)...)
end
vind = indexin([name], allvars)[1] # the index of the variable
if vind !== nothing
# indexing in a time series
if length(index) != 1
error_process("Multiple indexing of variable or shock: $ex", expr, modelmodule)
end
tind = modelmodule.eval(:(
let t = 0
$index
$(index[1])
end
)) # the lag or lead value
add_tsref(allvars[vind], tind)
return normal_ref(name, tind)
return Expr(:ref, name, normal_ref(tind))
end
error_process("Undefined reference $(ex).", expr, modelmodule)
end
Expand Down
2 changes: 1 addition & 1 deletion src/steadystate.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
##################################################################################
# This file is part of ModelBaseEcon.jl
# BSD 3-Clause License
# Copyright (c) 2020-2023, Bank of Canada
# Copyright (c) 2020-2024, Bank of Canada
# All rights reserved.
##################################################################################

Expand Down
13 changes: 11 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
##################################################################################
# This file is part of ModelBaseEcon.jl
# BSD 3-Clause License
# Copyright (c) 2020-2023, Bank of Canada
# Copyright (c) 2020-2024, Bank of Canada
# All rights reserved.
##################################################################################

Expand Down Expand Up @@ -504,6 +504,8 @@ end
@test @movavw(a[t] + b[t+1], 2, p) == :((p[1] * (a[t] + b[t+1]) + p[2] * (a[t-1] + b[t])) / (p[1] + p[2]))
@test @movsumw(a[t] + b[t+1], 2, q, p) == :(q * (a[t] + b[t+1]) + p * (a[t-1] + b[t]))
@test @movavw(a[t] + b[t+1], 2, q, p) == :((q * (a[t] + b[t+1]) + p * (a[t-1] + b[t])) / (q + p))
@test @lead(v[t, 2]) == :(v[t+1, 2])
@test @dlog(v[t-1, z, t+2], 1) == :(log(v[t-1, z, t+2]) - log(v[t-2, z, t+1]))
end

module MetaTest
Expand Down Expand Up @@ -1672,6 +1674,13 @@ end
# this version of @test_throws requires Julia 1.8
@test_throws r".*Indexing parameters on time not allowed: p[t]*"i @initialize model
end

# do not allow multiple indexing of variables
@equations model begin
@delete :_EQ1
y[t, 1] = p[t] * y[t-1] + y_shk[t]
end
@test_throws ArgumentError @initialize model
Base.VERSION >= v"1.8" && @test_throws r".*Multiple indexing of variable or shock: y[t, 1]*"i @initialize model
end
end

Loading