Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added code and tests for wisenet model #18

Merged
merged 3 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions sslt/models/nets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ class SimpleSupervisedModel(L.LightningModule):
this class.

Note that, for this class the input data is a tuple of tensors, where the
first tensor is the input data and the second tensor is the mask, with the
same shape as the input data.
first tensor is the input data and the second tensor is the mask or label.
"""

def __init__(
Expand Down Expand Up @@ -108,7 +107,9 @@ def _single_step(
batch_idx : int
The index of the batch.
step_name : str
The name of the step. It will be used to log the loss.
The name of the step. It will be used to log the loss. The possible
values are: "train", "val" and "test". The loss will be logged as
"{step_name}_loss".

Returns
-------
Expand Down
137 changes: 137 additions & 0 deletions sslt/models/nets/wisenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import torch
from sslt.models.nets.base import SimpleSupervisedModel


class _WiseNet(torch.nn.Module):
def __init__(self, in_channels=1, out_channels=1):
super().__init__()

self.in_channels = in_channels
self.out_channels = out_channels

# self.norm = nn.BatchNorm3d(1)
self.conv1 = torch.nn.Conv3d(
in_channels=in_channels,
out_channels=32,
kernel_size=3,
stride=1,
padding=1,
)
self.relu = torch.nn.ReLU()
self.pool1 = torch.nn.MaxPool3d(kernel_size=2, stride=2, ceil_mode=True)
self.conv2 = torch.nn.Conv3d(
in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1
)
self.pool2 = torch.nn.MaxPool3d(kernel_size=2, stride=2, ceil_mode=True)
self.conv3 = torch.nn.Conv3d(
in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1
)
self.pool3 = torch.nn.MaxPool3d(kernel_size=2, stride=2, ceil_mode=True)
self.conv4 = torch.nn.Conv3d(
in_channels=128,
out_channels=256,
kernel_size=3,
stride=1,
padding=1,
)
self.pool4 = torch.nn.MaxPool3d(kernel_size=(2, 1, 1), stride=(2, 1, 1))
self.conv5 = torch.nn.ConvTranspose2d(
in_channels=256,
out_channels=128,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
)
self.conv6 = torch.nn.ConvTranspose2d(
in_channels=128,
out_channels=64,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
)
self.conv7 = torch.nn.ConvTranspose2d(
in_channels=64,
out_channels=32,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
)
self.conv8 = torch.nn.Conv2d(
in_channels=32,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
)

def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.relu(x)
x = self.pool2(x)
x = self.conv3(x)
x = self.relu(x)
x = self.pool3(x)
x = self.conv4(x)
x = self.relu(x)
x = self.pool4(x)
x = x.view(
x.size(0), x.size(1), x.size(3), x.size(4)
) # (batch_size, channels, height, width)
x = self.conv5(x)
x = self.relu(x)
x = self.conv6(x)
x = self.relu(x)
x = self.conv7(x)
x = self.relu(x)
x = self.conv8(x)
return x


class WiseNet(SimpleSupervisedModel):
def __init__(
self,
in_channels: int = 1,
out_channels: int = 1,
loss_fn: torch.nn.Module = None,
learning_rate: float = 1e-3,
):
super().__init__(
backbone=_WiseNet(
in_channels=in_channels, out_channels=out_channels
),
fc=torch.nn.Identity(),
loss_fn=loss_fn or torch.nn.MSELoss(),
learning_rate=learning_rate,
flatten=False,
)

def _single_step(
self, batch: torch.Tensor, batch_idx: int, step_name: str
) -> torch.Tensor:
x, y = batch
y_hat = self.forward(x)
y_hat = y_hat[:, :, : y.size(2), : y.size(3)]

loss = self._loss_func(y_hat, y)
self.log(
f"{step_name}_loss",
loss,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
return loss


def predict_step(self, batch, batch_idx, dataloader_idx=None):
x, y = batch
y_hat = self.forward(x)
y_hat = y_hat[:, :, : y.size(2), : y.size(3)]
return y_hat
48 changes: 48 additions & 0 deletions tests/models/nets/test_wisenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
from sslt.models.nets.wisenet import WiseNet


def test_wisenet_loss():
model = WiseNet()
batch_size = 2
mask_shape = (batch_size, 1, 500, 500) # (2, 1, 500, 500)
input_shape = *mask_shape[:2], 17, *mask_shape[2:] # (2, 1, 17, 500, 500)

# Input X is volume of 17 slices of 500x500, 1 channel.
# So, it is a 5D tensor of shape (B, C, D, H, W), where B is the batch
# size, C is the number of channels, D is the depth, H is the
# height and W is the width.
x = torch.rand(*input_shape)
# The mask is a single 2-D panel of 500x500, 1 channel.
# It is a 4D tensor of shape (B, C, H, W), where B is the batch size,
# C is the number of channels, H is the height and W is the width.
mask = torch.rand(*mask_shape)

# Do the training step
loss = model.training_step((x, mask), 0).item()
assert loss is not None
assert loss >= 0, f"Expected non-negative loss, but got {loss}"


def test_wisenet_predict():
model = WiseNet()
batch_size = 2
mask_shape = (batch_size, 1, 500, 500) # (2, 1, 500, 500)
input_shape = *mask_shape[:2], 17, *mask_shape[2:] # (2, 1, 17, 500, 500)

# Input X is volume of 17 slices of 500x500, 1 channel.
# So, it is a 5D tensor of shape (B, C, D, H, W), where B is the batch
# size, C is the number of channels, D is the depth, H is the
# height and W is the width.
x = torch.rand(*input_shape)
# The mask is a single 2-D panel of 500x500, 1 channel.
# It is a 4D tensor of shape (B, C, H, W), where B is the batch size,
# C is the number of channels, H is the height and W is the width.
mask = torch.rand(*mask_shape)

# Do the prediction step
preds = model.predict_step((x, mask), 0)
assert preds is not None
assert (
preds.shape == mask_shape
), f"Expected shape {mask_shape}, but got {preds.shape}"