Skip to content

Commit

Permalink
[cm] Fixing torchscript typing errors
Browse files Browse the repository at this point in the history
  • Loading branch information
christhetree committed Feb 6, 2024
1 parent cd602dd commit 1f9b011
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 21 deletions.
21 changes: 15 additions & 6 deletions examples/example_clipper_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import torch.nn as nn
from torch import Tensor

from neutone_sdk import NeutoneParameter, KnobNeutoneParameter, TextNeutoneParameter
from neutone_sdk import NeutoneParameter, TextNeutoneParameter, \
ContinuousNeutoneParameter, CategoricalNeutoneParameter
from neutone_sdk.non_realtime_wrapper import NonRealtimeBase

logging.basicConfig()
Expand Down Expand Up @@ -59,10 +60,12 @@ def is_experimental(self) -> bool:

def get_neutone_parameters(self) -> List[NeutoneParameter]:
return [
KnobNeutoneParameter("min", "min clip threshold", default_value=0.15),
KnobNeutoneParameter("max", "max clip threshold", default_value=0.15),
TextNeutoneParameter("text_param", "testing"),
KnobNeutoneParameter("gain", "scale clip threshold", default_value=1.0),
ContinuousNeutoneParameter("min", "min clip threshold", default_value=0.15),
CategoricalNeutoneParameter("cat", "catty", n_values=3, default_value=2),
ContinuousNeutoneParameter("max", "max clip threshold", default_value=0.15),
ContinuousNeutoneParameter("gain", "scale clip threshold", default_value=1.0),
# ContinuousNeutoneParameter("gain2", "scale clip threshold", default_value=1.0),
]

