Skip to content

Commit

Permalink
[cm] Adding customizable activations to TCN block
Browse files Browse the repository at this point in the history
  • Loading branch information
christhetree committed Dec 27, 2023
1 parent 9f34d5d commit 01566f3
Showing 1 changed file with 70 additions and 1 deletion.
71 changes: 70 additions & 1 deletion neutone_sdk/tcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self,
use_ln: bool = False,
temporal_dim: Optional[int] = None,
use_act: bool = True,
act_name: str = "prelu",
use_res: bool = True,
cond_dim: int = 0,
use_film_bn: bool = True, # TODO(cm): check if this should be false
Expand All @@ -62,6 +63,7 @@ def __init__(self,
self.use_ln = use_ln
self.temporal_dim = temporal_dim
self.use_act = use_act
self.act_name = act_name
self.use_res = use_res
self.cond_dim = cond_dim
self.use_film_bn = use_film_bn
Expand All @@ -75,7 +77,7 @@ def __init__(self,

self.act = None
if use_act:
self.act = nn.PReLU(out_channels)
self.act = self.get_activation(act_name, out_channels)

self.conv = Conv1dGeneral(in_channels,
out_channels,
Expand Down Expand Up @@ -177,6 +179,70 @@ def forward(self, x: Tensor, cond: Optional[Tensor] = None) -> Tensor:
x += x_res
return x

@staticmethod
def get_activation(act_name: str, out_ch: Optional[int] = None) -> nn.Module:
"""
Most of the code and experimental results in this method are from
https://github.com/csteinmetz1/ronn
Given an activation name string, returns the corresponding activation function.
Args:
act_name: Name of the activation function.
out_ch: Optional number of output channels. Only used for determining the
number of parameters in the PReLU activation function.
Returns:
act: PyTorch activation function.
Experimental results for randomized overdrive neural networks.
----------------------
- ReLU: solid distortion
- LeakyReLU: somewhat veiled sound
- Tanh: insane levels of distortion with lots of aliasing (HF)
- Sigmoid: too gritty to be useful
- ELU: fading in and out
- RReLU: really interesting HF noise with a background sound
- SELU: rolled off soft distortion sound
- GELU: roomy, not too interesting
- Softplus: heavily distorted signal but with a very rolled off sound. (nice)
- Softshrink: super distant sounding and somewhat roomy
"""
act_name = act_name.lower()
if act_name == "relu":
act = nn.ReLU()
elif act_name == "leakyrelu":
act = nn.LeakyReLU()
elif act_name == "tanh":
act = nn.Tanh()
elif act_name == "sigmoid":
act = nn.Sigmoid()
elif act_name == "elu":
act = nn.ELU()
elif act_name == "rrelu":
act = nn.RReLU()
elif act_name == "selu":
act = nn.SELU()
elif act_name == "gelu":
act = nn.GELU()
elif act_name == "softplus":
act = nn.Softplus()
elif act_name == "softshrink":
act = nn.Softshrink()
elif act_name == "silu" or act_name == "swish":
act = nn.SiLU()
elif act_name == "prelu":
if out_ch is None:
act = nn.PReLU()
else:
act = nn.PReLU(out_ch)
elif act_name == "prelu1":
act = nn.PReLU()
else:
raise ValueError(f"Invalid activation name: '{act_name}'.")

return act


class TCN(nn.Module):
def __init__(self,
Expand All @@ -195,6 +261,7 @@ def __init__(self,
use_ln: bool = False,
temporal_dims: Optional[List[int]] = None,
use_act: bool = True,
act_name: str = "prelu",
use_res: bool = True,
cond_dim: int = 0,
use_film_bn: bool = True, # TODO(cm): check if this should be false
Expand All @@ -215,6 +282,7 @@ def __init__(self,
self.use_ln = use_ln
self.temporal_dims = temporal_dims
self.use_act = use_act
self.act_name = act_name
self.use_res = use_res
self.cond_dim = cond_dim
self.use_film_bn = use_film_bn
Expand Down Expand Up @@ -269,6 +337,7 @@ def __init__(self,
use_ln,
temp_dim,
use_act,
act_name,
use_res,
cond_dim,
use_film_bn,
Expand Down

0 comments on commit 01566f3

Please sign in to comment.