-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
55 lines (38 loc) · 1.43 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from audio_diffusion_pytorch import AudioDiffusionModel
import torch
import torchaudio
# https://github.com/archinetai/audio-diffusion-pytorch/discussions/14
model = AudioDiffusionModel(
in_channels=1,
context_channels=[1]
)
def reshape(tensor):
assert tensor.shape[0] == 2
return tensor[0, :].unsqueeze(0).unsqueeze(0) # tensor.shape = [1, 1, N]
def readAudioFile(filename):
[rawSource, sampleRate] = torchaudio.load(
filename) # rawSource.shape = [2, N]
return reshape(rawSource)
source = readAudioFile("bucky.wav")
target = readAudioFile("makeba.wav")
N = source.shape[2]
assert N == target.shape[2]
print("Processing soure and target of of number of samples: " + str(N))
print("input shape: " + str(source.shape))
# Train model with pairs of audio sources, i.e. predict target given source
# [batch, in_channels, samples], 2**18 ≈ 12s of audio at a frequency of 22050
# loss = model(target, channels_list=[source])
loss = model(target, context=[source])
for i in range(20):
print("Train:" + str(i))
loss.backward(retain_graph=True) # Do this many times
# Sample a target audio given start noise and source audio
noise = torch.randn(1, 1, N)
print("Starting sampling...")
sampled = model.sample(
context=[source],
noise=noise,
num_steps=25 # Suggested range: 2-50
) # Its not [2, 1, 2 ** 18] but rather [1,1,N]
print("Saving output.wav")
torchaudio.save("output.wav", sampled[0], 48000)