@tr.jit.export
Expand Down Expand Up @@ -93,6 +96,9 @@ def do_forward_pass(self,
audio_in: List[Tensor],
cont_params: Dict[str, Tensor],
text_params: List[str]) -> List[Tensor]:
# print(cont_params)
# print(text_params)
# exit()
min_val, max_val, gain = cont_params["min"], cont_params["max"], cont_params["gain"]
audio_out = []
for x in audio_in:
Expand All @@ -109,7 +115,10 @@ def do_forward_pass(self,

model = ClipperModel()
wrapper = ClipperModelWrapper(model)
wrapper.forward(0, [tr.rand(2, 2048)], text_params=["ayy"])
# wrapper.forward(0, [tr.rand(2, 2048)], knob_params=tr.tensor([[0.5], [0.1], [0.2], [0.3]]))
# wrapper.forward(0, [tr.rand(2, 2048)], text_params=["ayy"])
wrapper.forward(0, [tr.rand(2, 2048)])

ts = tr.jit.script(wrapper)
ts.forward(0, [tr.rand(2, 2048)], text_params=["ayy"])
ts.forward(0, [tr.rand(2, 2048)])
# ts.forward(0, [tr.rand(2, 2048)], text_params=["ayy"])
25 changes: 14 additions & 11 deletions examples/music_gen.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging
import os
from typing import Dict, List, Optional
from typing import Dict, List

import torch as tr
from torch import Tensor

from neutone_sdk import NeutoneParameter, TextNeutoneParameter, \
ContinuousNeutoneParameter
CategoricalNeutoneParameter
from neutone_sdk.non_realtime_wrapper import NonRealtimeBase

logging.basicConfig()
Expand Down Expand Up @@ -47,9 +47,15 @@ def is_experimental(self) -> bool:

def get_neutone_parameters(self) -> List[NeutoneParameter]:
return [
TextNeutoneParameter("prompt", "testing"),
# TODO(cm): convert to categorical
ContinuousNeutoneParameter("duration", "how many seconds to generate", default_value=0.0),
TextNeutoneParameter(name="prompt",
description="text prompt for generation",
max_n_chars=128,
default_value="techno kick drum"),
CategoricalNeutoneParameter(name="duration",
description="how many seconds to generate",
n_values=8,
default_value=0,
labels=[str(idx) for idx in range(1, 9)]),
]

@tr.jit.export
Expand All @@ -75,12 +81,9 @@ def is_one_shot_model(self) -> bool:
def do_forward_pass(self,
curr_block_idx: int,
audio_in: List[Tensor],
cont_params: Dict[str, Tensor],
knob_params: Dict[str, Tensor],
text_params: List[str]) -> List[Tensor]:
raw_time = cont_params["duration"]
# Limit duration between 1 and 8 seconds
n_seconds = tr.clamp(raw_time * 7.0 + 1.0, 1.0, 8.0).item()
n_seconds = int(n_seconds)
n_seconds = knob_params["duration"].item() + 1
# Convert duration to number of tokens
n_tokens = (n_seconds * 50) + 4
if self.use_debug_mode:
Expand Down Expand Up @@ -121,7 +124,7 @@ def do_forward_pass(self,
wrapper = MusicGenModelWrapper(model)
audio_out = wrapper.forward(curr_block_idx=0,
audio_in=[],
knob_params=tr.tensor([0.0, 1.0]).unsqueeze(1),
knob_params=tr.tensor([0.0]).unsqueeze(1),
text_params=["testing"])
log.info(audio_out[0].shape)

Expand Down
54 changes: 50 additions & 4 deletions neutone_sdk/non_realtime_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,20 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:
self.text_param_max_n_chars.append(p.max_n_chars)
self.text_param_default_values.append(p.default_value)

# This is needed for TorchScript typing since it doesn't allow empty lists etc.
if not self.n_cont_params:
self.cont_param_names.append("__torchscript_typing")
self.cont_param_indices.append(-1)
if not self.n_cat_params:
self.cat_param_names.append("__torchscript_typing")
self.cat_param_indices.append(-1)
self.cat_param_n_values["__torchscript_typing"] = -1
self.cat_param_labels["__torchscript_typing"] = ["__torchscript_typing"]
if not self.n_text_params:
self.text_param_names.append("__torchscript_typing")
self.text_param_max_n_chars.append(-1)
self.text_param_default_values.append("__torchscript_typing")

self.n_knob_params = self.n_cont_params + self.n_cat_params

assert self.n_knob_params <= constants.NEUTONE_GEN_N_KNOB_PARAMS, (
Expand All @@ -118,8 +132,8 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:
all_neutone_parameters = self._get_all_neutone_parameters()
assert len(all_neutone_parameters) == self._get_max_n_params()

# This overrides the base class definitions to remove the text param since
# it is handled separately in the UI.
# This overrides the base class definitions to remove the text param or extra
# base param since it is handled separately in the UI.
if self.has_text_param:
self.neutone_parameter_names = [p.name for p in all_neutone_parameters
if p.type != NeutoneParameterType.TEXT]
Expand All @@ -131,6 +145,15 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:
if p.type != NeutoneParameterType.TEXT]
self.neutone_parameter_used = [p.used for p in all_neutone_parameters
if p.type != NeutoneParameterType.TEXT]
else:
self.neutone_parameter_names = self.neutone_parameter_names[
:constants.NEUTONE_GEN_N_KNOB_PARAMS]
self.neutone_parameter_descriptions = self.neutone_parameter_descriptions[
:constants.NEUTONE_GEN_N_KNOB_PARAMS]
self.neutone_parameter_types = self.neutone_parameter_types[
:constants.NEUTONE_GEN_N_KNOB_PARAMS]
self.neutone_parameter_used = self.neutone_parameter_used[
:constants.NEUTONE_GEN_N_KNOB_PARAMS]

assert (len(self.get_default_param_names()) ==
constants.NEUTONE_GEN_N_KNOB_PARAMS)
Expand Down Expand Up @@ -311,7 +334,11 @@ def forward(self,
If params is None, we fill in the default values.
"""
if text_params is None:
text_params = self.text_param_default_values
# Needed for TorchScript typing
if self.n_text_params:
text_params = self.text_param_default_values
else:
text_params = []

if self.use_debug_mode:
assert len(audio_in) == len(self.get_audio_in_channels())
Expand Down Expand Up @@ -355,14 +382,15 @@ def forward(self,
remapped_knob_params[self.cont_param_names[idx]] = cont_params[idx]
# Aggregate and remap the categorical parameters
if self.n_cat_params > 0:
# TODO(cm): param validation
cat_params = knob_params[self.cat_param_indices, :]
cat_params = self.aggregate_categorical_params(cat_params)
if self.use_debug_mode:
assert cat_params.ndim == 2
assert cat_params.size(0) == self.n_cat_params
for idx in range(self.n_cat_params):
remapped_knob_params[self.cat_param_names[idx]] = (
int(cat_params[idx]))
cat_params[idx].int())

if self.should_cancel_forward_pass():
return []
Expand Down Expand Up @@ -459,6 +487,8 @@ def get_continuous_param_names(self) -> List[str]:
"""
Returns the names of the continuous parameters.
"""
if not self.n_cont_params: # Needed for TorchScript typing
return []
return self.cont_param_names

@tr.jit.export
Expand All @@ -467,13 +497,17 @@ def get_continuous_param_indices(self) -> List[int]:
Returns the indices of the position of the continuous parameters in the list of
knob parameters.
"""
if not self.n_cont_params: # Needed for TorchScript typing
return []
return self.cont_param_indices

@tr.jit.export
def get_categorical_param_names(self) -> List[str]:
"""
Returns the names of the categorical parameters.
"""
if not self.n_cat_params: # Needed for TorchScript typing
return []
return self.cat_param_names

@tr.jit.export
Expand All @@ -482,6 +516,8 @@ def get_categorical_param_indices(self) -> List[int]:
Returns the indices of the position of the categorical parameters in the list of
knob parameters.
"""
if not self.n_cat_params: # Needed for TorchScript typing
return []
return self.cat_param_indices

@tr.jit.export
Expand All @@ -490,6 +526,8 @@ def get_categorical_param_n_values(self) -> Dict[str, int]:
Returns a dictionary of the number of values for each categorical parameter.
Indexed by the name of the categorical parameter.
"""
if not self.n_cat_params: # Needed for TorchScript typing
return {}
return self.cat_param_n_values

@tr.jit.export
Expand All @@ -498,27 +536,35 @@ def get_categorical_param_labels(self) -> Dict[str, List[str]]:
Returns a dictionary of the labels for each categorical parameter.
Indexed by the name of the categorical parameter.
"""
if not self.n_cat_params: # Needed for TorchScript typing
return {}
return self.cat_param_labels

@tr.jit.export
def get_text_param_names(self) -> List[str]:
"""
Returns the names of the text parameters.
"""
if not self.n_text_params: # Needed for TorchScript typing
return []
return self.text_param_names

@tr.jit.export
def get_text_param_max_n_chars(self) -> List[int]:
"""
Returns the maximum number of characters for each text parameter.
"""
if not self.n_text_params: # Needed for TorchScript typing
return []
return self.text_param_max_n_chars

@tr.jit.export
def get_text_param_default_values(self) -> List[str]:
"""
Returns the default values for the text parameters.
"""
if not self.n_text_params: # Needed for TorchScript typing
return []
return self.text_param_default_values

@tr.jit.export
Expand Down

0 comments on commit 1f9b011

Please sign in to comment.