Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch compile, script, export #1031

Merged
merged 58 commits into from
Jan 15, 2025
Merged

Fix torch compile, script, export #1031

merged 58 commits into from
Jan 15, 2025

Conversation

qubvel
Copy link
Collaborator

@qubvel qubvel commented Jan 13, 2025

Huge PR, but many things are dependent, so include everything here:

  • Deprecate timm- encoders (map weights to tu- except for EfficientNet and SKNet).
  • Fix torch.compile for encoders and add tests (only EfficientNet is currently skipped, but it can be fixed once we copy-paste the code).
  • Fix and add tests for torch.export.export
  • Fix and add tests for torch.jit.script

Fixes:

Copy link

codecov bot commented Jan 13, 2025

Codecov Report

Attention: Patch coverage is 93.98496% with 40 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...entation_models_pytorch/encoders/timm_universal.py 40.00% 12 Missing ⚠️
segmentation_models_pytorch/base/model.py 62.50% 9 Missing ⚠️
segmentation_models_pytorch/base/utils.py 45.45% 6 Missing ⚠️
segmentation_models_pytorch/encoders/_base.py 78.94% 4 Missing ⚠️
...ntation_models_pytorch/decoders/deeplabv3/model.py 75.00% 3 Missing ⚠️
...ntation_models_pytorch/encoders/mix_transformer.py 96.92% 2 Missing ⚠️
...egmentation_models_pytorch/decoders/fpn/decoder.py 94.73% 1 Missing ⚠️
...egmentation_models_pytorch/decoders/pan/decoder.py 94.73% 1 Missing ⚠️
...ation_models_pytorch/decoders/segformer/decoder.py 83.33% 1 Missing ⚠️
...ation_models_pytorch/encoders/timm_efficientnet.py 96.87% 1 Missing ⚠️
Files with missing lines Coverage Δ
segmentation_models_pytorch/base/hub_mixin.py 98.33% <100.00%> (+0.05%) ⬆️
...ation_models_pytorch/decoders/deeplabv3/decoder.py 98.68% <100.00%> (+0.17%) ⬆️
...ntation_models_pytorch/decoders/linknet/decoder.py 100.00% <100.00%> (ø)
...mentation_models_pytorch/decoders/manet/decoder.py 97.75% <100.00%> (ø)
...entation_models_pytorch/decoders/pspnet/decoder.py 100.00% <100.00%> (ø)
...gmentation_models_pytorch/decoders/unet/decoder.py 91.37% <100.00%> (ø)
...on_models_pytorch/decoders/unetplusplus/decoder.py 92.85% <100.00%> (+0.10%) ⬆️
...tion_models_pytorch/decoders/unetplusplus/model.py 95.00% <100.00%> (+0.26%) ⬆️
...ntation_models_pytorch/decoders/upernet/decoder.py 98.00% <100.00%> (ø)
...mentation_models_pytorch/decoders/upernet/model.py 100.00% <100.00%> (ø)
... and 23 more

... and 1 file with indirect coverage changes

Comment on lines -46 to +87
def get_stages(self):
return [
nn.Identity(),
nn.Sequential(
self.features[0].conv, self.features[0].bn, self.features[0].act
),
nn.Sequential(
self.features[0].pool, self.features[1 : self._stage_idxs[0]]
),
self.features[self._stage_idxs[0] : self._stage_idxs[1]],
self.features[self._stage_idxs[1] : self._stage_idxs[2]],
self.features[self._stage_idxs[2] : self._stage_idxs[3]],
]

def forward(self, x):
stages = self.get_stages()

features = []
for i in range(self._depth + 1):
x = stages[i](x)
if isinstance(x, (list, tuple)):
features.append(F.relu(torch.cat(x, dim=1), inplace=True))
else:
features.append(x)
def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]:
return {
16: [self.features[self._stage_idxs[1] : self._stage_idxs[2]]],
32: [self.features[self._stage_idxs[2] : self._stage_idxs[3]]],
}

def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
features = [x]

if self._depth >= 1:
x = self.features[0].conv(x)
x = self.features[0].bn(x)
x = self.features[0].act(x)
features.append(x)

if self._depth >= 2:
x = self.features[0].pool(x)
x = self.features[1 : self._stage_idxs[0]](x)
skip = F.relu(torch.cat(x, dim=1), inplace=True)
features.append(skip)

if self._depth >= 3:
x = self.features[self._stage_idxs[0] : self._stage_idxs[1]](x)
skip = F.relu(torch.cat(x, dim=1), inplace=True)
features.append(skip)

if self._depth >= 4:
x = self.features[self._stage_idxs[1] : self._stage_idxs[2]](x)
skip = F.relu(torch.cat(x, dim=1), inplace=True)
features.append(skip)

