Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Implement len function for XbatcherSlicerIterDataPipe #75

Merged
merged 2 commits into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions docs/chipping.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ This should give us about 12 chips in total, 6 from each of the 2 Sentinel-1
images that were passed in.

```{code-cell}
chips = [chip for chip in dp_xbatcher]
print(f"Number of chips: {len(chips)}")
print(f"Number of chips: {len(dp_xbatcher)}")
```

Now, if you want to customize the sliding window (e.g. do overlapping strides),
Expand All @@ -145,14 +144,14 @@ Great, and this overlapping stride method should give us more 512x512 chips 🧮
than before.

```{code-cell}
chips = [chip for chip in dp_xbatcher]
print(f"Number of chips: {len(chips)}")
print(f"Number of chips: {len(dp_xbatcher)}")
```

Double-check that single chips are of the correct dimensions
(band: 1, y: 512, x: 512).

```{code-cell}
chips = list(dp_xbatcher)
sample = chips[0]
sample
```
Expand Down Expand Up @@ -232,7 +231,7 @@ Then, pass this collate function to

```{code-cell}
dp_collate = dp_batch.collate(collate_fn=xr_collate_fn)
print(f"Number of mini-batches: {len(list(dp_collate))}")
print(f"Number of mini-batches: {len(dp_collate)}")
print(f"Mini-batch tensor shape: {list(dp_collate)[0].shape}")
```

Expand Down
2 changes: 1 addition & 1 deletion docs/object-detection-boxes.md
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def boximg_collate_fn(samples) -> (list[torch.Tensor], torch.Tensor, list[dict])

```{code-cell}
dp_collate = dp_batch.collate(collate_fn=boximg_collate_fn)
print(f"Number of mini-batches: {len(list(dp_collate))}")
print(f"Number of mini-batches: {len(dp_collate)}")
mini_batch_box, mini_batch_img, mini_batch_metadata = list(dp_collate)[1]
print(f"Mini-batch image tensor shape: {mini_batch_img.shape}")
print(f"Mini-batch box tensors: {mini_batch_box}")
Expand Down
11 changes: 7 additions & 4 deletions zen3geo/datapipes/xbatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class XbatcherSlicerIterDataPipe(IterDataPipe[Union[xr.DataArray, xr.Dataset]]):
stacked into one dimension called ``batch``.

kwargs : Optional
Extra keyword arguments to pass to :py:func:`xbatcher.BatchGenerator`.
Extra keyword arguments to pass to :py:class:`xbatcher.BatchGenerator`.

Yields
------
Expand Down Expand Up @@ -88,7 +88,7 @@ def __init__(
self,
source_datapipe: IterDataPipe[Union[xr.DataArray, xr.Dataset]],
input_dims: Dict[Hashable, int],
**kwargs: Optional[Dict[str, Any]]
**kwargs: Optional[Dict[str, Any]],
) -> None:
if xbatcher is None:
raise ModuleNotFoundError(
Expand All @@ -109,5 +109,8 @@ def __iter__(self) -> Iterator[Union[xr.DataArray, xr.Dataset]]:
):
yield chip

# def __len__(self) -> int:
# return len(self.source_datapipe)
def __len__(self) -> int:
return sum(
len(dataarray.batch.generator(input_dims=self.input_dims, **self.kwargs))
for dataarray in self.source_datapipe
)
2 changes: 2 additions & 0 deletions zen3geo/tests/test_datapipes_xbatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_xbatcher_slicer_dataarray():
# Using functional form (recommended)
dp_xbatcher = dp.slice_with_xbatcher(input_dims={"y": 64, "x": 64})

assert len(dp_xbatcher) == 4
it = iter(dp_xbatcher)
dataarray_chip = next(it)

Expand Down Expand Up @@ -55,6 +56,7 @@ def test_xbatcher_slicer_dataset():
# Using functional form (recommended)
dp_xbatcher = dp.slice_with_xbatcher(input_dims={"y": 16, "x": 16})

assert len(dp_xbatcher) == 4
it = iter(dp_xbatcher)
dataset_chip = next(it)

Expand Down