Skip to content

Commit

Permalink
Fix float rounding issues (#736)
Browse files Browse the repository at this point in the history
* Fix rounding bugs

* Mypy fixes

* Undo some changes

* Undo some changes

* mypy fixes

* Fix tests

* Increase size of ETCI2021 fake data

* Undo model changes

* line length fix
  • Loading branch information
adamjstewart committed Sep 3, 2022
1 parent bc6aa0c commit 7ac65a2
Show file tree
Hide file tree
Showing 27 changed files with 43 additions and 31 deletions.
18 changes: 15 additions & 3 deletions tests/data/etci2021/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,29 @@
{
"filename": "train.zip",
"directory": "train",
"subdirs": ["nebraska_20170108t002112", "bangladesh_20170314t115609"],
"subdirs": [
"nebraska_20170108t002112",
"bangladesh_20170314t115609",
"northal_20190302t234651",
],
},
{
"filename": "val_with_ref_labels.zip",
"directory": "test",
"subdirs": ["florence_20180510t231343", "florence_20180522t231344"],
"subdirs": [
"florence_20180510t231343",
"florence_20180522t231344",
"florence_20190302t234651",
],
},
{
"filename": "test_without_ref_labels.zip",
"directory": "test_internal",
"subdirs": ["redrivernorth_20190104t002247", "redrivernorth_20190116t002247"],
"subdirs": [
"redrivernorth_20190104t002247",
"redrivernorth_20190116t002247",
"redrivernorth_20190302t234651",
],
},
]

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/data/etci2021/test_without_ref_labels.zip
Binary file not shown.
Binary file modified tests/data/etci2021/train.zip
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/data/etci2021/val_with_ref_labels.zip
Binary file not shown.
10 changes: 5 additions & 5 deletions tests/datamodules/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ def test_dataset_split() -> None:

# Test only train/val set split
train_ds, val_ds = dataset_split(ds, val_pct=1 / 2)
assert len(train_ds) == num_samples // 2
assert len(val_ds) == num_samples // 2
assert len(train_ds) == round(num_samples / 2)
assert len(val_ds) == round(num_samples / 2)

