Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SetR #43

Merged
merged 24 commits into from
Apr 11, 2024
Merged

SetR #43

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
torch
lightning
git+https://github.com/discovery-unicamp/hiaac-librep.git@0.0.4-dev
scipy
plotly
numpy
pandas
plotly
PyYAML
scipy
statsmodels
jsonargparse[all]
tifffile
torch
zarr
rich
torchmetrics
1 change: 1 addition & 0 deletions sslt/models/nets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .setr import _SetR_PUP
7 changes: 4 additions & 3 deletions sslt/models/nets/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict
import torch

import lightning as L
import torch


class SimpleSupervisedModel(L.LightningModule):
Expand All @@ -18,7 +19,7 @@ class SimpleSupervisedModel(L.LightningModule):
easier to implement new models by only changing the backbone model. More
complex models, that does not follow this pipeline, should not inherit from
this class.

Note that, for this class the input data is a tuple of tensors, where the
first tensor is the input data and the second tensor is the mask or label.
"""
Expand All @@ -38,7 +39,7 @@ def __init__(
backbone : torch.nn.Module
The backbone model. Usually the encoder/decoder part of the model.
fc : torch.nn.Module
The fully connected model, usually used to classification tasks.
The fully connected model, usually used to classification tasks.
Use `torch.nn.Identity()` if no FC model is needed.
loss_fn : torch.nn.Module
The function used to compute the loss.
Expand Down
Loading