-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix torch compile, script, export #1031
Conversation
def get_stages(self): | ||
return [ | ||
nn.Identity(), | ||
nn.Sequential( | ||
self.features[0].conv, self.features[0].bn, self.features[0].act | ||
), | ||
nn.Sequential( | ||
self.features[0].pool, self.features[1 : self._stage_idxs[0]] | ||
), | ||
self.features[self._stage_idxs[0] : self._stage_idxs[1]], | ||
self.features[self._stage_idxs[1] : self._stage_idxs[2]], | ||
self.features[self._stage_idxs[2] : self._stage_idxs[3]], | ||
] | ||
|
||
def forward(self, x): | ||
stages = self.get_stages() | ||
|
||
features = [] | ||
for i in range(self._depth + 1): | ||
x = stages[i](x) | ||
if isinstance(x, (list, tuple)): | ||
features.append(F.relu(torch.cat(x, dim=1), inplace=True)) | ||
else: | ||
features.append(x) | ||
def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: | ||
return { | ||
16: [self.features[self._stage_idxs[1] : self._stage_idxs[2]]], | ||
32: [self.features[self._stage_idxs[2] : self._stage_idxs[3]]], | ||
} | ||
|
||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]: | ||
features = [x] | ||
|
||
if self._depth >= 1: | ||
x = self.features[0].conv(x) | ||
x = self.features[0].bn(x) | ||
x = self.features[0].act(x) | ||
features.append(x) | ||
|
||
if self._depth >= 2: | ||
x = self.features[0].pool(x) | ||
x = self.features[1 : self._stage_idxs[0]](x) | ||
skip = F.relu(torch.cat(x, dim=1), inplace=True) | ||
features.append(skip) | ||
|
||
if self._depth >= 3: | ||
x = self.features[self._stage_idxs[0] : self._stage_idxs[1]](x) | ||
skip = F.relu(torch.cat(x, dim=1), inplace=True) | ||
features.append(skip) | ||
|
||
if self._depth >= 4: | ||
x = self.features[self._stage_idxs[1] : self._stage_idxs[2]](x) | ||
skip = F.relu(torch.cat(x, dim=1), inplace=True) | ||
features.append(skip) | ||
|
||
if self._depth >= 5: | ||
x = self.features[self._stage_idxs[2] : self._stage_idxs[3]](x) | ||
features.append(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR refactors the self.get_stages
method, which previously used a for loop to return multi-scale features. Now, the self.forward
method uses if-else statements to handle different feature scales.
What prompted this change? Is it intended to better support more complex models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @brianhou0208! This is prompted by compatibility with Torch script/export. I'm not sure if it's easier to read, but it is still not that complicated and is very explicit.
@adamjstewart would you like to have a look? 😄 otherwise it's ready to be merged. The next one will move all encoders to the hf-hub for faster loading and download stats. Actually I moved them already and just need to update URLs and the way of loading |
btw, additionally tested all models and encoders against the main branch with the following script to ensure output match and weight can be loaded with no issues import os
import torch
import segmentation_models_pytorch as smp
from tqdm import tqdm
TMP_FOLDER = "tmp-model"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def run_on_main():
encoder_names = sorted(smp.encoders.get_encoder_names())
for encoder_name in tqdm(encoder_names):
model = smp.Unet(encoder_name, decoder_channels=[4, 4, 4, 4, 4], encoder_weights=None)
model = model.eval().to(DEVICE)
sample = torch.randn(2, 3, 256, 256).to(DEVICE)
with torch.no_grad():
output = model(sample)
os.makedirs(os.path.join(TMP_FOLDER, encoder_name), exist_ok=True)
torch.save(sample, os.path.join(TMP_FOLDER, encoder_name, "input.pth"))
torch.save(output, os.path.join(TMP_FOLDER, encoder_name, "output.pth"))
torch.save(model.state_dict(), os.path.join(TMP_FOLDER, encoder_name, "state_dict.pth"))
def run_on_branch():
encoder_names = os.listdir(TMP_FOLDER)
for encoder_name in tqdm(encoder_names):
sample = torch.load(os.path.join(TMP_FOLDER, encoder_name, "input.pth"), weights_only=True)
expected_output = torch.load(os.path.join(TMP_FOLDER, encoder_name, "output.pth"), weights_only=True)
state_dict = torch.load(os.path.join(TMP_FOLDER, encoder_name, "state_dict.pth"), weights_only=True)
model = smp.Unet(encoder_name, decoder_channels=[4, 4, 4, 4, 4], encoder_weights=None).eval().to(DEVICE)
try:
model.load_state_dict(state_dict)
except Exception as e:
print(f"Error loading state dict for {encoder_name}: {e}")
raise e
with torch.no_grad():
output = model(sample)
if not torch.allclose(output, expected_output):
diff = torch.abs(output - expected_output).max().item()
print(f"Encoder {encoder_name} has different output with max diff {diff:.6f}")
if __name__ == "__main__":
import git
repo = git.Repo(".")
if repo.active_branch.name == "main":
print("\n--- Running on main branch ---\n")
run_on_main()
else:
print(f"\n--- Running on {repo.active_branch.name} branch ---\n ")
run_on_branch() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is a bit too big for me to properly review, but I added a few comments on things that caught my eye. Thanks for adding more type hints!
super().__init__() | ||
if policy not in ["add", "cat"]: | ||
raise ValueError( | ||
"`merge_policy` must be one of: ['add', 'cat'], got {}".format(policy) | ||
) | ||
self.policy = policy | ||
|
||
def forward(self, x): | ||
def forward(self, x: List[torch.Tensor]) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically List
is a bit too strict, it could be any collections.abc.Sequence
. This includes things like tuples. Likely true for a lot of other places in the code base as well.
@@ -220,7 +247,7 @@ def __init__( | |||
upscale_mode=upscale_mode, | |||
) | |||
|
|||
def forward(self, *features): | |||
def forward(self, features: List[torch.Tensor]) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a pretty big change. You went from model(x1, x2, x3)
to model([x1, x2, x3])
. Is this intentional, or was this an accidental type hint change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was an intentional change for torchscript
compatibility. Yeah, while the model interface does not change, decoder interface has been changed, so might break smth for those who use building blocks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with that if it's only the building blocks that changed and not the outward-facing encoders/decoders.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hope it will be fine as well. While Decoders themselves are not private, they are not advertised as a public API. The main use case is the model
API, which is fully backward compatible.
def __init__( | ||
self, | ||
in_channels: int, | ||
sizes: Tuple[int, ...] = (1, 2, 3, 6), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would any collections.abc.Sequence
be valid here?
x = self.features.transition3.pool(x) | ||
x = self.features.denseblock4(x) | ||
x = self.features.norm5(x) | ||
features.append(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I honestly prefer the for-loop here, but I'm guessing that makes it not possible to compile. What happens if depth > 5? Is there an assert statement to prevent that invalid input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add validation in d121fec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, more like torhcscript limitations, I left loop in a few places, but it's not possible to loop over indexes, only over layers + break
is not supported in torchscript
I also agree with what @adamjstewart said. Maybe it can be divided into multiple PR for easier review, such as type hints check or delete timm- encoders. |
Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
@brianhou0208, thanks for the review! typehint corrections were needed for torchscript as well.. so yeah.. not happy PR become so huge either |
believe in previously improved tests and my additional testing against main branch 🤞 |
Agree with @adamjstewart, typhints are not ideal, and can be improved further, but out of scope of this PR, so gonna merge them as is |
Huge PR, but many things are dependent, so include everything here:
timm-
encoders (map weights totu-
except for EfficientNet and SKNet).torch.compile
for encoders and add tests (only EfficientNet is currently skipped, but it can be fixed once we copy-paste the code).torch.export.export
torch.jit.script
Fixes: