Skip to content

Commit

Permalink
[cm] Adding knob param info
Browse files Browse the repository at this point in the history
  • Loading branch information
christhetree committed Feb 5, 2024
1 parent 023734a commit f46b266
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
2 changes: 1 addition & 1 deletion examples/music_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def do_forward_pass(self,
wrapper = MusicGenModelWrapper(model)
audio_out = wrapper.forward(curr_block_idx=0,
audio_in=[],
cont_params=tr.tensor([0.0]).unsqueeze(1),
cont_params=tr.tensor([0.0, 1.0]).unsqueeze(1),
text_params=["testing"])
log.info(audio_out[0].shape)

Expand Down
51 changes: 44 additions & 7 deletions neutone_sdk/non_realtime_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,26 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:
f"Max allowed is {constants.NEUTONE_GEN_N_TEXT_PARAMS}"
)
self.n_cont_params = n_cont_params
# TorchScript does not know how to type empty lists
self._cont_param_names = cont_param_names + ["__torchscript_typing"]
self.n_text_params = n_text_params
# TorchScript does not know how to type empty lists
self._text_param_names = text_param_names + ["__torchscript_typing"]
if n_text_params:
self.has_text_param = True

all_neutone_parameters = self._get_all_neutone_parameters()
assert len(all_neutone_parameters) == self._get_max_n_params()

# Continuous and base (unused) parameters are considered knobs
knob_param_types = {NeutoneParameterType.BASE, NeutoneParameterType.KNOB}
self.neutone_knob_param_names = [p.name for p in all_neutone_parameters
if p.type in knob_param_types]
self.neutone_knob_param_descriptions = [
p.description for p in all_neutone_parameters if p.type in knob_param_types
]
self.neutone_knob_param_types = [p.type.value for p in all_neutone_parameters
if p.type in knob_param_types]
self.neutone_knob_param_used = [p.used for p in all_neutone_parameters
if p.type in knob_param_types]
assert len(self.neutone_knob_param_names) == constants.NEUTONE_GEN_N_KNOB_PARAMS

default_cont_param_values = Tensor() # Empty tensor for typing
if self.n_cont_params > 0:
default_cont_param_values = tr.tensor(
Expand All @@ -95,6 +104,7 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:
if p.type == NeutoneParameterType.KNOB
]
).unsqueeze(-1)
assert default_cont_param_values.size(0) == self.n_cont_params
# TODO(cm): rename once moved from core
self.register_buffer("default_param_values", default_cont_param_values)

Expand Down Expand Up @@ -280,7 +290,10 @@ def forward(self,

if self.use_debug_mode:
if cont_params is not None and self.n_cont_params > 0:
assert cont_params.shape == (self.n_cont_params, in_n)
assert cont_params.ndim == 2
assert (self.n_cont_params <= cont_params.size(0) <=
constants.NEUTONE_GEN_N_KNOB_PARAMS)
assert cont_params.size(1) == in_n
if self.get_native_buffer_sizes():
assert (
in_n in self.get_native_buffer_sizes()
Expand All @@ -292,9 +305,11 @@ def forward(self,
cont_params = self.aggregate_cont_params(cont_params)
if self.use_debug_mode:
assert cont_params.ndim == 2
assert cont_params.size(0) == self.n_cont_params
assert (self.n_cont_params <= cont_params.size(0) <=
constants.NEUTONE_GEN_N_KNOB_PARAMS)
for idx in range(self.n_cont_params):
remapped_cont_params[self._cont_param_names[idx]] = cont_params[idx]
remapped_cont_params[
self.neutone_knob_param_names[idx]] = cont_params[idx]

if self.should_cancel_forward_pass():
return []
Expand Down Expand Up @@ -386,6 +401,22 @@ def is_text_model(self) -> bool:
"""
return self.has_text_param

@tr.jit.export
def get_default_knob_param_names(self) -> List[str]:
return self.neutone_knob_param_names

@tr.jit.export
def get_default_knob_param_descriptions(self) -> List[str]:
return self.neutone_knob_param_descriptions

@tr.jit.export
def get_default_knob_param_types(self) -> List[str]:
return self.neutone_knob_param_types

@tr.jit.export
def get_default_knob_param_used(self) -> List[bool]:
return self.neutone_knob_param_used

@tr.jit.export
def get_preserved_attributes(self) -> List[str]:
# This avoids using inheritance which torchscript does not support
Expand All @@ -401,7 +432,13 @@ def get_preserved_attributes(self) -> List[str]:
"set_sample_rate_and_buffer_size",
"reset",
"get_progress_percentage",
"should_cancel_forward_pass",
"request_cancel_forward_pass",
"is_text_model",
"get_default_knob_param_names",
"get_default_knob_param_descriptions",
"get_default_knob_param_types",
"get_default_knob_param_used",
"get_preserved_attributes",
"to_metadata",
]
Expand Down

0 comments on commit f46b266

Please sign in to comment.