Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactored preprocessor/merged mt feature extraction in encoder/fixed some bugs #86

Closed
wants to merge 0 commits into from

Conversation

LeungTsang
Copy link
Collaborator

preprocessor:

1.preprocessor is separated from dataset. dataset is no longer recursive. instead, dataset accepts a preprocessor object for data preprocessing
2.transformation can be defined in arbitary order. e.g. crop and band filter can be the first step to reduce computation.
tile is replaced by sliding inference in evalutor, which avoids loading images from disk multiple times and allows 3.generating complete prediction for a whole image.
4.enable verification of data size and dimension

encoder/decoder:
1.incorporate multi temporal forward in encoder/base.py instead
2.squeeze temporal dimension safely

fix some bugs:
multi temporal output
save best model

TO DO:
check if everything works
merge segmentation and regression tasks as much as possible
fix bugs e.g. regevaluator

@LeungTsang LeungTsang requested a review from VMarsocci October 4, 2024 08:41
Copy link
Collaborator

@gle-bellier gle-bellier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,
Thank you for the pull request, the preprocessing refactoring was needed and much appreciated. The same for the sliding inference of the evaluator.
I included several comments for this review, most of them are related to comments about the shape of the tensors, but one of them requires special attention: we can not merge multi-temporal frames on the batch dimension otherwise we may face OOM.
I have not run any experiments with the codebase for now, but I will soon.

pangaea/run.py Outdated
@@ -166,7 +161,7 @@ def main(cfg: DictConfig) -> None:
num_workers=cfg.num_workers,
pin_memory=True,
# persistent_workers=True causes memory leak
persistent_workers=False,
persistent_workers=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you checked the system memory usage during training? With previous preprocessing it rocketed up when using persistent_workers=True and it was the reason why it was set up to False.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

"""
raise NotImplementedError

def forward(self, image: dict[str, torch.Tensor]) -> list[torch.Tensor]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you document the code with the shape of the tensors, please?
It mostly misses some docstrings that set the shape of the input and output tensors explicitly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


def naive_multi_temporal_forward(self, image: dict[str, torch.Tensor]) -> list[torch.Tensor]:
b, c, t, h, w = image[list(image.keys())[0]].shape
image = {k: v.transpose(1, 2).contiguous().view(-1, c, 1, h, w) for k, v in image.items()}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand and agree with the fact that we can consider we have T single-temporal images to deal with. Nevertheless it was previously done with a for loop and collecting features one by one by asking the encoder to compute them from input images of shape (B, C, 1, H, W). With this implementation, we are feeding the model with size inputs (BxT, C, 1, H, W) and it will cause out-of-memory issues if we don't have a large memory headroom.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I will make it optional

pangaea/run.py Outdated
@@ -56,6 +57,8 @@ def main(cfg: DictConfig) -> None:
fix_seed(cfg.seed)
# distributed training variables
rank = int(os.environ["RANK"])
cfg.task.trainer.use_wandb = cfg.task.trainer.use_wandb and rank == 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is the best way to do it, rewriting configs is not a good practice. Instead, we can use:
use_wandb = cfg.task.trainer.use_wandb and rank == 0 and overwrite use_wandb in the trainer and evaluator instantiate.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


# If the encoder handles only single temporal data, we apply multi_temporal_strategy
if self.multi_temporal_strategy == "linear":
feats = [self.tmap(feat.permute(0, 1, 3, 4, 2)).squeeze(-1) for feat in feats]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment on the shape of the features (before and after the permute and squeeze) so it is easier in the future to maintain, please?

Copy link
Collaborator

@KerekesDavid KerekesDavid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some comments/questions with my thoughts as well, mostly about design decisions that were made deliberately in the previous version. (And might be worth considering)

@@ -9,7 +9,7 @@
from tifffile import imread

from pangaea.datasets.base import GeoFMDataset

from pangaea.engine.data_preprocessor import BasePreprocessor
Copy link
Collaborator

@KerekesDavid KerekesDavid Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the datasets depending on a preprocessor is kind of an anti-pattern. A dataset should return the data, and then we can do whatever we want to do with it. In this sense (I think at least), the previous implementation was a lot clearer.

Also, this will just lead to having to repeat the following pattern in each dataset, and it's easy to miss especially for new contributors:

        if self.preprocessor is not None:
            output = self.preprocessor(output)

        return output

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, on this I disagree :)

I see your point, but in deep learning community is very common to have the transformations (i.e. the preprocessing) in the dataloader.

So I think it will be clear for people who read the code. And in general the code will be cleaner.

For similar reasons, I don't think people who want to add new datasets will forget, especially, if we document it in the readme.

Copy link
Collaborator

@KerekesDavid KerekesDavid Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my 10 years of deep learning I've only ever seen it in repositories that were dealing with a single dataset, but it might be that I'm just old-school 😄 It does work either way, so if you guys prefer this we should go with it.
I'm thinking we could outsource the application of the preprocessing to the base class maybe, so it doesn't have to be repeated?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rethink about it and we can do a single dataset wrapper (not recursive) to combine a raw dataset and the transforms.


from pangaea.encoders.base import Encoder

import matplotlib.pyplot as plt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a leftover from somewhere, do we depend on matplotlib at all? If we do why? :D

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


def __init__(self, dataset: GeoFMDataset, encoder: Encoder):
"""Initialize the RichDataset.
def build_preprocessor(preprocessing_cfg, dataset_cfg, encoder_cfg):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why this is a factory function instead of just the constructor of the class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@@ -11,7 +12,7 @@
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from pangaea.utils.logger import RunningAverageMeter, sec_to_hm

