Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VOS loss #47

Merged
merged 18 commits into from
Apr 15, 2024
Merged

VOS loss #47

merged 18 commits into from
Apr 15, 2024

Conversation

tla93
Copy link
Collaborator

@tla93 tla93 commented Apr 9, 2024

add VOS loss, detector and segmentation example
pre-commit.config : update flake8 version to 5.0.2


def __init__(self, model: torch.nn.Module, weights_energy: torch.nn.Module):
"""
:param t: Temperature value :math:`T`. Default is 1.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

documentation inconsistent

:param weights_energy: energy weights as torch.nn.module
"""
# Permutation depends on shape of logits
tmp_scores_ = logits.permute(0, 2, 3, 1)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this only work in segmentation settings?

"""

def __init__(
self, logistic_regression, weights_energy, alpha=0.1, device="cuda:0", reduction="mean"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add type hints to parameters

# Create non-linear neural network functions
phi = torch.nn.Linear(1, 2)
weights_energy = torch.nn.Linear(num_classes, 1)
if 'cuda' in device:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use to .to(device) method here as well

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, these functions are linear

@kkirchheim kkirchheim added documentation Improvements or additions to documentation new method labels Apr 9, 2024
.vscode/settings.json
.pre-commit-config.yaml
.pre-commit-config.yaml
.gitignore
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are these files on the gitignore list?

input_for_lr = torch.cat((energy_x_in, energy_v_out), -1)
if "cuda" in self.device:
labels_for_lr = torch.cat(
(torch.ones(len(energy_x_in)).cuda(), torch.zeros(len(energy_v_out)).cuda()), -1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use .to(device) function

(torch.ones(len(energy_x_in)), torch.zeros(len(energy_v_out))), -1
)

criterion = torch.nn.CrossEntropyLoss(reduction=self.reduction)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please construct cross-entropy in constructor and not on each update


import torch

from src.pytorch_ood.detector import VOSBased
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should fail since detector was renamed


class TestEnergyRegularization(unittest.TestCase):
"""
Test code for energy bounded learning
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not up to date

"""
:param logits: logits given by your model
"""
return self.score(logits)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will fail. please write some more tests.

@kkirchheim kkirchheim merged commit e0b3768 into kkirchheim:dev Apr 15, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation new method
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants