Skip to content

Commit

Permalink
intial step of adding multitarget support - adjusted BaseDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
danielle-hausler committed Dec 4, 2024
1 parent e187a72 commit f9d0cec
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 14 deletions.
73 changes: 60 additions & 13 deletions soundbay/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from itertools import starmap, repeat
from pathlib import Path
from typing import Union
from decimal import Decimal

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -71,9 +70,8 @@ def __init__(self, data_path, metadata_path, augmentations, augmentations_p, pre
self.preprocessor = self.set_preprocessor(preprocessors)
assert (0 <= margin_ratio) and (1 >= margin_ratio)
self.margin_ratio = margin_ratio
self.items_per_classes = np.unique(self.metadata['label'], return_counts=True)[1]
weights = 1 / self.items_per_classes
self.samples_weight = np.array([weights[t] for t in self.metadata['label'] ])
self.num_classes = self._get_num_classes()
self.samples_weight = self._get_samples_weight()

@staticmethod
def _update_metadata_by_mode(metadata, mode, split_metadata_by_label):
Expand Down Expand Up @@ -108,13 +106,12 @@ def _preprocess_metadata(self, slice_flag=False):
Output:
ClassifierDataset object with self.metadata dataframe after applying the condition
"""
self._preprocess_target()
is_noise = self.metadata['label'].apply(self._is_noise)

# All calls are worthy (because we can later create a bigger slice contain them that is still a call in
# _get_audio) but only long enough background sections will do.
self.metadata = self.metadata[
((self.metadata['call_length'] >= self.seq_length) & (self.metadata['label'] == 0)) |
(self.metadata['label'] > 0)
]
self.metadata = self.metadata[((self.metadata['call_length'] >= self.seq_length) & is_noise) | (~is_noise)]

# sometimes the bbox's end time exceeds the file's length
for name, sub_df in self.metadata.groupby('filename'):
Expand All @@ -128,6 +125,26 @@ def _preprocess_metadata(self, slice_flag=False):

self.metadata.reset_index(drop=True, inplace=True)

def _preprocess_target(self):
"""
Preprocesses the label column in the metadata. If the label is a string, it is evaluated and converted to an
integer or a list of integers.
"""
if pd.api.types.is_string_dtype(self.metadata['label']):
assert self.metadata['label'].str.match(r'^(\[|\()?(\d+)(\s*,\s*\d+)*(\]|\))?$').all(), \
"label should be a string that could be evaluated as a list of integers or integers."
self.metadata['label'] = self.metadata['label'].apply(eval)
if self.metadata['label'].apply(lambda x: isinstance(x, (list, tuple))).all():
self.metadata['label'] = self.metadata['label'].apply(np.array, dtype=int)

@staticmethod
def _is_noise(value: [int, np.ndarray]) -> bool:
"""
Checks if the value is a noise, i.e., if it is equal to 0.
"""
assert (isinstance(value, (int, np.integer)) | isinstance(value, np.ndarray)), "value should be either int or np.ndarray"
return np.sum(value) == 0

def _grab_fields(self, idx):
"""
grabs fields from metadata according to idx
Expand Down Expand Up @@ -159,10 +176,10 @@ def _slice_sequence(self):
self.metadata sliced according to buffers
"""
self.metadata = self.metadata.reset_index(drop=True)
count_values_before = self.metadata.value_counts('label', sort=False) # for validating that the following code doesn't lose samples
count_values_before = self.metadata.astype({'label': str}).value_counts('label', sort=False) # for validating that the following code doesn't lose samples
sliced_times = list(starmap(np.arange, zip(self.metadata['begin_time'], self.metadata['end_time'], repeat(self.seq_length))))
# add the last sequence at the end of this list for calls only (only if it does not exceed the file)
sliced_times = list([np.append(s, self.metadata.loc[i, 'end_time']) if self.metadata.loc[i, 'label'] != 0
sliced_times = list([np.append(s, self.metadata.loc[i, 'end_time']) if (not self._is_noise(self.metadata.loc[i, 'label']))
else s for i, s in enumerate(sliced_times)])
new_begin_time = list(x[:-1] for x in sliced_times)
duplicate_size_vector = [len(list_elem) for list_elem in new_begin_time] # vector to duplicate original dataframe
Expand All @@ -172,8 +189,10 @@ def _slice_sequence(self):
self.metadata['begin_time'] = new_begin_time
self.metadata['end_time'] = new_end_time
self.metadata['call_length'] = np.shape(self.metadata)[0] * [self.seq_length]
if not all(self.metadata.value_counts('label', sort=False) >= count_values_before):
print(f'Note: seems like _slice_sequence erases data.\nbefore:{count_values_before}\nafter:{self.metadata.value_counts("label", sort=False)}')
count_values_after = self.metadata.astype({'label': str}).value_counts('label', sort=False)
if not all(count_values_after >= count_values_before):
print(f'Note: seems like _slice_sequence erases data.\nbefore:{count_values_before}\n'
f'after:{count_values_after}')
return

def _get_audio(self, path_to_file, begin_time, end_time, label, channel=None):
Expand Down Expand Up @@ -215,6 +234,34 @@ def set_preprocessor(preprocessors_args):
preprocessor = torch.nn.Identity()
return preprocessor

def _get_num_classes(self) -> int:
"""
Returns the number of classes in the metadata.
"""
if self.metadata['label'].apply(lambda x: isinstance(x, np.ndarray)).all():
label_lengths = self.metadata['label'].apply(len)
assert label_lengths.nunique() == 1, "All labels should have the same length"
return label_lengths.iloc[0]
else:
return self.metadata['label'].nunique()

def _get_samples_weight(self) -> np.ndarray:
"""
Returns the weight of each sample in the dataset:
- if the label is integer, the weight is the inverse of the class count.
- if the label is a list, the weight is the inverse of the minimum class count.
"""
if self.metadata['label'].apply(lambda x: isinstance(x, np.ndarray)).all():
noise_counts = self.metadata['label'].apply(self._is_noise).sum()
class_counts = np.sum(self.metadata['label'])
per_sample_min_class_count = (self.metadata['label'].apply(
lambda x: class_counts[x.astype(bool)].min() if not self._is_noise(x) else noise_counts))
return (1 / per_sample_min_class_count).values
else:
weights = 1 / np.unique(self.metadata['label'], return_counts=True)[1]
return np.array([weights[t] for t in self.metadata['label']])


def __getitem__(self, idx):
'''
__getitem__ method loads item according to idx from the metadata
Expand Down Expand Up @@ -276,7 +323,7 @@ def _get_audio(self, path_to_file, begin_time, end_time, label, channel=None):
if self.mode == "train":
if seg_length >= requested_seq_length:
# Only for calls we can safely add sections out of the call and label it as call
if (self.margin_ratio != 0) and (label > 0):
if (self.margin_ratio != 0) and (not self._is_noise(label)):
# self.margin_ratio ranges from 0 to 1 - indicates the relative part from seq_len to exceed call_length
margin_len_begin = int(requested_seq_length * self.margin_ratio)
margin_len_end = int(requested_seq_length * (1 - self.margin_ratio))
Expand Down
2 changes: 1 addition & 1 deletion soundbay/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def modeling(
model.to(device)

# Assert number of labels in the dataset and the number of labels in the model
assert model_args['num_classes'] == len(train_dataset.items_per_classes) == len(val_dataset.items_per_classes), \
assert model_args['num_classes'] == train_dataset.num_classes == val_dataset.num_classes, \
"Num of classes in model and the datasets must be equal, check your configs and your dataset labels!!"

# Add model watch to WANDB
Expand Down

0 comments on commit f9d0cec

Please sign in to comment.