Skip to content

Commit

Permalink
[cm] Adding categorical params to non realtime wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
christhetree committed Feb 6, 2024
1 parent ca9d071 commit cd602dd
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 84 deletions.
2 changes: 1 addition & 1 deletion examples/example_clipper_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_native_buffer_sizes(self) -> List[int]:
def is_one_shot_model(self) -> bool:
return False

def aggregate_cont_params(self, cont_params: Tensor) -> Tensor:
def aggregate_continuous_params(self, cont_params: Tensor) -> Tensor:
return cont_params # We want sample-level control, so no aggregation

def do_forward_pass(self,
Expand Down
9 changes: 5 additions & 4 deletions examples/music_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torch as tr
from torch import Tensor

from neutone_sdk import NeutoneParameter, KnobNeutoneParameter, TextNeutoneParameter
from neutone_sdk import NeutoneParameter, TextNeutoneParameter, \
ContinuousNeutoneParameter
from neutone_sdk.non_realtime_wrapper import NonRealtimeBase

logging.basicConfig()
Expand Down Expand Up @@ -48,7 +49,7 @@ def get_neutone_parameters(self) -> List[NeutoneParameter]:
return [
TextNeutoneParameter("prompt", "testing"),
# TODO(cm): convert to categorical
KnobNeutoneParameter("duration", "how many seconds to generate", default_value=0.0),
ContinuousNeutoneParameter("duration", "how many seconds to generate", default_value=0.0),
]

@tr.jit.export
Expand Down Expand Up @@ -120,15 +121,15 @@ def do_forward_pass(self,
wrapper = MusicGenModelWrapper(model)
audio_out = wrapper.forward(curr_block_idx=0,
audio_in=[],
cont_params=tr.tensor([0.0, 1.0]).unsqueeze(1),
knob_params=tr.tensor([0.0, 1.0]).unsqueeze(1),
text_params=["testing"])
log.info(audio_out[0].shape)

# wrapper.prepare_for_inference()
ts = tr.jit.script(wrapper)
audio_out = ts.forward(curr_block_idx=0,
audio_in=[],
cont_params=tr.tensor([0.0]).unsqueeze(1),
knob_params=tr.tensor([0.0]).unsqueeze(1),
text_params=["testing"])
log.info(audio_out[0].shape)

Expand Down
Loading

0 comments on commit cd602dd

Please sign in to comment.