Skip to content

Commit

Permalink
update app
Browse files Browse the repository at this point in the history
  • Loading branch information
SkyTNT committed Sep 12, 2023
1 parent 3801d96 commit 745c6f4
Show file tree
Hide file tree
Showing 4 changed files with 518 additions and 99 deletions.
106 changes: 58 additions & 48 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import glob
import json

import PIL
import gradio as gr
import numpy as np
import torch
Expand All @@ -15,6 +15,7 @@
from midi_synthesizer import synthesis
from huggingface_hub import hf_hub_download


@torch.inference_mode()
def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
disable_patch_change=False, disable_control_change=False, disable_channels=None, amp=True):
Expand Down Expand Up @@ -82,43 +83,14 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
break


def create_msg(name, data):
return {"name": name, "data": data}


def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc, amp):
mid_seq = []
max_len = int(gen_events)
img_len = 1024
img = np.full((128 * 2, img_len, 3), 255, dtype=np.uint8)
state = {"t1": 0, "t": 0, "cur_pos": 0}
rand = np.random.RandomState(0)
colors = {(i, j): rand.randint(0, 200, 3) for i in range(128) for j in range(16)}

def draw_event(tokens):
if tokens[0] in tokenizer.id_events:
name = tokenizer.id_events[tokens[0]]
if len(tokens) <= len(tokenizer.events[name]):
return
params = tokens[1:]
params = [params[i] - tokenizer.parameter_ids[p][0] for i, p in enumerate(tokenizer.events[name])]
if not all([0 <= params[i] < tokenizer.event_parameters[p] for i, p in enumerate(tokenizer.events[name])]):
return
event = [name] + params
state["t1"] += event[1]
t = state["t1"] * 16 + event[2]
state["t"] = t
if name == "note":
tr, d, c, p = event[3:7]
shift = t + d - (state["cur_pos"] + img_len)
if shift > 0:
img[:, :-shift] = img[:, shift:]
img[:, -shift:] = 255
state["cur_pos"] += shift
t = t - state["cur_pos"]
img[p * 2:(p + 1) * 2, t: t + d] = colors[(tr, c)]

def get_img():
t = state["t"] - state["cur_pos"]
img_new = img.copy()
img_new[:, t: t + 2] = 0
return PIL.Image.fromarray(np.flip(img_new, 0))
gen_events = int(gen_events)
max_len = gen_events

disable_patch_change = False
disable_channels = None
Expand All @@ -135,7 +107,7 @@ def get_img():
mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i, c, p]))
mid_seq = mid
mid = np.asarray(mid, dtype=np.int64)
if len(instruments) > 0 or drum_kit != "None":
if len(instruments) > 0:
disable_patch_change = True
disable_channels = [i for i in range(16) if i not in patches]
elif mid is not None:
Expand All @@ -144,28 +116,32 @@ def get_img():
mid = mid[:int(midi_events)]
max_len += len(mid)
for token_seq in mid:
mid_seq.append(token_seq)
draw_event(token_seq)
mid_seq.append(token_seq.tolist())

init_msgs = [create_msg("visualizer_clear", None)]
for tokens in mid_seq:
init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
yield mid_seq, None, None, init_msgs
generator = generate(mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
disable_channels=disable_channels, amp=amp)
for token_seq in generator:
for i, token_seq in enumerate(generator):
mid_seq.append(token_seq)
draw_event(token_seq)
yield mid_seq, get_img(), None, None
event = tokenizer.tokens2event(token_seq.tolist())
yield mid_seq, None, None, [create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])]
mid = tokenizer.detokenize(mid_seq)
with open(f"output.mid", 'wb') as f:
f.write(MIDI.score2midi(mid))
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
yield mid_seq, get_img(), "output.mid", (44100, audio)
yield mid_seq, "output.mid", (44100, audio), [create_msg("visualizer_end", None)]


def cancel_run(mid_seq):
mid = tokenizer.detokenize(mid_seq)
with open(f"output.mid", 'wb') as f:
f.write(MIDI.score2midi(mid))
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
return "output.mid", (44100, audio)
return "output.mid", (44100, audio), [create_msg("visualizer_end", None)]


def load_model(path):
Expand All @@ -181,6 +157,38 @@ def get_model_path():
return model_path_input.update(choices=model_paths)


def load_javascript(dir="javascript"):
scripts_list = glob.glob(f"{dir}/*.js")
javascript = ""
for path in scripts_list:
with open(path, "r", encoding="utf8") as jsfile:
javascript += f"\n<!-- {path} --><script>{jsfile.read()}</script>"
template_response_ori = gr.routes.templates.TemplateResponse

