-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactored preprocessor/merged mt feature extraction in encoder/fixed some bugs #86
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
Thank you for the pull request, the preprocessing refactoring was needed and much appreciated. The same for the sliding inference of the evaluator.
I included several comments for this review, most of them are related to comments about the shape of the tensors, but one of them requires special attention: we can not merge multi-temporal frames on the batch dimension otherwise we may face OOM.
I have not run any experiments with the codebase for now, but I will soon.
pangaea/run.py
Outdated
@@ -166,7 +161,7 @@ def main(cfg: DictConfig) -> None: | |||
num_workers=cfg.num_workers, | |||
pin_memory=True, | |||
# persistent_workers=True causes memory leak | |||
persistent_workers=False, | |||
persistent_workers=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you checked the system memory usage during training? With previous preprocessing it rocketed up when using persistent_workers=True and it was the reason why it was set up to False.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
pangaea/encoders/base.py
Outdated
""" | ||
raise NotImplementedError | ||
|
||
def forward(self, image: dict[str, torch.Tensor]) -> list[torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you document the code with the shape of the tensors, please?
It mostly misses some docstrings that set the shape of the input and output tensors explicitly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
pangaea/encoders/base.py
Outdated
|
||
def naive_multi_temporal_forward(self, image: dict[str, torch.Tensor]) -> list[torch.Tensor]: | ||
b, c, t, h, w = image[list(image.keys())[0]].shape | ||
image = {k: v.transpose(1, 2).contiguous().view(-1, c, 1, h, w) for k, v in image.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand and agree with the fact that we can consider we have T single-temporal images to deal with. Nevertheless it was previously done with a for loop and collecting features one by one by asking the encoder to compute them from input images of shape (B, C, 1, H, W). With this implementation, we are feeding the model with size inputs (BxT, C, 1, H, W) and it will cause out-of-memory issues if we don't have a large memory headroom.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I will make it optional
pangaea/run.py
Outdated
@@ -56,6 +57,8 @@ def main(cfg: DictConfig) -> None: | |||
fix_seed(cfg.seed) | |||
# distributed training variables | |||
rank = int(os.environ["RANK"]) | |||
cfg.task.trainer.use_wandb = cfg.task.trainer.use_wandb and rank == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it is the best way to do it, rewriting configs is not a good practice. Instead, we can use:
use_wandb = cfg.task.trainer.use_wandb and rank == 0
and overwrite use_wandb in the trainer and evaluator instantiate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
pangaea/decoders/upernet.py
Outdated
|
||
# If the encoder handles only single temporal data, we apply multi_temporal_strategy | ||
if self.multi_temporal_strategy == "linear": | ||
feats = [self.tmap(feat.permute(0, 1, 3, 4, 2)).squeeze(-1) for feat in feats] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you comment on the shape of the features (before and after the permute and squeeze) so it is easier in the future to maintain, please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added some comments/questions with my thoughts as well, mostly about design decisions that were made deliberately in the previous version. (And might be worth considering)
pangaea/datasets/ai4smallfarms.py
Outdated
@@ -9,7 +9,7 @@ | |||
from tifffile import imread | |||
|
|||
from pangaea.datasets.base import GeoFMDataset | |||
|
|||
from pangaea.engine.data_preprocessor import BasePreprocessor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the datasets depending on a preprocessor is kind of an anti-pattern. A dataset should return the data, and then we can do whatever we want to do with it. In this sense (I think at least), the previous implementation was a lot clearer.
Also, this will just lead to having to repeat the following pattern in each dataset, and it's easy to miss especially for new contributors:
if self.preprocessor is not None:
output = self.preprocessor(output)
return output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, on this I disagree :)
I see your point, but in deep learning community is very common to have the transformations (i.e. the preprocessing) in the dataloader.
So I think it will be clear for people who read the code. And in general the code will be cleaner.
For similar reasons, I don't think people who want to add new datasets will forget, especially, if we document it in the readme.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my 10 years of deep learning I've only ever seen it in repositories that were dealing with a single dataset, but it might be that I'm just old-school 😄 It does work either way, so if you guys prefer this we should go with it.
I'm thinking we could outsource the application of the preprocessing to the base class maybe, so it doesn't have to be repeated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I rethink about it and we can do a single dataset wrapper (not recursive) to combine a raw dataset and the transforms.
pangaea/decoders/base.py
Outdated
|
||
from pangaea.encoders.base import Encoder | ||
|
||
import matplotlib.pyplot as plt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a leftover from somewhere, do we depend on matplotlib at all? If we do why? :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
pangaea/engine/data_preprocessor.py
Outdated
|
||
def __init__(self, dataset: GeoFMDataset, encoder: Encoder): | ||
"""Initialize the RichDataset. | ||
def build_preprocessor(preprocessing_cfg, dataset_cfg, encoder_cfg): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why this is a factory function instead of just the constructor of the class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
pangaea/engine/trainer.py
Outdated
@@ -11,7 +12,7 @@ | |||
from torch.optim.optimizer import Optimizer | |||
from torch.utils.data import DataLoader | |||
from pangaea.utils.logger import RunningAverageMeter, sec_to_hm | |||
|
|||
import wandb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This breaks the feature of being able to run the project without installing Wandb.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems wandb is in the requirements.txt anyway. I think just installing it won't bother that much for people don't want to use it. If installing wandb is indeed a pain, we still need to rewrite the way to optionally import wandb as the previous one is weird
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with it either way, it was just a decision we discussed before, and that's why it was implemented the way it was. It's true that at that time we did not have wandb in the dependency files on purpose, not sure when that changed.
pangaea/engine/evaluator.py
Outdated
@@ -77,6 +82,53 @@ def compute_metrics(self): | |||
def log_metrics(self, metrics): | |||
pass | |||
|
|||
@staticmethod | |||
def sliding_inference(model, img, input_size, output_size=None, stride=None, max_batch=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving this here (instead of leaving it in the preprocessor) means that:
- We can't do tiling while training (could be useful for large models on small GPU-s) ----- Edit: I'm dumb, obviously we just random crop in this case. 😅
- We can't decide whether we want to do certain augmentations before or after tiling.
I actually thought about this, and had a cache implementation for this in an iteration of the previous version. However, when I benchmarked it, it turns out the files are already memcached by the operating system (at least on my machines). So did it not improve performance at all, and I left it out in the final pull request to reduce code complexity. Did you see any performance improvements in your case? |
i know that the files will be in memory after the first epoch. however, the images files still need to be decoded and preprocessed. the old tiling bothers me a lot as it decodes the full, large images from 25 to 64 times to finish an inference (224/128 input size, 1024 image size), which makes no senses and cost huge overhead. meanwhile reading an intact image means you can easily visualize the corresponding prediction / perform test-time augmentation and other features to be implemented. |
The visualization can be easily solved by having an ID tag in the output dictionary so we can reconstruct the images.
This is what I thought as well, but by tiling first we don't do a lot of unnecessary preprocessing, and it seemed like in the end my code did not gain any performance if I cached the preprocessed images and tiled at the end of the pipeline. But if you can show me actual performance improvements, I'm all in for this change. 👍 |
1. The preprocessor is again redesigned. 1.1 Previous GeoFMDataset is renamed to RawFMDataset and current GeoFMDataset is a wrapper class to combine RawFMDataset and Preprocessor. 1.2 The Preprocessor is initialized by preprocessor_cfg, dataset_cfg, encoder_cfg. a list of defined preprocessor will then be initialized with data statistics/info (e.g., data mean/std) being tracked to ensure preprocessors works properly in any order. 1.3 BandAdaptor is split into BandFilter and BandPadding. So BnadFilter in the beginning and BandPadding in the end to avoid operating trivial bands. 1.4 Tile is replaced by sliding_inference in evaluator. It avoids load, decode, preprocess images multiple times when evaluating and its implementation is more straightforward. Getting intact images also facilitates potential operations during inference. 1.5 Add FocusRandomCrop for sparse annotated data. For example MADOS has its most label ignored and RandomCrop can hardly crop regions with valid labels. FocusRandomCrop guarantee crops have valid labels. 1.6 Remove ImportanceRandomCrop at the moment. I feel ImportanceRandomCrop will not work as expected. The class imbalance is mainly due to small classes only appear in few images. It may not be helpful to crop the most important region in single image and the resulting data distribution is undetermined. I also don't think we think we should handle data imbalance ourselves but the foundation models should. 1.7 Add RandomResizedCrop. |
2. The encoders are partially refactored and bugs are fixed. 2.1 gfmswin, satlasnet, spectralgpt do not resume the 2D feature layout correctly. They are fixed now. 2.2 For models whose output is already multi-scale pyramid features (swin), the Feature2Pyramid neck in upernet is skiped. 2.3 An default function (naive_reshape_to_2d) to reshape transformer output to 2D layout is used when possible, to make the code cleaner. 2.4 Single/multi temporal model vs single/multi temporal input are dealt with in the forward method of parent class encoder instead of in decoder. The native forward for each encoder is renamed to simple_forward. The code can be shared in this way. 2.5 The naive method (naive_multi_temporal_forward) for single temporal model to perform multi temporal forward is added. Loop mode iterates over T dimension to extract features one by one and batch mode merge T dimension into B to extract features in one forward. 2.6 SiamUperNet feed single frame to encoder but the encoder is configured to accept 2 frames. An enforce_single_temporal method is added to turn the encoder to single-temporal setting. 2.7 I try to redesign the encoder configs to make all the above stuff works. They are not satisfied in the end so I stop commenting them. I try to classify arguments into base encoder arguments (essential properties of the encoder) and model-specific arguments and merge those with the same meaning . However, it is cumbersome and there are always either special casesas the encoder implementations are heterogeneous. 2.8 Now my idea is to put the base encoder arguments into an independent dict accepted by the base class. Users will have to fill these properties of the encoder strictly according to the format in config to ensure other module works with the encoder. The free-form model-specific arguments are passed to the actual init function of the encoder. Separating the two groups of arguments minimizes the work to add new encoders. example config:
|
3. Other changes 3.1 Computing overall mean MSE in Regression evaluator is still incorrect because batch size is not even during testing and the metric is not reduced across all GPUs. It is fixed now. 3.2 The current main does not save best checkpoint correctly. Best checkpoint only stores the reference to the model weight but not the actual current values. A quick fix is to make a deepcopy but I prefer to immediately save it to disk to avoid the memory cost. 3.3 Skipping NaN loss is replaced by raising an error. The initial motivation was the same as 1.5 but the model was still not properly trained and it might suppress other issues. 3.4 The parameters for the parent class (Dataset/Encoder) is moved to **kwargs to avoid assigning dozens of parameters repeatedly. 3.5 In trainer, num_classes is obtained from decoder instead of dataset, which is easier. 3.6 Correct the left time compute. |
First of all that's a lot of things Liang, thanks for all the work! Just some initial thoughts below. I don't have enough time to review this today, so I'lll come back to it on Monday in detail.
|
preprocessor:
1.preprocessor is separated from dataset. dataset is no longer recursive. instead, dataset accepts a preprocessor object for data preprocessing
2.transformation can be defined in arbitary order. e.g. crop and band filter can be the first step to reduce computation.
tile is replaced by sliding inference in evalutor, which avoids loading images from disk multiple times and allows 3.generating complete prediction for a whole image.
4.enable verification of data size and dimension
encoder/decoder:
1.incorporate multi temporal forward in encoder/base.py instead
2.squeeze temporal dimension safely
fix some bugs:
multi temporal output
save best model
TO DO:
check if everything works
merge segmentation and regression tasks as much as possible
fix bugs e.g. regevaluator