import wandb
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks the feature of being able to run the project without installing Wandb.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems wandb is in the requirements.txt anyway. I think just installing it won't bother that much for people don't want to use it. If installing wandb is indeed a pain, we still need to rewrite the way to optionally import wandb as the previous one is weird

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with it either way, it was just a decision we discussed before, and that's why it was implemented the way it was. It's true that at that time we did not have wandb in the dependency files on purpose, not sure when that changed.

@@ -77,6 +82,53 @@ def compute_metrics(self):
def log_metrics(self, metrics):
pass

@staticmethod
def sliding_inference(model, img, input_size, output_size=None, stride=None, max_batch=None):
Copy link
Collaborator

@KerekesDavid KerekesDavid Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving this here (instead of leaving it in the preprocessor) means that:

  1. We can't do tiling while training (could be useful for large models on small GPU-s) ----- Edit: I'm dumb, obviously we just random crop in this case. 😅
  2. We can't decide whether we want to do certain augmentations before or after tiling.

@KerekesDavid
Copy link
Collaborator

KerekesDavid commented Oct 4, 2024

tile is replaced by sliding inference in evalutor, which avoids loading images from disk multiple times

I actually thought about this, and had a cache implementation for this in an iteration of the previous version. However, when I benchmarked it, it turns out the files are already memcached by the operating system (at least on my machines).

So did it not improve performance at all, and I left it out in the final pull request to reduce code complexity.

Did you see any performance improvements in your case?

@LeungTsang
Copy link
Collaborator Author

LeungTsang commented Oct 7, 2024

Did you see any performance improvements in your case?

i know that the files will be in memory after the first epoch. however, the images files still need to be decoded and preprocessed. the old tiling bothers me a lot as it decodes the full, large images from 25 to 64 times to finish an inference (224/128 input size, 1024 image size), which makes no senses and cost huge overhead. meanwhile reading an intact image means you can easily visualize the corresponding prediction / perform test-time augmentation and other features to be implemented.

@KerekesDavid
Copy link
Collaborator

KerekesDavid commented Oct 7, 2024

The visualization can be easily solved by having an ID tag in the output dictionary so we can reconstruct the images.

however, the images files still need to be decoded and preprocessed. the old tiling bothers me a lot as it decodes the full, large images from 25 to 64 times to finish an inference (224/128 input size, 1024 image size), which makes no senses and cost huge overhead.

This is what I thought as well, but by tiling first we don't do a lot of unnecessary preprocessing, and it seemed like in the end my code did not gain any performance if I cached the preprocessed images and tiled at the end of the pipeline.

But if you can show me actual performance improvements, I'm all in for this change. 👍

@LeungTsang
Copy link
Collaborator Author

1. The preprocessor is again redesigned.

1.1 Previous GeoFMDataset is renamed to RawFMDataset and current GeoFMDataset is a wrapper class to combine RawFMDataset and Preprocessor.

1.2 The Preprocessor is initialized by preprocessor_cfg, dataset_cfg, encoder_cfg. a list of defined preprocessor will then be initialized with data statistics/info (e.g., data mean/std) being tracked to ensure preprocessors works properly in any order.

1.3 BandAdaptor is split into BandFilter and BandPadding. So BnadFilter in the beginning and BandPadding in the end to avoid operating trivial bands.