def template_response(*args, **kwargs):
res = template_response_ori(*args, **kwargs)
res.body = res.body.replace(
b'</head>', f'{javascript}</head>'.encode("utf8"))
res.init_headers()
return res

gr.routes.templates.TemplateResponse = template_response


class JSMsgReceiver(gr.HTML):

def __init__(self, **kwargs):
super().__init__(elem_id="msg_receiver", visible=False, **kwargs)

def postprocess(self, y):
if y:
y = f"<p>{json.dumps(y)}</p>"
return super().postprocess(y)

def get_block_name(self) -> str:
return "html"


number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
40: "Blush", 48: "Orchestra"}
patch2number = {v: k for k, v in MIDI.Number2patch.items()}
Expand All @@ -195,8 +203,10 @@ def get_model_path():
tokenizer = MIDITokenizer()
model = MIDIModel(tokenizer).to(device=opt.device)

load_javascript()
app = gr.Blocks()
with app:
js_msg = JSMsgReceiver()
with gr.Accordion(label="Model option", open=False):
load_model_path_btn = gr.Button("Get Models")
model_path_input = gr.Dropdown(label="model")
Expand Down Expand Up @@ -243,12 +253,12 @@ def get_model_path():
run_btn = gr.Button("generate", variant="primary")
stop_btn = gr.Button("stop and output")
output_midi_seq = gr.Variable()
output_midi_img = gr.Image(label="output image")
output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
output_audio = gr.Audio(label="output audio", format="wav", elem_id="midi_audio")
output_midi = gr.File(label="output midi", file_types=[".mid"])
output_audio = gr.Audio(label="output audio", format="mp3")
run_event = run_btn.click(run, [tab_select, input_instruments, input_drum_kit, input_midi, input_midi_events,
input_gen_events, input_temp, input_top_p, input_top_k,
input_allow_cc, input_amp],
[output_midi_seq, output_midi_img, output_midi, output_audio])
stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio], cancels=run_event, queue=False)
[output_midi_seq, output_midi, output_audio, js_msg])
stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
app.queue(1).launch(server_port=opt.port)
109 changes: 58 additions & 51 deletions app_onnx.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import argparse
import glob
import json
import os.path
from sys import exit

import gradio as gr
import numpy as np
import onnxruntime as rt
import PIL
import PIL.ImageColor
import requests
import tqdm

Expand Down Expand Up @@ -105,44 +105,14 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
break


def create_msg(name, data):
return {"name": name, "data": data}


def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
mid_seq = []
max_len = int(gen_events)
img_len = 1024
img = np.full((128 * 2, img_len, 3), 255, dtype=np.uint8)
state = {"t1": 0, "t": 0, "cur_pos": 0}
colors = ['navy', 'blue', 'deepskyblue', 'teal', 'green', 'lightgreen', 'lime', 'orange',
'brown', 'grey', 'red', 'pink', 'aqua', 'orchid', 'bisque', 'coral']
colors = [PIL.ImageColor.getrgb(color) for color in colors]

def draw_event(tokens):
if tokens[0] in tokenizer.id_events:
name = tokenizer.id_events[tokens[0]]
if len(tokens) <= len(tokenizer.events[name]):
return
params = tokens[1:]
params = [params[i] - tokenizer.parameter_ids[p][0] for i, p in enumerate(tokenizer.events[name])]
if not all([0 <= params[i] < tokenizer.event_parameters[p] for i, p in enumerate(tokenizer.events[name])]):
return
event = [name] + params
state["t1"] += event[1]
t = state["t1"] * 16 + event[2]
state["t"] = t
if name == "note":
tr, d, c, p = event[3:7]
shift = t + d - (state["cur_pos"] + img_len)
if shift > 0:
img[:, :-shift] = img[:, shift:]
img[:, -shift:] = 255
state["cur_pos"] += shift
t = t - state["cur_pos"]
img[p * 2:(p + 1) * 2, t: t + d] = colors[c]

def get_img():
t = state["t"] - state["cur_pos"]
img_new = img.copy()
img_new[:, t: t + 2] = 0
return PIL.Image.fromarray(np.flip(img_new, 0))
gen_events = int(gen_events)
max_len = gen_events

disable_patch_change = False
disable_channels = None
Expand All @@ -168,20 +138,24 @@ def get_img():
mid = mid[:int(midi_events)]
max_len += len(mid)
for token_seq in mid:
mid_seq.append(token_seq)
draw_event(token_seq)
mid_seq.append(token_seq.tolist())

