Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GridGeoSampler: Sample beyond roi limit if needed #629

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 47 additions & 8 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,13 @@ def sampler(self, dataset: CustomGeoDataset, request: SubRequest) -> GridGeoSamp

def test_iter(self, sampler: GridGeoSampler) -> None:
for query in sampler:
assert sampler.roi.minx <= query.minx <= query.maxx <= sampler.roi.maxx
assert sampler.roi.miny <= query.miny <= query.miny <= sampler.roi.maxy
assert sampler.roi.mint <= query.mint <= query.maxt <= sampler.roi.maxt
assert sampler.roi.minx <= query.minx
assert sampler.roi.miny <= query.miny
assert sampler.roi.mint <= query.mint
if query.maxx > sampler.roi.maxx:
assert (query.maxx - sampler.roi.maxx) < sampler.size[1]
if query.maxy > sampler.roi.maxy:
assert (query.maxy - sampler.roi.maxy) < sampler.size[0]

assert math.isclose(query.maxx - query.minx, sampler.size[1])
assert math.isclose(query.maxy - query.miny, sampler.size[0])
Expand All @@ -182,24 +186,59 @@ def test_iter(self, sampler: GridGeoSampler) -> None:
)

def test_len(self, sampler: GridGeoSampler) -> None:
rows = int((100 - sampler.size[0]) // sampler.stride[0]) + 1
cols = int((100 - sampler.size[1]) // sampler.stride[1]) + 1
rows = math.ceil((100 - sampler.size[0] + sampler.stride[0]) / sampler.stride[0])
cols = math.ceil((100 - sampler.size[1] + sampler.stride[1]) / sampler.stride[1])
length = rows * cols * 2
assert len(sampler) == length

def test_len_larger(self, sampler: GridGeoSampler) -> None:
entire_rows = (100 - sampler.size[0] + sampler.stride[0]) // sampler.stride[0]
entire_cols = (100 - sampler.size[1] + sampler.stride[1]) // sampler.stride[1]
leftover_row = (100 - sampler.size[0] + sampler.stride[0]) \
/ sampler.stride[0] - entire_rows
leftover_col = (100 - sampler.size[1] + sampler.stride[1]) \
/ sampler.stride[1] - entire_cols
assert len(sampler) == (entire_rows + math.ceil(leftover_row)) * \
(entire_cols + math.ceil(leftover_col)) * 2

def test_roi(self, dataset: CustomGeoDataset) -> None:
roi = BoundingBox(0, 50, 200, 250, 400, 450)
sampler = GridGeoSampler(dataset, 2, 1, roi=roi)
for query in sampler:
assert query in roi

def test_small_area(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 1, 0, 1, 0, 1))
sampler = GridGeoSampler(ds, 2, 10)
assert len(sampler) == 1
for bbox in sampler:
assert bbox == BoundingBox(minx=0.0, maxx=20.0, miny=0.0, maxy=20.0, mint=0.0, maxt=1.0)

# TODO: skip patches with area=0 when two tiles are side-by-side with an overlapping edge face.
def test_tiles_side_by_side(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (20, 21, 20, 21, 20, 21))
ds.index.insert(0, (0, 10, 10, 20, 0, 10))
sampler = GridGeoSampler(ds, 2, 10)
for _ in sampler:
continue
for bbox in sampler:
assert bbox.area > 0

def test_equal_area(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(ds, 10, 10, units=Units.CRS)
assert len(sampler) == 1
for bbox in sampler:
assert bbox == BoundingBox(minx=0.0, maxx=10.0, miny=0.0, maxy=10.0, mint=0.0, maxt=10.0)

def test_larger_area(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 6, 0, 5, 0, 10))
sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS)
assert len(sampler) == 2
assert list(sampler)[0] == BoundingBox(minx=0.0, maxx=5.0, miny=0.0, maxy=5.0, mint=0.0, maxt=10.0)
assert list(sampler)[1] == BoundingBox(minx=5.0, maxx=10.0, miny=0.0, maxy=5.0, mint=0.0, maxt=10.0)

@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
Expand Down
37 changes: 18 additions & 19 deletions torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""TorchGeo samplers."""

import abc
import math
from typing import Callable, Iterable, Iterator, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -107,12 +108,8 @@ def __init__(
areas = []
for hit in self.index.intersection(tuple(self.roi), objects=True):
bounds = BoundingBox(*hit.bounds)
if (
bounds.maxx - bounds.minx >= self.size[1]
and bounds.maxy - bounds.miny >= self.size[0]
):
self.hits.append(hit)
areas.append(bounds.area)
self.hits.append(hit)
areas.append(bounds.area)

# torch.multinomial requires float probabilities > 0
self.areas = torch.tensor(areas, dtype=torch.float)
Expand Down Expand Up @@ -188,29 +185,27 @@ def __init__(
.. versionchanged:: 0.3
Added ``units`` parameter, changed default to pixel units
"""
super().__init__(dataset, roi)
super().__init__(dataset=dataset, roi=roi)
self.size = _to_tuple(size)
self.stride = _to_tuple(stride)

if units == Units.PIXELS:
self.size = (self.size[0] * self.res, self.size[1] * self.res)
self.stride = (self.stride[0] * self.res, self.stride[1] * self.res)

self.hits = []
for hit in self.index.intersection(tuple(self.roi), objects=True):
bounds = BoundingBox(*hit.bounds)
if (
bounds.maxx - bounds.minx > self.size[1]
and bounds.maxy - bounds.miny > self.size[0]
):
self.hits.append(hit)
self.hits = list(self.index.intersection(tuple(self.roi), objects=True))

self.length: int = 0
for hit in self.hits:
bounds = BoundingBox(*hit.bounds)

rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1
cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1
# last patch samples outside the bounds
rows = math.ceil(
(bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) / self.stride[0]
)
cols = math.ceil(
(bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) / self.stride[1]
)
self.length += rows * cols

def __iter__(self) -> Iterator[BoundingBox]:
Expand All @@ -223,8 +218,12 @@ def __iter__(self) -> Iterator[BoundingBox]:
for hit in self.hits:
bounds = BoundingBox(*hit.bounds)

rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1
cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1
rows = math.ceil(
(bounds.maxy - bounds.miny - self.size[0] + self.stride[0]) / self.stride[0]
)
cols = math.ceil(
(bounds.maxx - bounds.minx - self.size[1] + self.stride[1]) / self.stride[1]
)

mint = bounds.mint
maxt = bounds.maxt
Expand Down