Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce a variable skip_unrolling in class Metric #3258

Merged
merged 14 commits into from
Jul 1, 2024

Conversation

simeetnayan81
Copy link
Contributor

@simeetnayan81 simeetnayan81 commented Jun 25, 2024

Fixes #2940

Description:
Introduce a variable skip_unrolling in class Metric as discussed here https://discord.com/channels/831462531327328276/1110662056622964860/1253769540710567977

Check list:

  • New tests are added (if a new feature is added)
  • New doc strings: description and/or example code are in RST format
  • Documentation is updated (if required)

@github-actions github-actions bot added the module: metrics Metrics module label Jun 25, 2024
Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @simeetnayan81
Left a comment about the implementation

ignite/metrics/metric.py Outdated Show resolved Hide resolved
@@ -300,7 +300,7 @@ def compute(self):
_required_output_keys = required_output_keys

def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), skip_unrolling : bool =False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also update above docstring by adding the new argument. Please also check CONTRIBUTING guideline the part about how to add .. versionadded:: tag in the bottom of the docstring. Version to put should be 0.5.1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I add .. versionchanged:: instead?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On it. Thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have made the changes. Kindly review.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update! I added few minor comments. Let's add new feature tests and run the CI to see if any failures

ignite/metrics/metric.py Outdated Show resolved Hide resolved
ignite/metrics/metric.py Outdated Show resolved Hide resolved
ignite/metrics/metric.py Outdated Show resolved Hide resolved
simeetnayan81 and others added 2 commits June 26, 2024 16:04
Co-authored-by: vfdev <vfdev.5@gmail.com>
Co-authored-by: vfdev <vfdev.5@gmail.com>
@simeetnayan81
Copy link
Contributor Author

Tests should be added to end of the test_metric.py file?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jun 27, 2024

Yes, you can add it in the end of the file

@simeetnayan81
Copy link
Contributor Author

skip_unrolling = False is already covered by all the prior tests. I have added a test for when skip_unrolling = True. Kindly review and let me know the changes.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good @simeetnayan81 , thanks, let's just add an example of usage of the flag in the docstring and it will be good to go, once CI is green (except unrelated failures)

ignite/metrics/metric.py Show resolved Hide resolved
@simeetnayan81
Copy link
Contributor Author

simeetnayan81 commented Jun 28, 2024

@vfdev-5 Before adding the example in the docstring, I wanted to confirm, to make skip_unrolling effective for the loss function, we might also need to change this.
https://github.com/pytorch/ignite/blob/master/ignite/metrics/loss.py#L77
Prev:

def __init__(
        self,
        loss_fn: Callable,
        output_transform: Callable = lambda x: x,
        batch_size: Callable = len,
        device: Union[str, torch.device] = torch.device("cpu"),
    ):
        super(Loss, self).__init__(output_transform, device=device)

Change to:

def __init__(
        self,
        loss_fn: Callable,
        output_transform: Callable = lambda x: x,
        batch_size: Callable = len,
        device: Union[str, torch.device] = torch.device("cpu"),
        skip_unrolling=False
    ):
        super(Loss, self).__init__(output_transform, device=device, skip_unrolling=skip_unrolling)

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jun 28, 2024

@simeetnayan81 yes, you are right, we need to add this new arg to all metrics defining a constructor. Let's update Loss metric here and update other metrics in a follow-up PR.

@simeetnayan81
Copy link
Contributor Author

simeetnayan81 commented Jun 28, 2024

Things to do in a follow-up PR.

  • Add test for updated Loss class
  • Update other sub-classes of Metric with skip_unrolling arg as required, add tests and docstring

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jun 28, 2024

Thanks for the updates and the TODO. Can we do this point here ?

Add test for updated Loss class

@simeetnayan81
Copy link
Contributor Author

Alright @vfdev-5

ignite/metrics/metric.py Outdated Show resolved Hide resolved
@simeetnayan81
Copy link
Contributor Author

Have made the changes, the new test works locally.

ignite/metrics/loss.py Outdated Show resolved Hide resolved
ignite/metrics/metric.py Outdated Show resolved Hide resolved
ignite/metrics/metric.py Outdated Show resolved Hide resolved
ignite/metrics/metric.py Outdated Show resolved Hide resolved
ignite/metrics/metric.py Outdated Show resolved Hide resolved
@simeetnayan81
Copy link
Contributor Author

The test is failing because list[torch.Tensor, torch.Tensor] is supported on python 3.9 and above. Let me modify it a bit.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @simeetnayan81 , lgtm

@vfdev-5 vfdev-5 merged commit d715807 into pytorch:master Jul 1, 2024
19 of 20 checks passed
@simeetnayan81 simeetnayan81 mentioned this pull request Jul 13, 2024
28 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: metrics Metrics module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Metric with multiple input runs in an unexpected way.
2 participants