init_msgs = [create_msg("visualizer_clear", None)]
for tokens in mid_seq:
init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
yield mid_seq, None, None, init_msgs
generator = generate(mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
disable_channels=disable_channels)
for token_seq in generator:
for i, token_seq in enumerate(generator):
mid_seq.append(token_seq)
draw_event(token_seq)
yield mid_seq, get_img(), None, None
event = tokenizer.tokens2event(token_seq.tolist())
yield mid_seq, None, None, [create_msg("visualizer_append", event), create_msg("progress", [i+1, gen_events])]
mid = tokenizer.detokenize(mid_seq)
with open(f"output.mid", 'wb') as f:
f.write(MIDI.score2midi(mid))
audio = synthesis(MIDI.score2opus(mid), opt.soundfont_path)
yield mid_seq, get_img(), "output.mid", (44100, audio)
yield mid_seq, "output.mid", (44100, audio), [create_msg("visualizer_end", None)]


def cancel_run(mid_seq):
Expand All @@ -191,7 +165,7 @@ def cancel_run(mid_seq):
with open(f"output.mid", 'wb') as f:
f.write(MIDI.score2midi(mid))
audio = synthesis(MIDI.score2opus(mid), opt.soundfont_path)
return "output.mid", (44100, audio)
return "output.mid", (44100, audio), [create_msg("visualizer_end", None)]


def download(url, output_file):
Expand All @@ -216,6 +190,38 @@ def download_if_not_exit(url, output_file):
raise e


def load_javascript(dir="javascript"):
scripts_list = glob.glob(f"{dir}/*.js")
javascript = ""
for path in scripts_list:
with open(path, "r", encoding="utf8") as jsfile:
javascript += f"\n<!-- {path} --><script>{jsfile.read()}</script>"
template_response_ori = gr.routes.templates.TemplateResponse

def template_response(*args, **kwargs):
res = template_response_ori(*args, **kwargs)
res.body = res.body.replace(
b'</head>', f'{javascript}</head>'.encode("utf8"))
res.init_headers()
return res

gr.routes.templates.TemplateResponse = template_response


class JSMsgReceiver(gr.HTML):

def __init__(self, **kwargs):
super().__init__(elem_id="msg_receiver", visible=False, **kwargs)

def postprocess(self, y):
if y:
y = f"<p>{json.dumps(y)}</p>"
return super().postprocess(y)

def get_block_name(self) -> str:
return "html"


number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
40: "Blush", 48: "Orchestra"}
patch2number = {v: k for k, v in MIDI.Number2patch.items()}
Expand Down Expand Up @@ -259,6 +265,7 @@ def download_if_not_exit(url, output_file):
input("Failed to load models, maybe you need to delete them and re-download it.\nPress any key to continue...")
exit(-1)

load_javascript()
app = gr.Blocks()
with app:
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Midi Composer</h1>")
Expand All @@ -269,7 +276,7 @@ def download_if_not_exit(url, output_file):
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
" for faster running"
)

js_msg = JSMsgReceiver()
tab_select = gr.Variable(value=0)
with gr.Tabs():
with gr.TabItem("instrument prompt") as tab1:
Expand All @@ -290,7 +297,7 @@ def download_if_not_exit(url, output_file):
], [input_instruments, input_drum_kit])
with gr.TabItem("midi prompt") as tab2:
input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary")
input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=1024,
step=1,
value=128)

Expand All @@ -307,14 +314,14 @@ def download_if_not_exit(url, output_file):
run_btn = gr.Button("generate", variant="primary")
stop_btn = gr.Button("stop and output")
output_midi_seq = gr.Variable()
output_midi_img = gr.Image(label="output image")
output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
output_audio = gr.Audio(label="output audio", format="wav", elem_id="midi_audio")
output_midi = gr.File(label="output midi", file_types=[".mid"])
output_audio = gr.Audio(label="output audio", format="wav")
run_event = run_btn.click(run, [tab_select, input_instruments, input_drum_kit, input_midi, input_midi_events,
input_gen_events, input_temp, input_top_p, input_top_k,
input_allow_cc],
[output_midi_seq, output_midi_img, output_midi, output_audio])
stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio], cancels=run_event, queue=False)
[output_midi_seq, output_midi, output_audio, js_msg])
stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
try:
port = opt.port
if port == -1:
Expand Down
Loading

0 comments on commit 745c6f4

Please sign in to comment.