if self._depth >= 5:
x = self.features[self._stage_idxs[2] : self._stage_idxs[3]](x)
features.append(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR refactors the self.get_stages method, which previously used a for loop to return multi-scale features. Now, the self.forward method uses if-else statements to handle different feature scales.
What prompted this change? Is it intended to better support more complex models?

Copy link
Collaborator Author

@qubvel qubvel Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @brianhou0208! This is prompted by compatibility with Torch script/export. I'm not sure if it's easier to read, but it is still not that complicated and is very explicit.

@qubvel
Copy link
Collaborator Author

qubvel commented Jan 15, 2025

@adamjstewart would you like to have a look? 😄 otherwise it's ready to be merged.

The next one will move all encoders to the hf-hub for faster loading and download stats. Actually I moved them already and just need to update URLs and the way of loading

@qubvel
Copy link
Collaborator Author

qubvel commented Jan 15, 2025

btw, additionally tested all models and encoders against the main branch with the following script to ensure output match and weight can be loaded with no issues

import os
import torch
import segmentation_models_pytorch as smp

from tqdm import tqdm

TMP_FOLDER = "tmp-model"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def run_on_main():
    
    encoder_names = sorted(smp.encoders.get_encoder_names())

    for encoder_name in tqdm(encoder_names):

        model = smp.Unet(encoder_name, decoder_channels=[4, 4, 4, 4, 4], encoder_weights=None)
        model = model.eval().to(DEVICE)
        sample = torch.randn(2, 3, 256, 256).to(DEVICE)

        with torch.no_grad():
            output = model(sample)

        os.makedirs(os.path.join(TMP_FOLDER, encoder_name), exist_ok=True)
        torch.save(sample, os.path.join(TMP_FOLDER, encoder_name, "input.pth"))
        torch.save(output, os.path.join(TMP_FOLDER, encoder_name, "output.pth"))
        torch.save(model.state_dict(), os.path.join(TMP_FOLDER, encoder_name, "state_dict.pth"))


def run_on_branch():
    
    encoder_names = os.listdir(TMP_FOLDER)

    for encoder_name in tqdm(encoder_names):

        sample = torch.load(os.path.join(TMP_FOLDER, encoder_name, "input.pth"), weights_only=True)
        expected_output = torch.load(os.path.join(TMP_FOLDER, encoder_name, "output.pth"), weights_only=True)
        state_dict = torch.load(os.path.join(TMP_FOLDER, encoder_name, "state_dict.pth"), weights_only=True)
        
        model = smp.Unet(encoder_name, decoder_channels=[4, 4, 4, 4, 4], encoder_weights=None).eval().to(DEVICE)
        try:
            model.load_state_dict(state_dict)
        except Exception as e:
            print(f"Error loading state dict for {encoder_name}: {e}")
            raise e

        with torch.no_grad():
            output = model(sample)

        if not torch.allclose(output, expected_output):
            diff = torch.abs(output - expected_output).max().item()
            print(f"Encoder {encoder_name} has different output with max diff {diff:.6f}")

if __name__ == "__main__":

    import git
    repo = git.Repo(".")

    if repo.active_branch.name == "main":
        print("\n--- Running on main branch ---\n")
        run_on_main()
    else:
        print(f"\n--- Running on {repo.active_branch.name} branch ---\n ")
        run_on_branch()

Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is a bit too big for me to properly review, but I added a few comments on things that caught my eye. Thanks for adding more type hints!

super().__init__()
if policy not in ["add", "cat"]:
raise ValueError(
"`merge_policy` must be one of: ['add', 'cat'], got {}".format(policy)
)
self.policy = policy

def forward(self, x):
def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically List is a bit too strict, it could be any collections.abc.Sequence. This includes things like tuples. Likely true for a lot of other places in the code base as well.

@@ -220,7 +247,7 @@ def __init__(
upscale_mode=upscale_mode,
)

def forward(self, *features):
def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty big change. You went from model(x1, x2, x3) to model([x1, x2, x3]). Is this intentional, or was this an accidental type hint change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was an intentional change for torchscript compatibility. Yeah, while the model interface does not change, decoder interface has been changed, so might break smth for those who use building blocks

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with that if it's only the building blocks that changed and not the outward-facing encoders/decoders.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope it will be fine as well. While Decoders themselves are not private, they are not advertised as a public API. The main use case is the model API, which is fully backward compatible.

def __init__(
self,
in_channels: int,
sizes: Tuple[int, ...] = (1, 2, 3, 6),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would any collections.abc.Sequence be valid here?

segmentation_models_pytorch/decoders/pspnet/decoder.py Outdated Show resolved Hide resolved
segmentation_models_pytorch/encoders/__init__.py Outdated Show resolved Hide resolved
x = self.features.transition3.pool(x)
x = self.features.denseblock4(x)
x = self.features.norm5(x)
features.append(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I honestly prefer the for-loop here, but I'm guessing that makes it not possible to compile. What happens if depth > 5? Is there an assert statement to prevent that invalid input?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add validation in d121fec

Copy link
Collaborator Author

@qubvel qubvel Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, more like torhcscript limitations, I left loop in a few places, but it's not possible to loop over indexes, only over layers + break is not supported in torchscript

@brianhou0208
Copy link
Contributor

brianhou0208 commented Jan 15, 2025

I also agree with what @adamjstewart said.

Maybe it can be divided into multiple PR for easier review, such as type hints check or delete timm- encoders.

qubvel and others added 3 commits January 15, 2025 12:33
@qubvel
Copy link
Collaborator Author

qubvel commented Jan 15, 2025

@brianhou0208, thanks for the review! typehint corrections were needed for torchscript as well.. so yeah.. not happy PR become so huge either

@qubvel
Copy link
Collaborator Author

qubvel commented Jan 15, 2025

believe in previously improved tests and my additional testing against main branch 🤞

@qubvel
Copy link
Collaborator Author

qubvel commented Jan 15, 2025

Agree with @adamjstewart, typhints are not ideal, and can be improved further, but out of scope of this PR, so gonna merge them as is

@qubvel qubvel merged commit 456871a into main Jan 15, 2025
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants