Skip to content

Commit

Permalink
[cm] Adding doc strings
Browse files Browse the repository at this point in the history
  • Loading branch information
christhetree committed Feb 6, 2024
1 parent 3ee2f4b commit 3f9ecf4
Showing 1 changed file with 68 additions and 30 deletions.
98 changes: 68 additions & 30 deletions neutone_sdk/non_realtime_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class NonRealtimeMetadata(NamedTuple):

class NonRealtimeBase(NeutoneModel):
def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:
"""
Wraps a PyTorch model for use in a non-realtime context.
Compatible with the Neutone Gen plugin.
"""
super().__init__(model, use_debug_mode)
self.progress_percentage = 0
self.cancel_forward_pass_requested = False
Expand Down Expand Up @@ -202,10 +206,24 @@ def _get_all_neutone_parameters(self) -> List[NeutoneParameter]:

@abstractmethod
def get_audio_in_channels(self) -> List[int]:
"""
Returns a list of the number of audio channels that the model expects as input.
If the model does not require audio input, an empty list should be returned.
Currently only supports mono and stereo audio.
Example value: [2]
"""
pass

@abstractmethod
def get_audio_out_channels(self) -> List[int]:
"""
Returns a list of the number of audio channels that the model outputs.
Models must output at least one audio track.
Currently only supports mono and stereo audio.
Example value: [2]
"""
pass

@abstractmethod
Expand All @@ -215,7 +233,7 @@ def get_native_sample_rates(self) -> List[int]:
with. If the list is empty, all common sample rates are assumed to be
supported.
Example value: [44100, 48000]
Example value: [44100]
"""
pass

Expand All @@ -234,7 +252,8 @@ def get_native_buffer_sizes(self) -> List[int]:
def is_one_shot_model(self) -> bool:
"""
Returns True if the model is a one-shot model, i.e. it must process the entire
input buffer in one go.
input audio and / or parameters at once. If this is False, it is assumed that
the model can process audio and parameters in blocks.
"""
pass

Expand All @@ -249,28 +268,32 @@ def do_forward_pass(self,
The inputs to this method should be treated as read-only.
Args:
x:
torch Tensor of shape [num_channels, num_samples]
num_channels is 1 if `is_input_mono` is set to True, otherwise 2
num_samples will be one of the sizes specified in `get_native_buffer_sizes`
If a look-behind buffer is being used, see `get_look_behind_samples` for details on the shape of x.
curr_block_idx:
The index of the current block being processed. This is only relevant if
the model is not a one-shot model and will always be 0 otherwise.
audio_in:
List of torch Tensors of shape [num_channels, num_samples].
num_samples will be one of the sizes specified in
`get_native_buffer_sizes()` if not a one-shot model.
The sample rate of the audio will also be one of the ones specified in
`get_native_sample_rates`.
The best combination is chosen based on the DAW parameters at runtime. If
unsure, provide only one value in the lists.
`get_native_sample_rates()`.
cont_params:
Python dictionary mapping from parameter names (defined by the values in
get_parameters) to values. By default, we aggregate the sample values over the
entire buffer and provide the mean value.
Override the `aggregate_params` method for more fine grained control.
Python dictionary mapping from continuous and categorical (knob)
parameter names (defined by the values in `get_neutone_parameters()` to
values. By default, we aggregate the parameters to a single value per
parameter for the current audio being processed.
Overwrite `aggregate_continuous_params` and
`aggregate_categorical_params` for more fine-grained control.
text_params:
List of strings containing the text parameters. Will be empty if the
model does not have any text parameters.
Returns:
torch Tensor of shape [num_channels, num_samples]
The shape of the output must match the shape of the input.
List of torch Tensors of shape [num_channels, num_samples] representing the
output audio. The number of channels of the output audio tracks should match
the values returned by `get_audio_out_channels()`. The sample rate of the
output audio tracks should be the same as the input audio tracks which will
be one of the values specified in `get_native_sample_rates()`.
"""
pass

Expand Down Expand Up @@ -304,25 +327,35 @@ def reset_model(self) -> bool:

def aggregate_continuous_params(self, cont_params: Tensor) -> Tensor:
"""
Aggregates parameters of size (MAX_N_PARAMS, buffer_size) to single values.
Aggregates parameters of shape (n_cont_params, buffer_size) to single values.
By default we take the mean value along dimension 1 to provide a single value for each parameter
for the current buffer.
For more fine grained control, override this method as required.
By default we take the mean value along dimension 1 to provide a single value
for each parameter for the current buffer.
For more fine-grained control, override this method as required.
"""
if self.use_debug_mode:
assert cont_params.ndim == 2
return tr.mean(cont_params, dim=1, keepdim=True)

def aggregate_categorical_params(self, cat_params: Tensor) -> Tensor:
"""
Aggregates parameters of shape (n_cat_params, buffer_size) to single values.
By default we take the first value for each parameter for the current buffer.
For more fine-grained control, override this method as required.
"""
if self.use_debug_mode:
assert cat_params.ndim == 2
return cat_params[:, :1]

def set_progress_percentage(self, progress_percentage: int) -> None:
"""
Sets the progress percentage of the model.
This can be used to indicate the progress of the model to the user. This is
especially useful for long-running one-shot models. The progress percentage
should be between 0 and 100.
"""
if self.use_debug_mode:
assert 0 <= progress_percentage <= 100, \
"Progress percentage must be between 0 and 100"
Expand All @@ -334,9 +367,12 @@ def forward(self,
knob_params: Optional[Tensor] = None,
text_params: Optional[List[str]] = None) -> List[Tensor]:
"""
Internal forward pass for a WaveformToWaveform model.
Internal forward pass for a NonRealtimeBase wrapped model.
If `knob_params` or `text_params` is None, they are populated with their
default values when applicable.
If params is None, we fill in the default values.
This method should not be overwritten by SDK users.
"""
if text_params is None:
# Needed for TorchScript typing
Expand Down Expand Up @@ -439,15 +475,16 @@ def calc_model_delay_samples(self) -> int:
def set_sample_rate_and_buffer_size(self, sample_rate: int, n_samples: int) -> bool:
"""
Sets the sample_rate and buffer size of the wrapper.
This should not be overwritten by SDK users, instead please check out
This should not be overwritten by SDK users, instead please override
the 'set_model_sample_rate_and_buffer_size' method.
Args:
sample_rate: The sample rate to use.
n_samples: The number of samples to use.
Returns:
bool: True if 'set_model_sample_rate_and_buffer_size' is implemented and successful, otherwise False.
bool: True if 'set_model_sample_rate_and_buffer_size' is implemented and
successful, otherwise False.
"""
if self.use_debug_mode:
if self.get_native_buffer_sizes():
Expand All @@ -461,7 +498,8 @@ def set_sample_rate_and_buffer_size(self, sample_rate: int, n_samples: int) -> b
def reset(self) -> bool:
"""
Resets the wrapper.
This should not be overwritten by SDK users, instead please check out the 'reset_model' method.
This should not be overwritten by SDK users, instead please override the
'reset_model' method.
Returns:
bool: True if 'reset_model' is implemented and successful, otherwise False.
Expand Down

0 comments on commit 3f9ecf4

Please sign in to comment.