Skip to content

Commit

Permalink
also include feature maps from conditioning encoder in the decoder, i…
Browse files Browse the repository at this point in the history
…f skip_connect_condition_fmaps is set to True
  • Loading branch information
lucidrains committed Nov 30, 2022
1 parent 453b495 commit 0707009
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
23 changes: 14 additions & 9 deletions med_seg_diff_pytorch/med_seg_diff_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.nn.functional as F
from torch.fft import fft2, ifft2

from einops import rearrange, reduce
from einops import rearrange, reduce, pack
from einops.layers.torch import Rearrange

from tqdm.auto import tqdm
Expand Down Expand Up @@ -235,7 +235,8 @@ def __init__(
channels = 3,
self_condition = False,
resnet_block_groups = 8,
conditioning_klass = Conditioning
conditioning_klass = Conditioning,
skip_connect_condition_fmaps = False # whether to concatenate the conditioning fmaps in the latter decoder upsampling portion of unet
):
super().__init__()

Expand Down Expand Up @@ -274,6 +275,8 @@ def __init__(

self.conditioners = nn.ModuleList([])

self.skip_connect_condition_fmaps = skip_connect_condition_fmaps

# downsampling encoding blocks

self.downs = nn.ModuleList([])
Expand Down Expand Up @@ -314,9 +317,11 @@ def __init__(
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
is_last = ind == (len(in_out) - 1)

skip_connect_dim = dim_in * (2 if self.skip_connect_condition_fmaps else 1)

self.ups.append(nn.ModuleList([
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
block_klass(dim_out + skip_connect_dim, dim_out, time_emb_dim = time_dim),
block_klass(dim_out + skip_connect_dim, dim_out, time_emb_dim = time_dim),
Residual(LinearAttention(dim_out)),
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
]))
Expand All @@ -333,7 +338,7 @@ def forward(
cond,
x_self_cond = None
):
dtype = x.dtype
dtype, skip_connect_c = x.dtype, self.skip_connect_condition_fmaps

if self.self_condition:
x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
Expand All @@ -352,7 +357,7 @@ def forward(
x = block1(x, t)
c = cond_block1(c, t)

h.append(x)
h.append([x, c] if skip_connect_c else [x])

x = block2(x, t)
c = cond_block2(c, t)
Expand All @@ -365,7 +370,7 @@ def forward(

c = conditioner(x, c)

h.append(x)
h.append([x, c] if skip_connect_c else [x])

x = downsample(x)
c = cond_downsample(c)
Expand All @@ -379,10 +384,10 @@ def forward(
x = self.mid_block2(x, t)

for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim = 1)
x = torch.cat((x, *h.pop()), dim = 1)
x = block1(x, t)

x = torch.cat((x, h.pop()), dim = 1)
x = torch.cat((x, *h.pop()), dim = 1)
x = block2(x, t)
x = attn(x)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'med-seg-diff-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.4',
version = '0.0.5',
license='MIT',
description = 'MedSegDiff - SOTA medical image segmentation - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 0707009

Please sign in to comment.