-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[cm] Custom Neutone exceptions #58
base: main
Are you sure you want to change the base?
Conversation
@@ -112,7 +112,7 @@ def save_neutone_model( | |||
|
|||
sqw = SampleQueueWrapper(model) | |||
|
|||
with tr.no_grad(): | |||
with tr.inference_mode(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this inference_mode
block here?
@@ -212,34 +216,36 @@ def profile_sqw( | |||
sqw.prepare_for_inference() | |||
if convert_to_torchscript: | |||
log.info("Converting to TorchScript") | |||
with torch.no_grad(): | |||
with torch.inference_mode(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also similar for this one, do we actually need it?
|
||
|
||
# TODO(cm): constant for now, but if we need more of these we could use a factory method | ||
INFERENCE_MODE_EXCEPTION = NeutoneException( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be a list of possible exceptions that you iterate through it in the try/catch block?
block. | ||
""", | ||
trigger_type=RuntimeError, | ||
trigger_str="Inference tensors cannot be saved for backward." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit worried about:
- Is this specific enough to not accidentally catch other exceptions?
- We need to test it with different pytorch versions and see if we get the same error string.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, I used only a snippet of the text, but it's probably better to change it to use almost the entire text. I think it would also make sense to change this to a list of strings that can each trigger the exception such that if the text is slightly different in older pytorch versions we can simply add those messages to the list.
sort_by="self_cpu_memory_usage", row_limit=5 | ||
log.info("Displaying Total CPU Time") | ||
log.info(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) | ||
# log.info(prof.key_averages(group_by_stack_n=5).table(sort_by="cpu_time_total", row_limit=10)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remove this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall, thank you!
Left a couple of comments. I like the raise from
approach, looks like we could also do a raise from None
to suppress the original message. I think that's fine in this case, but maybe too much in general?
This PR implements Neutone exceptions which are used when other exceptions have cryptic error messages that should be elaborated on.
Currently takes advantage of the
raise from
functionality introduced in python 3, let me know if you think there's a nicer way to do this.