1.4 Tile is replaced by sliding_inference in evaluator. It avoids load, decode, preprocess images multiple times when evaluating and its implementation is more straightforward. Getting intact images also facilitates potential operations during inference.

1.5 Add FocusRandomCrop for sparse annotated data. For example MADOS has its most label ignored and RandomCrop can hardly crop regions with valid labels. FocusRandomCrop guarantee crops have valid labels.

1.6 Remove ImportanceRandomCrop at the moment. I feel ImportanceRandomCrop will not work as expected. The class imbalance is mainly due to small classes only appear in few images. It may not be helpful to crop the most important region in single image and the resulting data distribution is undetermined. I also don't think we think we should handle data imbalance ourselves but the foundation models should.

1.7 Add RandomResizedCrop.

@LeungTsang
Copy link
Collaborator Author

2. The encoders are partially refactored and bugs are fixed.

2.1 gfmswin, satlasnet, spectralgpt do not resume the 2D feature layout correctly. They are fixed now.

2.2 For models whose output is already multi-scale pyramid features (swin), the Feature2Pyramid neck in upernet is skiped.

2.3 An default function (naive_reshape_to_2d) to reshape transformer output to 2D layout is used when possible, to make the code cleaner.

2.4 Single/multi temporal model vs single/multi temporal input are dealt with in the forward method of parent class encoder instead of in decoder. The native forward for each encoder is renamed to simple_forward. The code can be shared in this way.

2.5 The naive method (naive_multi_temporal_forward) for single temporal model to perform multi temporal forward is added. Loop mode iterates over T dimension to extract features one by one and batch mode merge T dimension into B to extract features in one forward.

2.6 SiamUperNet feed single frame to encoder but the encoder is configured to accept 2 frames. An enforce_single_temporal method is added to turn the encoder to single-temporal setting.

2.7 I try to redesign the encoder configs to make all the above stuff works. They are not satisfied in the end so I stop commenting them. I try to classify arguments into base encoder arguments (essential properties of the encoder) and model-specific arguments and merge those with the same meaning . However, it is cumbersome and there are always either special casesas the encoder implementations are heterogeneous.

2.8 Now my idea is to put the base encoder arguments into an independent dict accepted by the base class. Users will have to fill these properties of the encoder strictly according to the format in config to ensure other module works with the encoder. The free-form model-specific arguments are passed to the actual init function of the encoder. Separating the two groups of arguments minimizes the work to add new encoders.

example config:

###_model-specific arguments_###
img_size: 224
embed_dim: 768
...

###_base encoder arguments_###
meta_base:
    model_name: good
    download_url: xxx
    multi_temporal: 2
    input_size: 224 (so no conflicts with the img_size) 

class Encoder():
    def __init__(self, model_name, input_size, ...):
         self.input_size = input_size
         self.model_name = model_name
         ...

class GoodEncoder(Encoder):
    def __init__(self, img_size, embed_dim, ...., meta):
         super().__init__(**meta)
         self.img_size = img_size
         self.block = Block(embed_dim)
         ....

@LeungTsang
Copy link
Collaborator Author

3. Other changes

3.1 Computing overall mean MSE in Regression evaluator is still incorrect because batch size is not even during testing and the metric is not reduced across all GPUs. It is fixed now.

3.2 The current main does not save best checkpoint correctly. Best checkpoint only stores the reference to the model weight but not the actual current values. A quick fix is to make a deepcopy but I prefer to immediately save it to disk to avoid the memory cost.

3.3 Skipping NaN loss is replaced by raising an error. The initial motivation was the same as 1.5 but the model was still not properly trained and it might suppress other issues.

3.4 The parameters for the parent class (Dataset/Encoder) is moved to **kwargs to avoid assigning dozens of parameters repeatedly.

3.5 In trainer, num_classes is obtained from decoder instead of dataset, which is easier.

3.6 Correct the left time compute.

@KerekesDavid
Copy link
Collaborator

KerekesDavid commented Oct 11, 2024

First of all that's a lot of things Liang, thanks for all the work!

Just some initial thoughts below. I don't have enough time to review this today, so I'lll come back to it on Monday in detail.

  • This PR has became a a bit too unfocused, there are ~3000 lines changed, and I think a lot of the stuff could be broken up into smaller PR-s that are easier to reason about. Eg. most of the fixes in the "3 .Other Changes" section could just be merged independently of the work here.
  • Merge from main so there are no merge conflicts before the reviews.
  • The change detection stuff that @SebastianGer and @SebastianHafner is working on relies on importance cropping. The focus crop might be able to replace this, but they should take a look at it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants