From 1f9b011f76a3c7b524113a9392c3fd0fe3a83973 Mon Sep 17 00:00:00 2001 From: christhetree Date: Mon, 5 Feb 2024 23:56:16 +0000 Subject: [PATCH] [cm] Fixing torchscript typing errors --- examples/example_clipper_gen.py | 21 +++++++---- examples/music_gen.py | 25 +++++++------ neutone_sdk/non_realtime_wrapper.py | 54 ++++++++++++++++++++++++++--- 3 files changed, 79 insertions(+), 21 deletions(-) diff --git a/examples/example_clipper_gen.py b/examples/example_clipper_gen.py index 196f1dd..03e534d 100644 --- a/examples/example_clipper_gen.py +++ b/examples/example_clipper_gen.py @@ -8,7 +8,8 @@ import torch.nn as nn from torch import Tensor -from neutone_sdk import NeutoneParameter, KnobNeutoneParameter, TextNeutoneParameter +from neutone_sdk import NeutoneParameter, TextNeutoneParameter, \ + ContinuousNeutoneParameter, CategoricalNeutoneParameter from neutone_sdk.non_realtime_wrapper import NonRealtimeBase logging.basicConfig() @@ -59,10 +60,12 @@ def is_experimental(self) -> bool: def get_neutone_parameters(self) -> List[NeutoneParameter]: return [ - KnobNeutoneParameter("min", "min clip threshold", default_value=0.15), - KnobNeutoneParameter("max", "max clip threshold", default_value=0.15), TextNeutoneParameter("text_param", "testing"), - KnobNeutoneParameter("gain", "scale clip threshold", default_value=1.0), + ContinuousNeutoneParameter("min", "min clip threshold", default_value=0.15), + CategoricalNeutoneParameter("cat", "catty", n_values=3, default_value=2), + ContinuousNeutoneParameter("max", "max clip threshold", default_value=0.15), + ContinuousNeutoneParameter("gain", "scale clip threshold", default_value=1.0), + # ContinuousNeutoneParameter("gain2", "scale clip threshold", default_value=1.0), ] @tr.jit.export @@ -93,6 +96,9 @@ def do_forward_pass(self, audio_in: List[Tensor], cont_params: Dict[str, Tensor], text_params: List[str]) -> List[Tensor]: + # print(cont_params) + # print(text_params) + # exit() min_val, max_val, gain = cont_params["min"], cont_params["max"], cont_params["gain"] audio_out = [] for x in audio_in: @@ -109,7 +115,10 @@ def do_forward_pass(self, model = ClipperModel() wrapper = ClipperModelWrapper(model) - wrapper.forward(0, [tr.rand(2, 2048)], text_params=["ayy"]) + # wrapper.forward(0, [tr.rand(2, 2048)], knob_params=tr.tensor([[0.5], [0.1], [0.2], [0.3]])) + # wrapper.forward(0, [tr.rand(2, 2048)], text_params=["ayy"]) + wrapper.forward(0, [tr.rand(2, 2048)]) ts = tr.jit.script(wrapper) - ts.forward(0, [tr.rand(2, 2048)], text_params=["ayy"]) + ts.forward(0, [tr.rand(2, 2048)]) + # ts.forward(0, [tr.rand(2, 2048)], text_params=["ayy"]) diff --git a/examples/music_gen.py b/examples/music_gen.py index 56e0345..92933d3 100644 --- a/examples/music_gen.py +++ b/examples/music_gen.py @@ -1,12 +1,12 @@ import logging import os -from typing import Dict, List, Optional +from typing import Dict, List import torch as tr from torch import Tensor from neutone_sdk import NeutoneParameter, TextNeutoneParameter, \ - ContinuousNeutoneParameter + CategoricalNeutoneParameter from neutone_sdk.non_realtime_wrapper import NonRealtimeBase logging.basicConfig() @@ -47,9 +47,15 @@ def is_experimental(self) -> bool: def get_neutone_parameters(self) -> List[NeutoneParameter]: return [ - TextNeutoneParameter("prompt", "testing"), - # TODO(cm): convert to categorical - ContinuousNeutoneParameter("duration", "how many seconds to generate", default_value=0.0), + TextNeutoneParameter(name="prompt", + description="text prompt for generation", + max_n_chars=128, + default_value="techno kick drum"), + CategoricalNeutoneParameter(name="duration", + description="how many seconds to generate", + n_values=8, + default_value=0, + labels=[str(idx) for idx in range(1, 9)]), ] @tr.jit.export @@ -75,12 +81,9 @@ def is_one_shot_model(self) -> bool: def do_forward_pass(self, curr_block_idx: int, audio_in: List[Tensor], - cont_params: Dict[str, Tensor], + knob_params: Dict[str, Tensor], text_params: List[str]) -> List[Tensor]: - raw_time = cont_params["duration"] - # Limit duration between 1 and 8 seconds - n_seconds = tr.clamp(raw_time * 7.0 + 1.0, 1.0, 8.0).item() - n_seconds = int(n_seconds) + n_seconds = knob_params["duration"].item() + 1 # Convert duration to number of tokens n_tokens = (n_seconds * 50) + 4 if self.use_debug_mode: @@ -121,7 +124,7 @@ def do_forward_pass(self, wrapper = MusicGenModelWrapper(model) audio_out = wrapper.forward(curr_block_idx=0, audio_in=[], - knob_params=tr.tensor([0.0, 1.0]).unsqueeze(1), + knob_params=tr.tensor([0.0]).unsqueeze(1), text_params=["testing"]) log.info(audio_out[0].shape) diff --git a/neutone_sdk/non_realtime_wrapper.py b/neutone_sdk/non_realtime_wrapper.py index 16807c0..4d11e99 100644 --- a/neutone_sdk/non_realtime_wrapper.py +++ b/neutone_sdk/non_realtime_wrapper.py @@ -102,6 +102,20 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None: self.text_param_max_n_chars.append(p.max_n_chars) self.text_param_default_values.append(p.default_value) + # This is needed for TorchScript typing since it doesn't allow empty lists etc. + if not self.n_cont_params: + self.cont_param_names.append("__torchscript_typing") + self.cont_param_indices.append(-1) + if not self.n_cat_params: + self.cat_param_names.append("__torchscript_typing") + self.cat_param_indices.append(-1) + self.cat_param_n_values["__torchscript_typing"] = -1 + self.cat_param_labels["__torchscript_typing"] = ["__torchscript_typing"] + if not self.n_text_params: + self.text_param_names.append("__torchscript_typing") + self.text_param_max_n_chars.append(-1) + self.text_param_default_values.append("__torchscript_typing") + self.n_knob_params = self.n_cont_params + self.n_cat_params assert self.n_knob_params <= constants.NEUTONE_GEN_N_KNOB_PARAMS, ( @@ -118,8 +132,8 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None: all_neutone_parameters = self._get_all_neutone_parameters() assert len(all_neutone_parameters) == self._get_max_n_params() - # This overrides the base class definitions to remove the text param since - # it is handled separately in the UI. + # This overrides the base class definitions to remove the text param or extra + # base param since it is handled separately in the UI. if self.has_text_param: self.neutone_parameter_names = [p.name for p in all_neutone_parameters if p.type != NeutoneParameterType.TEXT] @@ -131,6 +145,15 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None: if p.type != NeutoneParameterType.TEXT] self.neutone_parameter_used = [p.used for p in all_neutone_parameters if p.type != NeutoneParameterType.TEXT] + else: + self.neutone_parameter_names = self.neutone_parameter_names[ + :constants.NEUTONE_GEN_N_KNOB_PARAMS] + self.neutone_parameter_descriptions = self.neutone_parameter_descriptions[ + :constants.NEUTONE_GEN_N_KNOB_PARAMS] + self.neutone_parameter_types = self.neutone_parameter_types[ + :constants.NEUTONE_GEN_N_KNOB_PARAMS] + self.neutone_parameter_used = self.neutone_parameter_used[ + :constants.NEUTONE_GEN_N_KNOB_PARAMS] assert (len(self.get_default_param_names()) == constants.NEUTONE_GEN_N_KNOB_PARAMS) @@ -311,7 +334,11 @@ def forward(self, If params is None, we fill in the default values. """ if text_params is None: - text_params = self.text_param_default_values + # Needed for TorchScript typing + if self.n_text_params: + text_params = self.text_param_default_values + else: + text_params = [] if self.use_debug_mode: assert len(audio_in) == len(self.get_audio_in_channels()) @@ -355,6 +382,7 @@ def forward(self, remapped_knob_params[self.cont_param_names[idx]] = cont_params[idx] # Aggregate and remap the categorical parameters if self.n_cat_params > 0: + # TODO(cm): param validation cat_params = knob_params[self.cat_param_indices, :] cat_params = self.aggregate_categorical_params(cat_params) if self.use_debug_mode: @@ -362,7 +390,7 @@ def forward(self, assert cat_params.size(0) == self.n_cat_params for idx in range(self.n_cat_params): remapped_knob_params[self.cat_param_names[idx]] = ( - int(cat_params[idx])) + cat_params[idx].int()) if self.should_cancel_forward_pass(): return [] @@ -459,6 +487,8 @@ def get_continuous_param_names(self) -> List[str]: """ Returns the names of the continuous parameters. """ + if not self.n_cont_params: # Needed for TorchScript typing + return [] return self.cont_param_names @tr.jit.export @@ -467,6 +497,8 @@ def get_continuous_param_indices(self) -> List[int]: Returns the indices of the position of the continuous parameters in the list of knob parameters. """ + if not self.n_cont_params: # Needed for TorchScript typing + return [] return self.cont_param_indices @tr.jit.export @@ -474,6 +506,8 @@ def get_categorical_param_names(self) -> List[str]: """ Returns the names of the categorical parameters. """ + if not self.n_cat_params: # Needed for TorchScript typing + return [] return self.cat_param_names @tr.jit.export @@ -482,6 +516,8 @@ def get_categorical_param_indices(self) -> List[int]: Returns the indices of the position of the categorical parameters in the list of knob parameters. """ + if not self.n_cat_params: # Needed for TorchScript typing + return [] return self.cat_param_indices @tr.jit.export @@ -490,6 +526,8 @@ def get_categorical_param_n_values(self) -> Dict[str, int]: Returns a dictionary of the number of values for each categorical parameter. Indexed by the name of the categorical parameter. """ + if not self.n_cat_params: # Needed for TorchScript typing + return {} return self.cat_param_n_values @tr.jit.export @@ -498,6 +536,8 @@ def get_categorical_param_labels(self) -> Dict[str, List[str]]: Returns a dictionary of the labels for each categorical parameter. Indexed by the name of the categorical parameter. """ + if not self.n_cat_params: # Needed for TorchScript typing + return {} return self.cat_param_labels @tr.jit.export @@ -505,6 +545,8 @@ def get_text_param_names(self) -> List[str]: """ Returns the names of the text parameters. """ + if not self.n_text_params: # Needed for TorchScript typing + return [] return self.text_param_names @tr.jit.export @@ -512,6 +554,8 @@ def get_text_param_max_n_chars(self) -> List[int]: """ Returns the maximum number of characters for each text parameter. """ + if not self.n_text_params: # Needed for TorchScript typing + return [] return self.text_param_max_n_chars @tr.jit.export @@ -519,6 +563,8 @@ def get_text_param_default_values(self) -> List[str]: """ Returns the default values for the text parameters. """ + if not self.n_text_params: # Needed for TorchScript typing + return [] return self.text_param_default_values @tr.jit.export