-
Notifications
You must be signed in to change notification settings - Fork 376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
EuroSATDataModule initialised with 3 bands gives error: mean length and number of channels do not match #1634
Comments
Tried to hack around this with
But error still arises |
If I understand correctly, what you want to use is SENTINEL2_RGB_MOCO for the weights and not SENTINEL2_ALL_MOCO which expects all the multispectral bands. |
Also, your hack wouldn't actually work. You would need to override the actual Normalize transform itself since it's already been instantiated. We prefer that users simply subclass the datamodule into their own custom datamodule if they want to override attributes. |
With
But I should replace EuroSATDataModule with a subclass that re-implements the normalisation? My objective is a simple tutorial for a hackathon where people can experiment with number of bands, pretrained weights etc |
I've attempted the custom module below: # extract for RGB later
mins = torch.tensor(
[
1013.0,
676.0,
448.0,
247.0,
269.0,
253.0,
243.0,
189.0,
61.0,
4.0,
33.0,
11.0,
186.0,
]
)
maxs = torch.tensor(
[
2309.0,
4543.05,
4720.2,
5293.05,
3902.05,
4473.0,
5447.0,
5948.05,
1829.0,
23.0,
4894.05,
4076.05,
5846.0,
]
)
# use vaules from https://github.com/microsoft/torchgeo/blob/main/torchgeo/datasets/eurosat.py
bands = {
"B01": "Coastal Aerosol",
"B02": "Blue",
"B03": "Green",
"B04": "Red",
"B05": "Vegetation Red Edge 1",
"B06": "Vegetation Red Edge 2",
"B07": "Vegetation Red Edge 3",
"B08": "NIR 1",
"B08A": "NIR 2",
"B09": "Water Vapour",
"B10": "SWIR 1",
"B11": "SWIR 2",
"B12": "SWIR 3",
}
rgb_bands = ("B04", "B03", "B02") # or experiment with all
# Get the indices of the keys represented in rgb_bands
rgb_indices = [list(bands.keys()).index(band) for band in rgb_bands]
mins = mins[rgb_indices]
maxs = maxs[rgb_indices]
class MinMaxNormalize(K.IntensityAugmentationBase2D):
"""Normalize channels to the range [0, 1] using min/max values."""
def __init__(self, mins: Tensor, maxs: Tensor) -> None:
super().__init__(p=1)
self.flags = {"mins": mins.view(1, -1, 1, 1), "maxs": maxs.view(1, -1, 1, 1)}
def apply_transform(
self,
input: Tensor,
params: Dict[str, Tensor],
flags: Dict[str, int],
transform: Optional[Tensor] = None,
) -> Tensor:
return (input - flags["mins"]) / (flags["maxs"] - flags["mins"] + 1e-10)
class RGBEuroSATDataModule(NonGeoDataModule):
def __init__(self, data_dir: str, batch_size: int = 64):
super().__init__(batch_size=batch_size, dataset_class=NonGeoDataset)
self.data_dir = data_dir
self.batch_size = batch_size
self.rgb_bands = ("B04", "B03", "B02") # or experiment with all
# Define transforms
self.train_transforms = AugmentationSequential(
MinMaxNormalize(mins, maxs),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
K.RandomAffine(degrees=(0, 90), p=0.25),
K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25),
K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25),
data_keys=["image"],
)
self.test_transforms = nn.Sequential(
MinMaxNormalize(mins, maxs),
)
def setup(self, stage=None):
if stage == 'fit' or stage is None:
self.train_dataset = EuroSAT(root=self.data_dir, split='train', transforms=self.train_transforms, download=True, bands=self.rgb_bands)
self.val_dataset = EuroSAT(root=self.data_dir, split='val', transforms=self.test_transforms, download=True, bands=self.rgb_bands)
if stage == 'test' or stage is None:
self.test_dataset = EuroSAT(root=self.data_dir, split='test', transforms=self.test_transforms, download=True, bands=self.rgb_bands)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size)
datamodule = RGBEuroSATDataModule(
data_dir="data",
batch_size=batch_size,
)
task = ClassificationTask(
model="resnet18",
# weights=True, # standard Imagenet
# weights=ResNet18_Weights.SENTINEL2_ALL_MOCO, # or try sentinel 2 all bands
weights=ResNet18_Weights.SENTINEL2_RGB_MOCO, # or try sentinel 2 rgb bands
num_classes=10,
in_channels=len(datamodule.train_dataset.bands), # make sure to validate
loss="ce",
patience=10
)
trainer.fit(model=task, datamodule=datamodule)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[19], line 1
----> 1 trainer.fit(model=task, datamodule=datamodule)
File [/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532](https://vscode-remote+lightning-002eai.vscode-resource.vscode-cdn.net/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532), in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
530 self.strategy._lightning_module = model
531 _verify_strategy_supports_compile(model, self.strategy)
--> 532 call._call_and_handle_interrupt(
533 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
534 )
File [/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:43](https://vscode-remote+lightning-002eai.vscode-resource.vscode-cdn.net/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:43), in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
41 if trainer.strategy.launcher is not None:
42 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 43 return trainer_fn(*args, **kwargs)
45 except _TunerExitException:
46 _call_teardown_hook(trainer)
File [/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:571](https://vscode-remote+lightning-002eai.vscode-resource.vscode-cdn.net/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:571), in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
561 self._data_connector.attach_data(
562 model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
563 )
565 ckpt_path = self._checkpoint_connector._select_ckpt_path(
566 self.state.fn,
...
--> 199 input_shape = in_tensor.shape
200 in_tensor = self.transform_tensor(in_tensor)
201 batch_shape = in_tensor.shape
AttributeError: 'dict' object has no attribute 'shape' |
OK appears I just overcomplicated it - get the correct MEAN and STD then just class RGBEuroSATDataModule(NonGeoDataModule):
mean = MEAN
std = STD
rgb_bands = ("B04", "B03", "B02") # or experiment with all
def __init__(
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a new EuroSATDataModule instance.
Args:
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.EuroSAT`.
"""
super().__init__(EuroSAT, batch_size, num_workers, bands=self.rgb_bands, **kwargs) |
I personally consider it a bug that |
A quick check of OSCD shows it supports 'all' or 'rbg' normalisation. Likewise so2sat supports via a 'band_set' approach. However I prefer the more flexible approach I demonstrated above, which allows any combo of bands. Should I implement that, or the 'band_set' approach? |
I'm fine with the more flexible approach. You could even modify OSCD/So2Sat/other datasets to match, but that's too much work to ask you to do. |
@adamjstewart sure will do after #1646 |
Description
I have previously:
download=True
the EuroSATDataModule with 13 bandsI now:
Create a task:
Train:
This presumably occurs on the attempted normalisation, which is expecting 13 bands
Steps to reproduce
As above
Version
0.5.0
The text was updated successfully, but these errors were encountered: