-
Notifications
You must be signed in to change notification settings - Fork 14
/
02-generator.R
50 lines (35 loc) · 1.44 KB
/
02-generator.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# Creates a generator from a dataset.
library(tfdatasets)
audio_ops <- tf$contrib$framework$python$ops$audio_ops
data_generator <- function(df, batch_size, shuffle = TRUE,
window_size_ms = 30, window_stride_ms = 10) {
window_size <- as.integer(16000*window_size_ms/1000)
stride <- as.integer(16000*window_stride_ms/1000)
fft_size <- as.integer(2^trunc(log(window_size, 2)) + 1)
n_chunks <- length(seq(window_size/2, 16000 - window_size/2, stride))
ds <- tensor_slices_dataset(df)
if (shuffle)
ds <- ds %>% dataset_shuffle(buffer_size = 100)
ds <- ds %>%
dataset_map(function(obs) {
# decoding wav files
audio_binary <- tf$read_file(tf$reshape(obs$fname, shape = list()))
wav <- audio_ops$decode_wav(audio_binary, desired_channels = 1)
# create the spectrogram
spectrogram <- audio_ops$audio_spectrogram(
wav$audio,
window_size = window_size,
stride = stride,
magnitude_squared = TRUE
)
spectrogram <- tf$log(tf$abs(spectrogram) + 0.01)
spectrogram <- tf$transpose(spectrogram, perm = c(1L, 2L, 0L))
# transform the class_id into a one-hot encoded vector
response <- tf$one_hot(obs$class_id, 30L)
list(spectrogram, response)
}) %>%
dataset_repeat()
ds <- ds %>%
dataset_padded_batch(batch_size, list(shape(n_chunks, fft_size, NULL), shape(NULL)))
ds
}