forked from acids-ircam/RAVE
-
Notifications
You must be signed in to change notification settings - Fork 1
/
cli_helper.py
176 lines (133 loc) · 4.39 KB
/
cli_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from os import path
import shutil
class Print:
def __init__(self):
self.msg = ""
def __call__(self, msg, end="\n"):
print(msg, end=end)
self.msg += msg + end
p = Print()
def header(title: str, out=p):
term_size = shutil.get_terminal_size().columns
term_size = min(120, term_size)
pad_l = (term_size - len(title)) // 2
out(term_size * "=")
out(pad_l * " " + title.upper())
out(term_size * "=")
out("")
def subsection(title: str):
p("")
size = len(title)
p(title.upper())
p(size * "-")
p("")
if __name__ == "__main__":
header("rave command line helper", out=print)
name = ""
while not name:
name = input("choose a name for the training: ")
name = name.lower().replace(" ", "_")
data = ""
while not data:
data = input("path to the .wav files: ")
preprocessed = ""
while not preprocessed:
preprocessed = input("temporary folder (fast drive): ")
sampling_rate = input("sampling rate (defaults to 48000): ")
multiband_number = input("multiband number (defaults to 16): ")
n_signal = input("training example duration (defaults to 65536 samples): ")
capacity = input("model capacity (defaults to 6): ")
warmup = input("number of steps for stage 1 (defaults to 1000000): ")
prior_resolution = input("prior resolution (defaults to 32): ")
fidelity = input("reconstruction fidelity (defaults to 0.95): ")
no_latency = input("latency compensation (defaults to false): ")
latent_size = input("latent size (learned if left blank): ")
pressure = input("regularization strength (defaults to 0.1): ")
header(f"{name}: training instructions")
subsection("train rave")
p("Train rave (both training stages are included)")
p("")
cmd = "python train_rave.py "
cmd += f"--name {name} "
cmd += f"--wav {data} "
prep_rave = path.join(preprocessed, name, "rave")
cmd += f"--preprocessed {prep_rave} "
if sampling_rate:
cmd += f"--sr {sampling_rate} "
if multiband_number:
cmd += f"--data-size {multiband_number} "
if n_signal:
cmd += f"--n-signal {n_signal} "
if capacity:
cmd += f"--capacity {2**int(capacity)} "
if warmup:
cmd += f"--warmup {warmup} "
if no_latency:
cmd += f"--no-latency {no_latency.lower()} "
if latent_size:
cmd += f"--cropped-latent-size {int(latent_size)} "
if pressure:
cmd += f"--max-kl {pressure} "
p(cmd)
p("")
p("You can follow the training using tensorboard")
p("")
p("tensorboard --logdir . --bind_all")
p("")
p("Once the training has reached a satisfactory state, kill it (ctrl + C)")
subsection("train prior")
p(
f"Export the latent space trained on {name}.",
end="\n\n",
)
cmd = "python export_rave.py "
run = path.join("runs", name, "rave")
cmd += f"--run {run} "
cmd += f"--cached false "
if fidelity:
cmd += f"--fidelity {fidelity} "
cmd += f"--name {name}"
p(cmd)
p("")
p(
f"Train the prior model.",
end="\n\n",
)
cmd = "python train_prior.py "
if prior_resolution:
cmd += f"--resolution {prior_resolution} "
cmd += f"--pretrained-vae rave_{name}.ts "
prep_prior = path.join(preprocessed, name, "prior")
cmd += f"--preprocessed {prep_prior} "
cmd += f"--wav {data} "
if n_signal:
cmd += f"--n-signal {n_signal} "
cmd += f"--name {name}"
p(cmd)
p("")
p("Once the training has reached a satisfactory state, kill it (ctrl + C)")
p("")
header("export to max msp (coming soon)")
p("In order to use both **rave** and the **prior** model inside max/msp, we have to export them using **cached convolutions**."
)
p("")
cmd = "python export_rave.py "
run = path.join("runs", name, "rave")
cmd += f"--run {run} "
cmd += f"--cached true "
if fidelity:
cmd += f"--fidelity {fidelity} "
cmd += f"--name {name}_rt"
p(cmd)
cmd = "python export_prior.py "
run = path.join("runs", name, "prior")
cmd += f"--run {run} "
cmd += f"--name {name}_rt"
p(cmd)
cmd = "python combine_models.py "
cmd += f"--prior prior_{name}_rt.ts "
cmd += f"--rave rave_{name}_rt.ts "
cmd += f"--name {name}"
p(cmd)
with open(f"instruction_{name}.txt", "w") as out:
out.write(p.msg)