Skip to content

Commit

Permalink
Try to make fill_pyramids handle additional dimension
Browse files Browse the repository at this point in the history
This is currently work in progress.
The output arrays need to be constructed accordingly and the additional dimensions need to be added to the windows.
  • Loading branch information
felixcremer committed Jul 2, 2024
1 parent 23ba10a commit 94634b1
Showing 1 changed file with 36 additions and 13 deletions.
49 changes: 36 additions & 13 deletions src/PyramidScheme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,40 @@ Fill the pyramids generated from the `data` with the aggregation function `func`
`recursive` indicates whether higher tiles are computed from lower tiles or directly from the original data.
This is an optimization which for functions like median might lead to misleading results.
"""
function fill_pyramids(data, outputs,func,recursive;runner=LocalRunner, kwargs...)
function fill_pyramids(data, outputs,func,recursive;runner=LocalRunner, verbose=false, outtype=:mem, kwargs...)
t = typeof(func(zeros(eltype(data), 2,2)))
n_level = compute_nlevels(data)
input_axes = pyramidedaxes(data)
nonpyramiddims = DD.otherdims(data, input_axes)
@show nonpyramiddims
if length(input_axes) != 2
throw(ArgumentError("Expected two spatial dimensions got $input_axes"))
end
verbose && println("Constructing output arrays")
spatialsize = size(data)[collect(DD.dimnum(data, input_axes))]
pyramid_sizes = [ceil.(Int, spatialsize ./ 2^i) for i in 1:n_level]
allsizes = [spatialsize..., [1 for o in nonpyramiddims]...]
sizeperm = [DD.dimnum(data, input_axes)..., DD.dimnum(data, nonpyramiddims)...]
permute!(allsizes, sizeperm)
@show allsizes
outputs = if outtype == :zarr
[output_zarr(n, input_axes, t, joinpath(path, string(n))) for n in 1:n_level]
elseif outtype == :mem
outmin = output_arrays(pyramid_sizes, t)
else
throw(ArgumentError("Output type not valied got $outtype expected :mem or :zarr"))
end

verbose && println("Start computation")
n_level = length(outputs)
@show typeof(data)
@show n_level
@show size.(outputs)
pixel_base_size = 2^n_level
pyramid_sizes = size.(outputs)
tmp_sizes = [ceil(Int,pixel_base_size / 2^i) for i in 1:n_level]

ia = InputArray(data, windows = arraywindows(size(data),pixel_base_size))
windows = arraywindows(allsizes,pixel_base_size)
ia = InputArray(data;windows)

oa = ntuple(i->create_outwindows(pyramid_sizes[i],windows = arraywindows(pyramid_sizes[i],tmp_sizes[i])),n_level)

Expand Down Expand Up @@ -254,6 +281,8 @@ Construct a list of `RegularWindows` for the size list in `s` for windows `w`.
??
"""
function arraywindows(s,w)
@show s
@show w
map(s) do l
RegularWindows(1,l,window=w)
end
Expand Down Expand Up @@ -307,6 +336,9 @@ Union of Dimensions which are assumed to be in space and are therefore used in t
"""
SpatialDim = Union{DD.Dimensions.XDim, DD.Dimensions.YDim}

pyramidedaxes(input) = filter(x-> x isa SpatialDim, DD.dims(input))


"""
buildpyramids(path; resampling_method=mean)
Build the pyramids for the zarr dataset at `path` and write the pyramid layers into the zarr folder.
Expand All @@ -323,15 +355,6 @@ function buildpyramids(path; resampling_method=mean, recursive=true, runner=Loca
# Build a loop for all variables in a dataset?
org = Cube(path)
# We run the method once to derive the output type
t = typeof(resampling_method(zeros(eltype(org), 2,2)))
n_level = compute_nlevels(org)
input_axes = filter(x-> x isa SpatialDim, DD.dims(org))
if length(input_axes) != 2
throw(ArgumentError("Expected two spatial dimensions got $input_axes"))
end
verbose && println("Constructing output arrays")
outarrs = [output_zarr(n, input_axes, t, joinpath(path, string(n))) for n in 1:n_level]
verbose && println("Start computation")
fill_pyramids(org, outarrs, resampling_method, recursive;runner)
pyraxs = [agg_axis.(input_axes, 2^n) for n in 1:n_level]
pyrlevels = DD.DimArray.(outarrs, pyraxs)
Expand Down Expand Up @@ -375,7 +398,7 @@ Compute the data of the pyramids of a given data cube `ras`.
This returns the data of the pyramids and the dimension values of the aggregated axes.
"""
function getpyramids(reducefunc, ras;recursive=true)
input_axes = DD.dims(ras)
input_axes = pyramidedaxes(ras)
n_level = compute_nlevels(ras)
if iszero(n_level)
@info "Array is smaller than the tilesize no pyramids are computed"
Expand Down

0 comments on commit 94634b1

Please sign in to comment.