Skip to content

Commit

Permalink
feat: add resnet, upsample, downsample kernel_size/width, more upsamp…
Browse files Browse the repository at this point in the history
…le options
  • Loading branch information
flavioschneider committed Feb 3, 2023
1 parent ab3c5da commit e0933e7
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 15 deletions.
45 changes: 37 additions & 8 deletions a_unet/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Sequential,
T,
Upsample,
UpsampleInterpolate,
default,
exists,
)
Expand All @@ -46,6 +47,7 @@ def DownsampleItem(
factor: Optional[int] = None,
in_channels: Optional[int] = None,
channels: Optional[int] = None,
downsample_width: int = 1,
**kwargs,
) -> nn.Module:
msg = "DownsampleItem requires dim, factor, in_channels, channels"
Expand All @@ -54,7 +56,11 @@ def DownsampleItem(
), msg
Item = SelectX(Downsample)
return Item( # type: ignore
dim=dim, factor=factor, in_channels=in_channels, out_channels=channels
dim=dim,
factor=factor,
width=downsample_width,
in_channels=in_channels,
out_channels=channels,
)


Expand All @@ -63,16 +69,34 @@ def UpsampleItem(
factor: Optional[int] = None,
channels: Optional[int] = None,
out_channels: Optional[int] = None,
upsample_mode: str = "nearest",
upsample_kernel_size: int = 3, # Used with upsample_mode != "transpose"
upsample_width: int = 1, # Used with upsample_mode == "transpose"
**kwargs,
) -> nn.Module:
msg = "UpsampleItem requires dim, factor, channels, out_channels"
assert (
exists(dim) and exists(factor) and exists(channels) and exists(out_channels)
), msg
Item = SelectX(Upsample)
return Item( # type: ignore
dim=dim, factor=factor, in_channels=channels, out_channels=out_channels
)
if upsample_mode == "transpose":
Item = SelectX(Upsample)
return Item( # type: ignore
dim=dim,
factor=factor,
width=upsample_width,
in_channels=channels,
out_channels=out_channels,
)
else:
Item = SelectX(UpsampleInterpolate)
return Item( # type: ignore
dim=dim,
factor=factor,
mode=upsample_mode,
kernel_size=upsample_kernel_size,
in_channels=channels,
out_channels=out_channels,
)


""" Main """
Expand All @@ -82,15 +106,20 @@ def ResnetItem(
dim: Optional[int] = None,
channels: Optional[int] = None,
resnet_groups: Optional[int] = None,
resnet_kernel_size: int = 3,
**kwargs,
) -> nn.Module:
msg = "ResnetItem requires dim, channels, and resnet_groups"
assert exists(dim) and exists(channels) and exists(resnet_groups), msg
Item = SelectX(ResnetBlock)
conv_block_t = T(ConvBlock)(norm_t=T(nn.GroupNorm)(num_groups=resnet_groups))
return Item(
dim=dim, in_channels=channels, out_channels=channels, conv_block_t=conv_block_t
) # type: ignore
return Item( # type: ignore
dim=dim,
in_channels=channels,
out_channels=channels,
kernel_size=resnet_kernel_size,
conv_block_t=conv_block_t,
)


def ConvNextV2Item(
Expand Down
57 changes: 51 additions & 6 deletions a_unet/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,57 @@ def Conv(dim: int, *args, **kwargs) -> nn.Module:
return [nn.Conv1d, nn.Conv2d, nn.Conv3d][dim - 1](*args, **kwargs)


def Downsample(dim: int, factor: int = 2, conv_t=Conv, **kwargs) -> nn.Module:
return conv_t(dim=dim, kernel_size=factor, stride=factor, **kwargs)
def ConvTranspose(dim: int, *args, **kwargs) -> nn.Module:
return [nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d][dim - 1](
*args, **kwargs
)


def Downsample(
dim: int, factor: int = 2, width: int = 1, conv_t=Conv, **kwargs
) -> nn.Module:
width = width if factor > 1 else 1
return conv_t(
dim=dim,
kernel_size=factor * width,
stride=factor,
padding=(factor * width - factor) // 2,
**kwargs,
)


def Upsample(
dim: int, factor: int = 2, mode: str = "nearest", conv_t=Conv, **kwargs
dim: int,
factor: int = 2,
width: int = 1,
conv_t=Conv,
conv_tranpose_t=ConvTranspose,
**kwargs,
) -> nn.Module:
width = width if factor > 1 else 1
return conv_tranpose_t(
dim=dim,
kernel_size=factor * width,
stride=factor,
padding=(factor * width - factor) // 2,
**kwargs,
)


def UpsampleInterpolate(
dim: int,
factor: int = 2,
kernel_size: int = 3,
mode: str = "nearest",
conv_t=Conv,
**kwargs,
) -> nn.Module:
assert kernel_size % 2 == 1, "upsample kernel size must be odd"
return nn.Sequential(
nn.Upsample(scale_factor=factor, mode="nearest"),
conv_t(dim=dim, kernel_size=3, padding=1, **kwargs),
nn.Upsample(scale_factor=factor, mode=mode),
conv_t(
dim=dim, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, **kwargs
),
)


Expand All @@ -165,10 +206,14 @@ def ResnetBlock(
dim: int,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
conv_block_t=ConvBlock,
conv_t=Conv,
**kwargs,
) -> nn.Module:
ConvBlock = T(conv_block_t)(dim=dim, kernel_size=3, padding=1)
ConvBlock = T(conv_block_t)(
dim=dim, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, **kwargs
)
Conv = T(conv_t)(dim=dim, kernel_size=1)

conv_block = Sequential(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="a-unet",
packages=find_packages(exclude=[]),
version="0.0.15",
version="0.0.16",
license="MIT",
description="A-UNet",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit e0933e7

Please sign in to comment.