diff --git a/Project.toml b/Project.toml index a70d464..ed1fda3 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,7 @@ ColumnSelectors = "0.1" Combinatorics = "1.0" DataScienceTraits = "0.1" Distances = "0.10" -GeoStatsModels = "0.1" +GeoStatsModels = "0.2" GeoTables = "1.9" Meshes = "0.35" TableDistances = "0.3" diff --git a/src/GeoStatsTransforms.jl b/src/GeoStatsTransforms.jl index 3d50036..c502e49 100644 --- a/src/GeoStatsTransforms.jl +++ b/src/GeoStatsTransforms.jl @@ -23,7 +23,7 @@ using LinearAlgebra using Statistics using Unitful: AffineQuantity -using GeoStatsModels: GeoStatsModel, fit, predict, predictprob +using GeoStatsModels: GeoStatsModel, fitpredict using ColumnSelectors: ColumnSelector, SingleColumnSelector using ColumnSelectors: Column, AllSelector, NoneSelector using ColumnSelectors: selector, selectsingle diff --git a/src/interpneighbors.jl b/src/interpneighbors.jl index 9b915a1..584691a 100644 --- a/src/interpneighbors.jl +++ b/src/interpneighbors.jl @@ -88,12 +88,11 @@ InterpolateNeighbors(domain, pairs::Pair{<:Any,<:GeoStatsModel}...; kwargs...) = isrevertible(::Type{<:InterpolateNeighbors}) = false function apply(transform::InterpolateNeighbors, geotable::AbstractGeoTable) - dom = domain(geotable) tab = values(geotable) cols = Tables.columns(tab) vars = Tables.columnnames(cols) - idom = transform.domain + domain = transform.domain selectors = transform.selectors models = transform.models minneighbors = transform.minneighbors @@ -103,73 +102,13 @@ function apply(transform::InterpolateNeighbors, geotable::AbstractGeoTable) point = transform.point prob = transform.prob - nobs = nelements(dom) - if maxneighbors > nobs || maxneighbors < 1 - @warn "Invalid maximum number of neighbors. Adjusting to $nobs..." - maxneighbors = nobs - end - - if minneighbors > maxneighbors || minneighbors < 1 - @warn "Invalid minimum number of neighbors. Adjusting to 1..." - minneighbors = 1 - end - - data = if point - pset = PointSet(centroid(dom, i) for i in 1:nobs) - _adjustunits(georef(values(geotable), pset)) - else - _adjustunits(geotable) - end - - # preprocess variable models - varmodels = mapreduce(vcat, selectors, models) do selector, model + interps = map(selectors, models) do selector, model svars = selector(vars) - [var => model for var in svars] + data = geotable[:, svars] + fitpredict(model, data, domain; point, prob, minneighbors, maxneighbors, neighborhood, distance) end - # determine bounded search method - searcher = searcher_ui(dom, maxneighbors, distance, neighborhood) - - # pre-allocate memory for neighbors - neighbors = Vector{Int}(undef, maxneighbors) - - # prediction order - inds = traverse(idom, LinearPath()) - - # predict variable values - function pred(var, model) - map(inds) do ind - # centroid of estimation - center = centroid(idom, ind) - - # find neighbors with data - nneigh = search!(neighbors, center, searcher) - - # predict if enough neighbors - if nneigh ≥ minneighbors - # final set of neighbors - ninds = view(neighbors, 1:nneigh) - - # view neighborhood with data - samples = view(data, ninds) - - # fit model to data - fmodel = fit(model, samples) - - # save prediction - geom = point ? center : idom[ind] - pfun = prob ? predictprob : predict - pfun(fmodel, var, geom) - else # missing prediction - missing - end - end - end - - pairs = (var => pred(var, model) for (var, model) in varmodels) - newtab = (; pairs...) |> Tables.materializer(tab) - - newgeotable = georef(newtab, idom) + newgeotable = reduce(hcat, interps) newgeotable, nothing end diff --git a/src/interpolate.jl b/src/interpolate.jl index f3f8e33..05c8389 100644 --- a/src/interpolate.jl +++ b/src/interpolate.jl @@ -45,47 +45,23 @@ Interpolate(domain, pairs::Pair{<:Any,<:GeoStatsModel}...; kwargs...) = isrevertible(::Type{<:Interpolate}) = false function apply(transform::Interpolate, geotable::AbstractGeoTable) - dom = domain(geotable) tab = values(geotable) cols = Tables.columns(tab) vars = Tables.columnnames(cols) - idom = transform.domain + domain = transform.domain selectors = transform.selectors models = transform.models point = transform.point prob = transform.prob - data = if point - pset = PointSet(centroid(dom, i) for i in 1:nelements(dom)) - _adjustunits(georef(values(geotable), pset)) - else - _adjustunits(geotable) - end - - # preprocess variable models - varmodels = mapreduce(vcat, selectors, models) do selector, model - fmodel = fit(model, data) + interps = map(selectors, models) do selector, model svars = selector(vars) - [var => fmodel for var in svars] - end - - # prediction order - inds = traverse(idom, LinearPath()) - - # predict variable values - function pred(var, fmodel) - map(inds) do ind - geom = point ? centroid(idom, ind) : idom[ind] - pfun = prob ? predictprob : predict - pfun(fmodel, var, geom) - end + data = geotable[:, svars] + fitpredict(model, data, domain; point, prob, neighbors=false) end - pairs = (var => pred(var, fmodel) for (var, fmodel) in varmodels) - newtab = (; pairs...) |> Tables.materializer(tab) - - newgeotable = georef(newtab, idom) + newgeotable = reduce(hcat, interps) newgeotable, nothing end diff --git a/src/utils.jl b/src/utils.jl index 26cac40..4966926 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,22 +2,6 @@ # Licensed under the MIT License. See LICENSE in the project root. # ------------------------------------------------------------------ -""" - searcher_ui(domain, maxneighbors, distance, neighborhood) - -Return the appropriate search method over the `domain` based on -end-user inputs such as `maxneighbors`, `distance` and `neighborhood`. -""" -function searcher_ui(domain, maxneighbors, distance, neighborhood) - if isnothing(neighborhood) - # nearest neighbor search with a metric - KNearestSearch(domain, maxneighbors; metric=distance) - else - # neighbor search with ball neighborhood - KBallSearch(domain, maxneighbors, neighborhood) - end -end - #------------- # AGGREGATION #-------------