From ede097f8ab6aaa9c553cdcb21ea8b54b87f5e40d Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Thu, 28 Apr 2016 21:15:26 +0200 Subject: [PATCH] ModelMatrix need to be able to align factors (pooled arrays used in ModelFrame). --- src/statsmodels/formula.jl | 39 ++++++++++++++++++++++++++++++++++- src/statsmodels/statsmodel.jl | 2 +- test/statsmodel.jl | 3 +-- 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/src/statsmodels/formula.jl b/src/statsmodels/formula.jl index f211e36899..6ee21429b9 100644 --- a/src/statsmodels/formula.jl +++ b/src/statsmodels/formula.jl @@ -292,13 +292,50 @@ function nc(trm::Vector) n end -function ModelMatrix(mf::ModelFrame) +function alignpool{T,Rx,Ry,N}(x::PooledDataArray{T, Rx, N}, y::PooledDataArray{T, Ry, N}) + if x.pool == y.pool + return x + end + + xi = DataFrame(pool=x.pool, xi=1:length(x.pool)) + yi = DataFrame(pool=y.pool, yi=1:length(y.pool)) + d = join(xi, yi, on=:pool, kind=:left) + + # Validate that x has support in y + i = findfirst(isna(d[:yi])) + if i>0 + if length(y.pool)<10 + error("Unknown level: ", d[i,:pool], ". Expected one of: ", y.pool) + else + error("Unknown level: ", d[i,:pool], ". Expected one of ", length(y.pool), " levels in reference dataframe.") + end + end + + newrefs = Array{Ry}(size(x.refs)...) + for r in eachrow(d) + newrefs[x.refs .== r[:xi]] = r[:yi] + end + + PooledDataArray(DataArrays.RefArray(newrefs), y.pool) +end + +function alignpool(x::DataArray, ::DataArray) + return x +end + +function ModelMatrix(mf::ModelFrame, referece_df = mf.df) trms = mf.terms aa = Any[Any[ones(size(mf.df,1), @compat(Int(trms.intercept)))]] asgn = zeros(Int, @compat(Int(trms.intercept))) fetrms = Bool[isfe(t) for t in trms.terms] if trms.response unshift!(fetrms, false) end ff = trms.factors[:, fetrms] + + # need to use the same levels in predictions as for regression + for n in trms.eterms + mf.df[n] = alignpool(mf.df[n], referece_df[n]) + end + ## need to be cautious here to avoid evaluating cols for a factor with many levels ## if the factor doesn't occur in the fetrms rows = Bool[x != 0 for x in sum(ff, 2)] diff --git a/src/statsmodels/statsmodel.jl b/src/statsmodels/statsmodel.jl index 2094aa4530..4b524048e6 100644 --- a/src/statsmodels/statsmodel.jl +++ b/src/statsmodels/statsmodel.jl @@ -79,7 +79,7 @@ function StatsBase.predict(mm::DataFrameRegressionModel, df::AbstractDataFrame) newTerms = remove_response(mm.mf.terms) # create new model frame/matrix mf = ModelFrame(newTerms, df) - newX = ModelMatrix(mf).m + newX = ModelMatrix(mf, mm.mf.df[1:0,:]).m yp = predict(mm, newX) out = DataArray(eltype(yp), size(df, 1)) out[mf.msng] = yp diff --git a/test/statsmodel.jl b/test/statsmodel.jl index ffe0cab63e..bfd92e61c2 100644 --- a/test/statsmodel.jl +++ b/test/statsmodel.jl @@ -66,8 +66,7 @@ m2 = fit(DummyMod, f2, d) @test coeftable(m2).rownms == ["(Intercept)", "x1p - 6", "x1p - 7", "x1p - 8"] ## predict w/ new data missing levels -## FAILS: mismatch between number of model matrix columns -## @test predict(m2, d[2:4, :]) == predict(m2)[2:4] +@test predict(m2, d[2:4, :]) == predict(m2)[2:4] ## Another dummy model type to test fall-through show method