-
Notifications
You must be signed in to change notification settings - Fork 370
/
formula.jl
359 lines (320 loc) · 11.8 KB
/
formula.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
# Formulas for representing and working with linear-model-type expressions
# Original by Harlan D. Harris. Later modifications by John Myles White
# and Douglas M. Bates.
## Formulas are written as expressions and parsed by the Julia parser.
## For example :(y ~ a + b + log(c))
## In Julia the & operator is used for an interaction. What would be written
## in R as y ~ a + b + a:b is written :(y ~ a + b + a&b) in Julia.
## The equivalent R expression, y ~ a*b, is the same in Julia
## The lhs of a one-sided formula is 'nothing'
## The rhs of a formula can be 1
type Formula
lhs::@compat(Union{Symbol, Expr, Void})
rhs::@compat(Union{Symbol, Expr, Integer})
end
macro ~(lhs, rhs)
ex = Expr(:call,
:Formula,
Base.Meta.quot(lhs),
Base.Meta.quot(rhs))
return ex
end
type Terms
terms::Vector
eterms::Vector # evaluation terms
factors::Matrix{Int8} # maps terms to evaluation terms
order::Vector{Int} # orders of rhs terms
response::Bool # indicator of a response, which is eterms[1] if present
intercept::Bool # is there an intercept column in the model matrix?
end
type ModelFrame
df::AbstractDataFrame
terms::Terms
msng::BitArray
end
type ModelMatrix{T <: @compat(Union{Float32, Float64})}
m::Matrix{T}
assign::Vector{Int}
end
Base.size(mm::ModelMatrix) = size(mm.m)
Base.size(mm::ModelMatrix, dim...) = size(mm.m, dim...)
function Base.show(io::IO, f::Formula)
print(io,
string("Formula: ",
f.lhs == nothing ? "" : f.lhs, " ~ ", f.rhs))
end
## Return, as a vector of symbols, the names of all the variables in
## an expression or a formula
function allvars(ex::Expr)
if ex.head != :call error("Non-call expression encountered") end
cc=Symbol[]
for i in ex.args[2:end] cc=append!(cc,allvars(i)) end
cc
end
allvars(f::Formula) = unique(vcat(allvars(f.rhs), allvars(f.lhs)))
allvars(sym::Symbol) = [sym]
allvars(v::Any) = Array(Symbol, 0)
# special operators in formulas
const specials = Set([:+, :-, :*, :/, :&, :|, :^])
function dospecials(ex::Expr)
if ex.head != :call error("Non-call expression encountered") end
a1 = ex.args[1]
if !(a1 in specials) return ex end
excp = copy(ex)
excp.args = vcat(a1,map(dospecials, ex.args[2:end]))
if a1 != :* return excp end
aa = excp.args
a2 = aa[2]
a3 = aa[3]
if length(aa) > 3
excp.args = vcat(a1, aa[3:end])
a3 = dospecials(excp)
end
## this order of expansion gives the R-style ordering of interaction
## terms (after sorting in increasing interaction order) for higher-
## order interaction terms (e.g. x1 * x2 * x3 should expand to x1 +
## x2 + x3 + x1&x2 + x1&x3 + x2&x3 + x1&x2&x3)
:($a2 + $a2 & $a3 + $a3)
end
dospecials(a::Any) = a
## Distribution of & over +
const distributive = @compat Dict(:& => :+)
distribute(ex::Expr) = distribute!(copy(ex))
distribute(a::Any) = a
## apply distributive property in-place
function distribute!(ex::Expr)
if ex.head != :call error("Non-call expression encountered") end
[distribute!(a) for a in ex.args[2:end]]
## check that top-level can be distributed
a1 = ex.args[1]
if a1 in keys(distributive)
## which op is being DISTRIBUTED (e.g. &, *)?
distributed_op = a1
## which op is doing the distributing (e.g. +)?
distributing_op = distributive[a1]
## detect distributing sub-expression (first arg is, e.g. :+)
is_distributing_subex(e) =
typeof(e)==Expr && e.head == :call && e.args[1] == distributing_op
## find first distributing subex
first_distributing_subex = findfirst(is_distributing_subex, ex.args)
if first_distributing_subex != 0
## remove distributing subexpression from args
subex = splice!(ex.args, first_distributing_subex)
newargs = Any[distributing_op]
## generate one new sub-expression, which calls the distributed operation
## (e.g. &) on each of the distributing sub-expression's arguments, plus
## the non-distributed arguments of the original expression.
for a in subex.args[2:end]
new_subex = copy(ex)
push!(new_subex.args, a)
## need to recurse here, in case there are any other
## distributing operations in the sub expression
distribute!(new_subex)
push!(newargs, new_subex)
end
ex.args = newargs
end
end
ex
end
distribute!(a::Any) = a
const associative = Set([:+,:*,:&]) # associative special operators
## If the expression is a call to the function s return its arguments
## Otherwise return the expression
function ex_or_args(ex::Expr,s::Symbol)
if ex.head != :call error("Non-call expression encountered") end
if ex.args[1] == s
## recurse in case there are more :calls of s below
return vcat(map(x -> ex_or_args(x, s), ex.args[2:end])...)
else
## not a :call to s, return condensed version of ex
return condense(ex)
end
end
ex_or_args(a,s::Symbol) = a
## Condense calls like :(+(a,+(b,c))) to :(+(a,b,c))
function condense(ex::Expr)
if ex.head != :call error("Non-call expression encountered") end
a1 = ex.args[1]
if !(a1 in associative) return ex end
excp = copy(ex)
excp.args = vcat(a1, map(x->ex_or_args(x,a1), ex.args[2:end])...)
excp
end
condense(a::Any) = a
## always return an ARRAY of terms
getterms(ex::Expr) = (ex.head == :call && ex.args[1] == :+) ? ex.args[2:end] : Expr[ex]
getterms(a::Any) = Any[a]
ord(ex::Expr) = (ex.head == :call && ex.args[1] == :&) ? length(ex.args)-1 : 1
ord(a::Any) = 1
const nonevaluation = Set([:&,:|]) # operators constructed from other evaluations
## evaluation terms - the (filtered) arguments for :& and :|, otherwise the term itself
function evt(ex::Expr)
if ex.head != :call error("Non-call expression encountered") end
if !(ex.args[1] in nonevaluation) return ex end
filter(x->!isa(x,Number), vcat(map(getterms, ex.args[2:end])...))
end
evt(a) = Any[a]
function Terms(f::Formula)
rhs = condense(distribute(dospecials(f.rhs)))
tt = unique(getterms(rhs))
tt = tt[!(tt .== 1)] # drop any explicit 1's
noint = (tt .== 0) | (tt .== -1) # should also handle :(-(expr,1))
tt = tt[!noint]
oo = Int[ord(t) for t in tt] # orders of interaction terms
if !issorted(oo) # sort terms by increasing order
pp = sortperm(oo)
tt = tt[pp]
oo = oo[pp]
end
etrms = map(evt, tt)
haslhs = f.lhs != nothing
if haslhs
unshift!(etrms, Any[f.lhs])
unshift!(oo, 1)
end
ev = unique(vcat(etrms...))
sets = [Set(x) for x in etrms]
facs = Int8[t in s for t in ev, s in sets]
Terms(tt, ev, facs, oo, haslhs, !any(noint))
end
function remove_response(t::Terms)
# shallow copy original terms
t = Terms(t.terms, t.eterms, t.factors, t.order, t.response, t.intercept)
if t.response
t.order = t.order[2:end]
t.eterms = t.eterms[2:end]
t.factors = t.factors[2:end, 2:end]
t.response = false
end
return t
end
## Default NA handler. Others can be added as keyword arguments
function na_omit(df::DataFrame)
cc = complete_cases(df)
df[cc,:], cc
end
## Trim the pool field of da to only those levels that occur in the refs
function dropUnusedLevels!(da::PooledDataArray)
rr = da.refs
uu = unique(rr)
length(uu) == length(da.pool) && return da
T = eltype(rr)
su = sort!(uu)
dict = Dict(zip(su, one(T):convert(T, length(uu))))
da.refs = map(x -> dict[x], rr)
da.pool = da.pool[uu]
da
end
dropUnusedLevels!(x) = x
function ModelFrame(trms::Terms, d::AbstractDataFrame)
df, msng = na_omit(DataFrame(map(x -> d[x], trms.eterms)))
names!(df, convert(Vector{Symbol}, map(string, trms.eterms)))
for c in eachcol(df) dropUnusedLevels!(c[2]) end
ModelFrame(df, trms, msng)
end
ModelFrame(f::Formula, d::AbstractDataFrame) = ModelFrame(Terms(f), d)
ModelFrame(ex::Expr, d::AbstractDataFrame) = ModelFrame(Formula(ex), d)
function StatsBase.model_response(mf::ModelFrame)
mf.terms.response || error("Model formula one-sided")
convert(Array, mf.df[round(Bool, mf.terms.factors[:, 1])][:, 1])
end
function contr_treatment(n::Integer, contrasts::Bool, sparse::Bool, base::Integer)
if n < 2 error("not enought degrees of freedom to define contrasts") end
contr = sparse ? speye(n) : eye(n) .== 1.
if !contrasts return contr end
if !(1 <= base <= n) error("base = $base is not allowed for n = $n") end
contr[:,vcat(1:(base-1),(base+1):end)]
end
contr_treatment(n::Integer,contrasts::Bool,sparse::Bool) = contr_treatment(n,contrasts,sparse,1)
contr_treatment(n::Integer,contrasts::Bool) = contr_treatment(n,contrasts,false,1)
contr_treatment(n::Integer) = contr_treatment(n,true,false,1)
cols(v::PooledDataVector) = contr_treatment(length(v.pool))[v.refs,:]
cols(v::DataVector) = convert(Vector{Float64}, v.data)
cols(v::Vector) = convert(Vector{Float64}, v)
function isfe(ex::Expr) # true for fixed-effects terms
if ex.head != :call error("Non-call expression encountered") end
ex.args[1] != :|
end
isfe(a) = true
function expandcols(trm::Vector)
if length(trm) == 1
return convert(Array{Float64}, trm[1])
else
a = convert(Array{Float64}, trm[1])
b = expandcols(trm[2:end])
nca = size(a, 2)
ncb = size(b, 2)
return hcat([a[:, i] .* b[:, j] for i in 1:nca, j in 1:ncb]...)
end
end
function nc(trm::Vector)
isempty(trm) && return 0
n = 1
for x in trm
n *= size(x, 2)
end
n
end
function ModelMatrix(mf::ModelFrame)
trms = mf.terms
aa = Any[Any[ones(size(mf.df,1), @compat(Int(trms.intercept)))]]
asgn = zeros(Int, @compat(Int(trms.intercept)))
fetrms = Bool[isfe(t) for t in trms.terms]
if trms.response unshift!(fetrms, false) end
ff = trms.factors[:, fetrms]
## need to be cautious here to avoid evaluating cols for a factor with many levels
## if the factor doesn't occur in the fetrms
rows = Bool[x != 0 for x in sum(ff, 2)]
ff = ff[rows, :]
cc = [cols(col) for col in columns(mf.df[:, rows])]
for j in 1:size(ff,2)
trm = cc[round(Bool, ff[:, j])]
push!(aa, trm)
asgn = vcat(asgn, fill(j, nc(trm)))
end
ModelMatrix{Float64}(hcat([expandcols(t) for t in aa]...), asgn)
end
termnames(term::Symbol, col) = [string(term)]
function termnames(term::Symbol, col::PooledDataArray)
levs = levels(col)
[string(term, " - ", levs[i]) for i in 2:length(levs)]
end
function coefnames(mf::ModelFrame)
if mf.terms.intercept
vnames = Compat.UTF8String["(Intercept)"]
else
vnames = Compat.UTF8String[]
end
# Need to only include active levels
for term in mf.terms.terms
if isa(term, Expr)
if term.head == :call && term.args[1] == :|
continue # skip random-effects terms
elseif term.head == :call && term.args[1] == :&
## for an interaction term, combine term names pairwise,
## starting with rightmost terms
append!(vnames,
foldr((a,b) ->
vec([string(lev1, " & ", lev2) for
lev1 in a,
lev2 in b]),
map(x -> termnames(x, mf.df[x]), term.args[2:end])))
else
error("unrecognized term $term")
end
else
append!(vnames, termnames(term, mf.df[term]))
end
end
return vnames
end
function Formula(t::Terms)
lhs = t.response ? t.eterms[1] : nothing
rhs = Expr(:call,:+)
if t.intercept
push!(rhs.args,1)
end
append!(rhs.args,t.terms)
Formula(lhs,rhs)
end