Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the fitpredict function in Interpolate and InterpolateNeighbors #14

Merged
merged 1 commit into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ ColumnSelectors = "0.1"
Combinatorics = "1.0"
DataScienceTraits = "0.1"
Distances = "0.10"
GeoStatsModels = "0.1"
GeoStatsModels = "0.2"
GeoTables = "1.9"
Meshes = "0.35"
TableDistances = "0.3"
Expand Down
2 changes: 1 addition & 1 deletion src/GeoStatsTransforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using LinearAlgebra
using Statistics

using Unitful: AffineQuantity
using GeoStatsModels: GeoStatsModel, fit, predict, predictprob
using GeoStatsModels: GeoStatsModel, fitpredict
using ColumnSelectors: ColumnSelector, SingleColumnSelector
using ColumnSelectors: Column, AllSelector, NoneSelector
using ColumnSelectors: selector, selectsingle
Expand Down
71 changes: 5 additions & 66 deletions src/interpneighbors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,11 @@ InterpolateNeighbors(domain, pairs::Pair{<:Any,<:GeoStatsModel}...; kwargs...) =
isrevertible(::Type{<:InterpolateNeighbors}) = false

function apply(transform::InterpolateNeighbors, geotable::AbstractGeoTable)
dom = domain(geotable)
tab = values(geotable)
cols = Tables.columns(tab)
vars = Tables.columnnames(cols)

idom = transform.domain
domain = transform.domain
selectors = transform.selectors
models = transform.models
minneighbors = transform.minneighbors
Expand All @@ -103,73 +102,13 @@ function apply(transform::InterpolateNeighbors, geotable::AbstractGeoTable)
point = transform.point
prob = transform.prob

nobs = nelements(dom)
if maxneighbors > nobs || maxneighbors < 1
@warn "Invalid maximum number of neighbors. Adjusting to $nobs..."
maxneighbors = nobs
end

if minneighbors > maxneighbors || minneighbors < 1
@warn "Invalid minimum number of neighbors. Adjusting to 1..."
minneighbors = 1
end

data = if point
pset = PointSet(centroid(dom, i) for i in 1:nobs)
_adjustunits(georef(values(geotable), pset))
else
_adjustunits(geotable)
end

# preprocess variable models
varmodels = mapreduce(vcat, selectors, models) do selector, model
interps = map(selectors, models) do selector, model
svars = selector(vars)
[var => model for var in svars]
data = geotable[:, svars]
fitpredict(model, data, domain; point, prob, minneighbors, maxneighbors, neighborhood, distance)
end

# determine bounded search method
searcher = searcher_ui(dom, maxneighbors, distance, neighborhood)

# pre-allocate memory for neighbors
neighbors = Vector{Int}(undef, maxneighbors)

# prediction order
inds = traverse(idom, LinearPath())

# predict variable values
function pred(var, model)
map(inds) do ind
# centroid of estimation
center = centroid(idom, ind)

# find neighbors with data
nneigh = search!(neighbors, center, searcher)

# predict if enough neighbors
if nneigh ≥ minneighbors
# final set of neighbors
ninds = view(neighbors, 1:nneigh)

# view neighborhood with data
samples = view(data, ninds)

# fit model to data
fmodel = fit(model, samples)

# save prediction
geom = point ? center : idom[ind]
pfun = prob ? predictprob : predict
pfun(fmodel, var, geom)
else # missing prediction
missing
end
end
end

pairs = (var => pred(var, model) for (var, model) in varmodels)
newtab = (; pairs...) |> Tables.materializer(tab)

newgeotable = georef(newtab, idom)
newgeotable = reduce(hcat, interps)

newgeotable, nothing
end
34 changes: 5 additions & 29 deletions src/interpolate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,47 +45,23 @@ Interpolate(domain, pairs::Pair{<:Any,<:GeoStatsModel}...; kwargs...) =
isrevertible(::Type{<:Interpolate}) = false

function apply(transform::Interpolate, geotable::AbstractGeoTable)
dom = domain(geotable)
tab = values(geotable)
cols = Tables.columns(tab)
vars = Tables.columnnames(cols)

idom = transform.domain
domain = transform.domain
selectors = transform.selectors
models = transform.models
point = transform.point
prob = transform.prob

data = if point
pset = PointSet(centroid(dom, i) for i in 1:nelements(dom))
_adjustunits(georef(values(geotable), pset))
else
_adjustunits(geotable)
end

# preprocess variable models
varmodels = mapreduce(vcat, selectors, models) do selector, model
fmodel = fit(model, data)
interps = map(selectors, models) do selector, model
svars = selector(vars)
[var => fmodel for var in svars]
end

# prediction order
inds = traverse(idom, LinearPath())

# predict variable values
function pred(var, fmodel)
map(inds) do ind
geom = point ? centroid(idom, ind) : idom[ind]
pfun = prob ? predictprob : predict
pfun(fmodel, var, geom)
end
data = geotable[:, svars]
fitpredict(model, data, domain; point, prob, neighbors=false)
end

pairs = (var => pred(var, fmodel) for (var, fmodel) in varmodels)
newtab = (; pairs...) |> Tables.materializer(tab)

newgeotable = georef(newtab, idom)
newgeotable = reduce(hcat, interps)

newgeotable, nothing
end
16 changes: 0 additions & 16 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,6 @@
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

"""
searcher_ui(domain, maxneighbors, distance, neighborhood)

Return the appropriate search method over the `domain` based on
end-user inputs such as `maxneighbors`, `distance` and `neighborhood`.
"""
function searcher_ui(domain, maxneighbors, distance, neighborhood)
if isnothing(neighborhood)
# nearest neighbor search with a metric
KNearestSearch(domain, maxneighbors; metric=distance)
else
# neighbor search with ball neighborhood
KBallSearch(domain, maxneighbors, neighborhood)
end
end

#-------------
# AGGREGATION
#-------------
Expand Down