Skip to content

Commit

Permalink
Feature/log time sample (#97)
Browse files Browse the repository at this point in the history
* log information per sample in artifacts (audio and spec)
* bugfix in default train run
  • Loading branch information
shaing10 authored Jul 10, 2024
1 parent 08c102e commit d78b6c9
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 10 deletions.
2 changes: 2 additions & 0 deletions soundbay/conf/data/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ data:
train_dataset:
_target_: soundbay.data.ClassifierDataset
data_path: './tests/assets/data/'
path_hierarchy: 0
mode: train
metadata_path: './tests/assets/annotations/sample_annotations.csv'
augmentations_p: 0.8
Expand All @@ -25,6 +26,7 @@ data:
val_dataset:
_target_: soundbay.data.ClassifierDataset
data_path: './tests/assets/data'
path_hierarchy: 0
mode: val
metadata_path: './tests/assets/annotations/sample_annotations.csv'
augmentations_p: 0
Expand Down
2 changes: 1 addition & 1 deletion soundbay/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def __getitem__(self, idx):

if self.mode == "train" or self.mode == "val":
label = self.metadata["label"][idx]
return audio_processed, label, audio_raw, idx
return audio_processed, label, audio_raw, {"idx": idx, "begin_time": begin_time, "org_file": Path(path_to_file).stem}

elif self.mode == "test":
return audio_processed
Expand Down
10 changes: 5 additions & 5 deletions soundbay/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ def train_epoch(self, epoch):
break

self.model.zero_grad()
audio, label, raw_wav, idx = batch
audio, label = audio.to(self.device), label.to(self.device)
audio, label, raw_wav, meta = batch
audio, label = audio.to(self.device), label.to(self.device)

if (it == 0) and (not self.debug) and ((epoch % 5) == 0):
self.logger.upload_artifacts(audio, label, raw_wav, idx, sample_rate=self.train_dataloader.dataset.sample_rate, flag='train')
self.logger.upload_artifacts(audio, label, raw_wav, meta, sample_rate=self.train_dataloader.dataset.sample_rate, flag='train')

# estimate and calc losses
estimated_label = self.model(audio)
Expand Down Expand Up @@ -137,10 +137,10 @@ def eval_epoch(self, epoch: int, datatset_name: str = None):
for it, batch in tqdm(enumerate(dataloader), desc=datatset_name):
if it == 3 and self.debug:
break
audio, label, raw_wav, idx = batch
audio, label, raw_wav, meta = batch
audio, label = audio.to(self.device), label.to(self.device)
if (it == 0) and (not self.debug) and ((epoch % 5) == 0):
self.logger.upload_artifacts(audio, label, raw_wav, idx, sample_rate=self.train_dataloader.dataset.sample_rate, flag=datatset_name)
self.logger.upload_artifacts(audio, label, raw_wav, meta, sample_rate=self.train_dataloader.dataset.sample_rate, flag=datatset_name)

# estimate and calc losses
estimated_label = self.model(audio)
Expand Down
9 changes: 5 additions & 4 deletions soundbay/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,12 @@ def calc_metrics(self, epoch: int, mode: str = 'train', label_names: Optional[Li
self.label_list = []
self.pred_proba_list = []

def upload_artifacts(self, audio: torch.Tensor, label: torch.Tensor, raw_wav: torch.Tensor, idx: torch.Tensor, sample_rate: int=16000, flag: str='train'):
def upload_artifacts(self, audio: torch.Tensor, label: torch.Tensor, raw_wav: torch.Tensor, meta: dict, sample_rate: int=16000, flag: str='train'):
"""upload algorithm artifacts to W&B during training session"""
volume = 50
matplotlib.use('Agg')
idx = idx.detach().cpu().numpy()
idx = meta['idx'].detach().cpu().numpy()
meta['begin_time'] = meta['begin_time'].detach().cpu().numpy()
label = label.detach().cpu().numpy()

if audio.shape[0] > self.upload_artifacts_limit:
Expand All @@ -170,7 +171,7 @@ def upload_artifacts(self, audio: torch.Tensor, label: torch.Tensor, raw_wav: to

artifact_wav = torch.squeeze(raw_wav).detach().cpu().numpy()
artifact_wav = artifact_wav / np.expand_dims(np.abs(artifact_wav).max(axis=1) + 1e-8, 1) * 0.5 # gain -6dB
list_of_wavs_objects = [wandb.Audio(data_or_path=wav, caption=f'label_{lab}_{ind}_train', sample_rate=sample_rate) for wav, ind, lab in zip(artifact_wav,idx, label)]
list_of_wavs_objects = [wandb.Audio(data_or_path=wav, caption=f'{flag}_label{lab}_i{ind}_{round(b_t/sample_rate,3)}sec_{f_n}', sample_rate=sample_rate) for wav, ind, lab, b_t, f_n in zip(artifact_wav,idx, label, meta['begin_time'], meta['org_file'])]

# Spectrograms batch
artifact_spec = torch.squeeze(audio).detach().cpu().numpy()
Expand All @@ -180,7 +181,7 @@ def upload_artifacts(self, audio: torch.Tensor, label: torch.Tensor, raw_wav: to
specs.append(librosa.display.specshow(artifact_spec[artifact_id,...], ax=ax[1]))
plt.close('all')
del ax
list_of_specs_objects = [wandb.Image(data_or_path=spec, caption=f'label_{lab}_{ind}_train') for spec, ind, lab in zip(specs,idx, label)]
list_of_specs_objects = [wandb.Image(data_or_path=spec, caption=f'{flag}_label{lab}_i{ind}_{round(b_t/sample_rate,2)}sec_{f_n}') for spec, ind, lab, b_t, f_n in zip(specs,idx, label, meta['begin_time'], meta['org_file'])]
log_wavs = {f'First batch {flag} original wavs': list_of_wavs_objects}
log_specs = {f'First batch {flag} augmented spectrogram\'s': list_of_specs_objects}

Expand Down

0 comments on commit d78b6c9

Please sign in to comment.