diff --git a/neutone_sdk/tcn.py b/neutone_sdk/tcn.py index c63b850..ddc5dd2 100644 --- a/neutone_sdk/tcn.py +++ b/neutone_sdk/tcn.py @@ -54,6 +54,7 @@ def __init__(self, use_ln: bool = False, temporal_dim: Optional[int] = None, use_act: bool = True, + act_name: str = "prelu", use_res: bool = True, cond_dim: int = 0, use_film_bn: bool = True, # TODO(cm): check if this should be false @@ -62,6 +63,7 @@ def __init__(self, self.use_ln = use_ln self.temporal_dim = temporal_dim self.use_act = use_act + self.act_name = act_name self.use_res = use_res self.cond_dim = cond_dim self.use_film_bn = use_film_bn @@ -75,7 +77,7 @@ def __init__(self, self.act = None if use_act: - self.act = nn.PReLU(out_channels) + self.act = self.get_activation(act_name, out_channels) self.conv = Conv1dGeneral(in_channels, out_channels, @@ -177,6 +179,70 @@ def forward(self, x: Tensor, cond: Optional[Tensor] = None) -> Tensor: x += x_res return x + @staticmethod + def get_activation(act_name: str, out_ch: Optional[int] = None) -> nn.Module: + """ + Most of the code and experimental results in this method are from + https://github.com/csteinmetz1/ronn + + Given an activation name string, returns the corresponding activation function. + + Args: + act_name: Name of the activation function. + out_ch: Optional number of output channels. Only used for determining the + number of parameters in the PReLU activation function. + + Returns: + act: PyTorch activation function. + + Experimental results for randomized overdrive neural networks. + ---------------------- + - ReLU: solid distortion + - LeakyReLU: somewhat veiled sound + - Tanh: insane levels of distortion with lots of aliasing (HF) + - Sigmoid: too gritty to be useful + - ELU: fading in and out + - RReLU: really interesting HF noise with a background sound + - SELU: rolled off soft distortion sound + - GELU: roomy, not too interesting + - Softplus: heavily distorted signal but with a very rolled off sound. (nice) + - Softshrink: super distant sounding and somewhat roomy + """ + act_name = act_name.lower() + if act_name == "relu": + act = nn.ReLU() + elif act_name == "leakyrelu": + act = nn.LeakyReLU() + elif act_name == "tanh": + act = nn.Tanh() + elif act_name == "sigmoid": + act = nn.Sigmoid() + elif act_name == "elu": + act = nn.ELU() + elif act_name == "rrelu": + act = nn.RReLU() + elif act_name == "selu": + act = nn.SELU() + elif act_name == "gelu": + act = nn.GELU() + elif act_name == "softplus": + act = nn.Softplus() + elif act_name == "softshrink": + act = nn.Softshrink() + elif act_name == "silu" or act_name == "swish": + act = nn.SiLU() + elif act_name == "prelu": + if out_ch is None: + act = nn.PReLU() + else: + act = nn.PReLU(out_ch) + elif act_name == "prelu1": + act = nn.PReLU() + else: + raise ValueError(f"Invalid activation name: '{act_name}'.") + + return act + class TCN(nn.Module): def __init__(self, @@ -195,6 +261,7 @@ def __init__(self, use_ln: bool = False, temporal_dims: Optional[List[int]] = None, use_act: bool = True, + act_name: str = "prelu", use_res: bool = True, cond_dim: int = 0, use_film_bn: bool = True, # TODO(cm): check if this should be false @@ -215,6 +282,7 @@ def __init__(self, self.use_ln = use_ln self.temporal_dims = temporal_dims self.use_act = use_act + self.act_name = act_name self.use_res = use_res self.cond_dim = cond_dim self.use_film_bn = use_film_bn @@ -269,6 +337,7 @@ def __init__(self, use_ln, temp_dim, use_act, + act_name, use_res, cond_dim, use_film_bn,