Skip to content

Commit

Permalink
bugfix - t[sec]=samples]/original_sample_rate, instead of the resampl…
Browse files Browse the repository at this point in the history
…ed sample rate.
  • Loading branch information
shaing10 committed Jul 14, 2024
1 parent d78b6c9 commit 8824bf8
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
6 changes: 4 additions & 2 deletions soundbay/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def train_epoch(self, epoch):
audio, label = audio.to(self.device), label.to(self.device)

if (it == 0) and (not self.debug) and ((epoch % 5) == 0):
self.logger.upload_artifacts(audio, label, raw_wav, meta, sample_rate=self.train_dataloader.dataset.sample_rate, flag='train')
self.logger.upload_artifacts(audio, label, raw_wav, meta, sample_rate=self.train_dataloader.dataset.sample_rate,
flag='train', data_sample_rate=self.train_dataloader.dataset.data_sample_rate)

# estimate and calc losses
estimated_label = self.model(audio)
Expand Down Expand Up @@ -140,7 +141,8 @@ def eval_epoch(self, epoch: int, datatset_name: str = None):
audio, label, raw_wav, meta = batch
audio, label = audio.to(self.device), label.to(self.device)
if (it == 0) and (not self.debug) and ((epoch % 5) == 0):
self.logger.upload_artifacts(audio, label, raw_wav, meta, sample_rate=self.train_dataloader.dataset.sample_rate, flag=datatset_name)
self.logger.upload_artifacts(audio, label, raw_wav, meta, sample_rate=self.train_dataloader.dataset.sample_rate,
flag=datatset_name, data_sample_rate=self.train_dataloader.dataset.data_sample_rate)

# estimate and calc losses
estimated_label = self.model(audio)
Expand Down
6 changes: 3 additions & 3 deletions soundbay/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def calc_metrics(self, epoch: int, mode: str = 'train', label_names: Optional[Li
self.label_list = []
self.pred_proba_list = []

def upload_artifacts(self, audio: torch.Tensor, label: torch.Tensor, raw_wav: torch.Tensor, meta: dict, sample_rate: int=16000, flag: str='train'):
def upload_artifacts(self, audio: torch.Tensor, label: torch.Tensor, raw_wav: torch.Tensor, meta: dict, sample_rate: int=16000, flag: str='train', data_sample_rate: int = 16000):
"""upload algorithm artifacts to W&B during training session"""
volume = 50
matplotlib.use('Agg')
Expand All @@ -171,7 +171,7 @@ def upload_artifacts(self, audio: torch.Tensor, label: torch.Tensor, raw_wav: to

artifact_wav = torch.squeeze(raw_wav).detach().cpu().numpy()
artifact_wav = artifact_wav / np.expand_dims(np.abs(artifact_wav).max(axis=1) + 1e-8, 1) * 0.5 # gain -6dB
list_of_wavs_objects = [wandb.Audio(data_or_path=wav, caption=f'{flag}_label{lab}_i{ind}_{round(b_t/sample_rate,3)}sec_{f_n}', sample_rate=sample_rate) for wav, ind, lab, b_t, f_n in zip(artifact_wav,idx, label, meta['begin_time'], meta['org_file'])]
list_of_wavs_objects = [wandb.Audio(data_or_path=wav, caption=f'{flag}_label{lab}_i{ind}_{round(b_t/data_sample_rate,2)}sec_{f_n}', sample_rate=sample_rate) for wav, ind, lab, b_t, f_n in zip(artifact_wav,idx, label, meta['begin_time'], meta['org_file'])]

# Spectrograms batch
artifact_spec = torch.squeeze(audio).detach().cpu().numpy()
Expand All @@ -181,7 +181,7 @@ def upload_artifacts(self, audio: torch.Tensor, label: torch.Tensor, raw_wav: to
specs.append(librosa.display.specshow(artifact_spec[artifact_id,...], ax=ax[1]))
plt.close('all')
del ax
list_of_specs_objects = [wandb.Image(data_or_path=spec, caption=f'{flag}_label{lab}_i{ind}_{round(b_t/sample_rate,2)}sec_{f_n}') for spec, ind, lab, b_t, f_n in zip(specs,idx, label, meta['begin_time'], meta['org_file'])]
list_of_specs_objects = [wandb.Image(data_or_path=spec, caption=f'{flag}_label{lab}_i{ind}_{round(b_t/data_sample_rate,2)}sec_{f_n}') for spec, ind, lab, b_t, f_n in zip(specs,idx, label, meta['begin_time'], meta['org_file'])]
log_wavs = {f'First batch {flag} original wavs': list_of_wavs_objects}
log_specs = {f'First batch {flag} augmented spectrogram\'s': list_of_specs_objects}

Expand Down

0 comments on commit 8824bf8

Please sign in to comment.