# Test train/val/test set split
train_ds, val_ds, test_ds = dataset_split(ds, val_pct=1 / 3, test_pct=1 / 3)
assert len(train_ds) == num_samples // 3
assert len(val_ds) == num_samples // 3
assert len(test_ds) == num_samples // 3
assert len(train_ds) == round(num_samples / 3)
assert len(val_ds) == round(num_samples / 3)
assert len(test_ds) == round(num_samples / 3)
8 changes: 4 additions & 4 deletions tests/datasets/test_etci2021.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@ def dataset(
metadata = {
"train": {
"filename": "train.zip",
"md5": "ebbd2e65cd10621bc2e90a230b474b8b",
"md5": "bd55f2116e43a35d5b94a765938be2aa",
"directory": "train",
"url": os.path.join(data_dir, "train.zip"),
},
"val": {
"filename": "val_with_ref_labels.zip",
"md5": "efdd1fe6c90f5dfd267c88b86b237c2b",
"md5": "96ed69904043e514c13c14ffd3ec45cd",
"directory": "test",
"url": os.path.join(data_dir, "val_with_ref_labels.zip"),
},
"test": {
"filename": "test_without_ref_labels.zip",
"md5": "bf1180143de5705fe95fa8490835d6d1",
"md5": "1b66d85e22c8f5b0794b3542c5ea09ef",
"directory": "test_internal",
"url": os.path.join(data_dir, "test_without_ref_labels.zip"),
},
Expand All @@ -67,7 +67,7 @@ def test_getitem(self, dataset: ETCI2021) -> None:
assert x["mask"].shape[0] == 1

def test_len(self, dataset: ETCI2021) -> None:
assert len(dataset) == 2
assert len(dataset) == 3

def test_already_downloaded(self, dataset: ETCI2021) -> None:
ETCI2021(root=dataset.root, download=True)
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_changestar.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
IN_CHANNELS = [64, 128]
INNNR_CHANNELS = [16, 32, 64]
NC = [1, 2, 4]
SF = [4.0, 8.0, 1.0]
SF = [4, 8, 1]


class TestChangeStar:
Expand Down Expand Up @@ -65,15 +65,15 @@ def test_invalid_changestar_farseg_backbone(self) -> None:
"inc,innerc,nc,sf", list(itertools.product(IN_CHANNELS, INNNR_CHANNELS, NC, SF))
)
def test_changemixin_output_size(
self, inc: int, innerc: int, nc: int, sf: float
self, inc: int, innerc: int, nc: int, sf: int
) -> None:
m = ChangeMixin(
in_channels=inc, inner_channels=innerc, num_convs=nc, scale_factor=sf
)

y = m(torch.rand(3, 2, inc // 2, 32, 32))
assert y[0].shape == y[1].shape
assert y[0].shape == (3, 1, int(32 * sf), int(32 * sf))
assert y[0].shape == (3, 1, 32 * sf, 32 * sf)

@torch.no_grad() # type: ignore[misc]
def test_changestar(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ def test_iter(self, sampler: GridGeoSampler) -> None:
)

def test_len(self, sampler: GridGeoSampler) -> None:
rows = int((100 - sampler.size[0]) // sampler.stride[0]) + 1
cols = int((100 - sampler.size[1]) // sampler.stride[1]) + 1
rows = ((100 - sampler.size[0]) // sampler.stride[0]) + 1
cols = ((100 - sampler.size[1]) // sampler.stride[1]) + 1
length = rows * cols * 2
assert len(sampler) == length

Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datamodules/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
self.patch_size = patch_size
# This is a rough estimate of how large of a patch we will need to sample in
# EPSG:3857 in order to guarantee a large enough patch in the local CRS.
self.original_patch_size = int(patch_size * 2.0)
self.original_patch_size = patch_size * 2
self.batch_size = batch_size
self.num_workers = num_workers
self.class_set = class_set
Expand Down Expand Up @@ -151,8 +151,8 @@ def center_crop(
def center_crop_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
_, height, width = sample["image"].shape

y1 = (height - size) // 2
x1 = (width - size) // 2
y1 = round((height - size) / 2)
x1 = round((width - size) / 2)
sample["image"] = sample["image"][:, y1 : y1 + size, x1 : x1 + size]
sample["mask"] = sample["mask"][:, y1 : y1 + size, x1 : x1 + size]

Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datamodules/etci2021.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def setup(self, stage: Optional[str] = None) -> None:
)

size_train_val = len(train_val_dataset)
size_train = int(0.8 * size_train_val)
size_train = round(0.8 * size_train_val)
size_val = size_train_val - size_train

self.train_dataset, self.val_dataset = random_split(
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datamodules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ def dataset_split(
a list of the subset datasets. Either [train, val] or [train, val, test]
"""
if test_pct is None:
val_length = int(len(dataset) * val_pct)
val_length = round(len(dataset) * val_pct)
train_length = len(dataset) - val_length
return random_split(dataset, [train_length, val_length])
else:
val_length = int(len(dataset) * val_pct)
test_length = int(len(dataset) * test_pct)
val_length = round(len(dataset) * val_pct)
test_length = round(len(dataset) * test_pct)
train_length = len(dataset) - (val_length + test_length)
return random_split(dataset, [train_length, val_length, test_length])
4 changes: 2 additions & 2 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,8 @@ def _merge_files(self, filepaths: Sequence[str], query: BoundingBox) -> Tensor:
bounds = (query.minx, query.miny, query.maxx, query.maxy)
if len(vrt_fhs) == 1:
src = vrt_fhs[0]
out_width = int(round((query.maxx - query.minx) / self.res))
out_height = int(round((query.maxy - query.miny) / self.res))
out_width = round((query.maxx - query.minx) / self.res)
out_height = round((query.maxy - query.miny) / self.res)
out_shape = (src.count, out_height, out_width)
dest = src.read(
out_shape=out_shape, window=from_bounds(*bounds, src.transform)
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/openbuildings.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,11 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
)
if shapes:
masks = rasterio.features.rasterize(
shapes, out_shape=(int(height), int(width)), transform=transform
shapes, out_shape=(round(height), round(width)), transform=transform
)
masks = torch.tensor(masks).unsqueeze(0)
else:
masks = torch.zeros(size=(1, int(height), int(width)))
masks = torch.zeros(size=(1, round(height), round(width)))

sample = {"mask": masks, "crs": self.crs, "bbox": query}

Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/so2sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def __init__(
raise RuntimeError("Dataset not found or corrupted.")

with h5py.File(self.fn, "r") as f:
self.size = int(f["label"].shape[0])
self.size: int = f["label"].shape[0]

def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Return an index within the dataset.
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/spacenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ def _load_mask(
speed_arr_bin = np.arange(min_speed_bin, max_speed_bin + 1)
bin_size_mph = 10.0
speed_cls_arr: "np.typing.NDArray[np.int_]" = np.array(
[int(math.ceil(s / bin_size_mph)) for s in speed_arr_bin]
[math.ceil(s / bin_size_mph) for s in speed_arr_bin]
)

try:
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def __init__(
):
self.hits.append(hit)

self.length: int = 0
self.length = 0
for hit in self.hits:
bounds = BoundingBox(*hit.bounds)

Expand Down

0 comments on commit 7ac65a2

Please sign in to comment.