From 3f9ecf4ef3aff67f1140bb3b5ebf552b79fcc041 Mon Sep 17 00:00:00 2001 From: christhetree Date: Tue, 6 Feb 2024 23:17:46 +0000 Subject: [PATCH] [cm] Adding doc strings --- neutone_sdk/non_realtime_wrapper.py | 98 ++++++++++++++++++++--------- 1 file changed, 68 insertions(+), 30 deletions(-) diff --git a/neutone_sdk/non_realtime_wrapper.py b/neutone_sdk/non_realtime_wrapper.py index 5370778..15b5e92 100644 --- a/neutone_sdk/non_realtime_wrapper.py +++ b/neutone_sdk/non_realtime_wrapper.py @@ -44,6 +44,10 @@ class NonRealtimeMetadata(NamedTuple): class NonRealtimeBase(NeutoneModel): def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None: + """ + Wraps a PyTorch model for use in a non-realtime context. + Compatible with the Neutone Gen plugin. + """ super().__init__(model, use_debug_mode) self.progress_percentage = 0 self.cancel_forward_pass_requested = False @@ -202,10 +206,24 @@ def _get_all_neutone_parameters(self) -> List[NeutoneParameter]: @abstractmethod def get_audio_in_channels(self) -> List[int]: + """ + Returns a list of the number of audio channels that the model expects as input. + If the model does not require audio input, an empty list should be returned. + Currently only supports mono and stereo audio. + + Example value: [2] + """ pass @abstractmethod def get_audio_out_channels(self) -> List[int]: + """ + Returns a list of the number of audio channels that the model outputs. + Models must output at least one audio track. + Currently only supports mono and stereo audio. + + Example value: [2] + """ pass @abstractmethod @@ -215,7 +233,7 @@ def get_native_sample_rates(self) -> List[int]: with. If the list is empty, all common sample rates are assumed to be supported. - Example value: [44100, 48000] + Example value: [44100] """ pass @@ -234,7 +252,8 @@ def get_native_buffer_sizes(self) -> List[int]: def is_one_shot_model(self) -> bool: """ Returns True if the model is a one-shot model, i.e. it must process the entire - input buffer in one go. + input audio and / or parameters at once. If this is False, it is assumed that + the model can process audio and parameters in blocks. """ pass @@ -249,28 +268,32 @@ def do_forward_pass(self, The inputs to this method should be treated as read-only. Args: - x: - torch Tensor of shape [num_channels, num_samples] - num_channels is 1 if `is_input_mono` is set to True, otherwise 2 - num_samples will be one of the sizes specified in `get_native_buffer_sizes` - If a look-behind buffer is being used, see `get_look_behind_samples` for details on the shape of x. - + curr_block_idx: + The index of the current block being processed. This is only relevant if + the model is not a one-shot model and will always be 0 otherwise. + audio_in: + List of torch Tensors of shape [num_channels, num_samples]. + num_samples will be one of the sizes specified in + `get_native_buffer_sizes()` if not a one-shot model. The sample rate of the audio will also be one of the ones specified in - `get_native_sample_rates`. - - The best combination is chosen based on the DAW parameters at runtime. If - unsure, provide only one value in the lists. + `get_native_sample_rates()`. cont_params: - Python dictionary mapping from parameter names (defined by the values in - get_parameters) to values. By default, we aggregate the sample values over the - entire buffer and provide the mean value. - - Override the `aggregate_params` method for more fine grained control. + Python dictionary mapping from continuous and categorical (knob) + parameter names (defined by the values in `get_neutone_parameters()` to + values. By default, we aggregate the parameters to a single value per + parameter for the current audio being processed. + Overwrite `aggregate_continuous_params` and + `aggregate_categorical_params` for more fine-grained control. + text_params: + List of strings containing the text parameters. Will be empty if the + model does not have any text parameters. Returns: - torch Tensor of shape [num_channels, num_samples] - - The shape of the output must match the shape of the input. + List of torch Tensors of shape [num_channels, num_samples] representing the + output audio. The number of channels of the output audio tracks should match + the values returned by `get_audio_out_channels()`. The sample rate of the + output audio tracks should be the same as the input audio tracks which will + be one of the values specified in `get_native_sample_rates()`. """ pass @@ -304,12 +327,11 @@ def reset_model(self) -> bool: def aggregate_continuous_params(self, cont_params: Tensor) -> Tensor: """ - Aggregates parameters of size (MAX_N_PARAMS, buffer_size) to single values. + Aggregates parameters of shape (n_cont_params, buffer_size) to single values. - By default we take the mean value along dimension 1 to provide a single value for each parameter - for the current buffer. - - For more fine grained control, override this method as required. + By default we take the mean value along dimension 1 to provide a single value + for each parameter for the current buffer. + For more fine-grained control, override this method as required. """ if self.use_debug_mode: assert cont_params.ndim == 2 @@ -317,12 +339,23 @@ def aggregate_continuous_params(self, cont_params: Tensor) -> Tensor: def aggregate_categorical_params(self, cat_params: Tensor) -> Tensor: """ + Aggregates parameters of shape (n_cat_params, buffer_size) to single values. + + By default we take the first value for each parameter for the current buffer. + For more fine-grained control, override this method as required. """ if self.use_debug_mode: assert cat_params.ndim == 2 return cat_params[:, :1] def set_progress_percentage(self, progress_percentage: int) -> None: + """ + Sets the progress percentage of the model. + + This can be used to indicate the progress of the model to the user. This is + especially useful for long-running one-shot models. The progress percentage + should be between 0 and 100. + """ if self.use_debug_mode: assert 0 <= progress_percentage <= 100, \ "Progress percentage must be between 0 and 100" @@ -334,9 +367,12 @@ def forward(self, knob_params: Optional[Tensor] = None, text_params: Optional[List[str]] = None) -> List[Tensor]: """ - Internal forward pass for a WaveformToWaveform model. + Internal forward pass for a NonRealtimeBase wrapped model. + + If `knob_params` or `text_params` is None, they are populated with their + default values when applicable. - If params is None, we fill in the default values. + This method should not be overwritten by SDK users. """ if text_params is None: # Needed for TorchScript typing @@ -439,7 +475,7 @@ def calc_model_delay_samples(self) -> int: def set_sample_rate_and_buffer_size(self, sample_rate: int, n_samples: int) -> bool: """ Sets the sample_rate and buffer size of the wrapper. - This should not be overwritten by SDK users, instead please check out + This should not be overwritten by SDK users, instead please override the 'set_model_sample_rate_and_buffer_size' method. Args: @@ -447,7 +483,8 @@ def set_sample_rate_and_buffer_size(self, sample_rate: int, n_samples: int) -> b n_samples: The number of samples to use. Returns: - bool: True if 'set_model_sample_rate_and_buffer_size' is implemented and successful, otherwise False. + bool: True if 'set_model_sample_rate_and_buffer_size' is implemented and + successful, otherwise False. """ if self.use_debug_mode: if self.get_native_buffer_sizes(): @@ -461,7 +498,8 @@ def set_sample_rate_and_buffer_size(self, sample_rate: int, n_samples: int) -> b def reset(self) -> bool: """ Resets the wrapper. - This should not be overwritten by SDK users, instead please check out the 'reset_model' method. + This should not be overwritten by SDK users, instead please override the + 'reset_model' method. Returns: bool: True if 'reset_model' is implemented and successful, otherwise False.