diff --git a/soundbay/data.py b/soundbay/data.py index 342369d..6d47ba6 100644 --- a/soundbay/data.py +++ b/soundbay/data.py @@ -2,7 +2,6 @@ from itertools import starmap, repeat from pathlib import Path from typing import Union -from decimal import Decimal import numpy as np import pandas as pd @@ -71,9 +70,8 @@ def __init__(self, data_path, metadata_path, augmentations, augmentations_p, pre self.preprocessor = self.set_preprocessor(preprocessors) assert (0 <= margin_ratio) and (1 >= margin_ratio) self.margin_ratio = margin_ratio - self.items_per_classes = np.unique(self.metadata['label'], return_counts=True)[1] - weights = 1 / self.items_per_classes - self.samples_weight = np.array([weights[t] for t in self.metadata['label'] ]) + self.num_classes = self._get_num_classes() + self.samples_weight = self._get_samples_weight() @staticmethod def _update_metadata_by_mode(metadata, mode, split_metadata_by_label): @@ -108,13 +106,12 @@ def _preprocess_metadata(self, slice_flag=False): Output: ClassifierDataset object with self.metadata dataframe after applying the condition """ + self._preprocess_target() + is_noise = self.metadata['label'].apply(self._is_noise) # All calls are worthy (because we can later create a bigger slice contain them that is still a call in # _get_audio) but only long enough background sections will do. - self.metadata = self.metadata[ - ((self.metadata['call_length'] >= self.seq_length) & (self.metadata['label'] == 0)) | - (self.metadata['label'] > 0) - ] + self.metadata = self.metadata[((self.metadata['call_length'] >= self.seq_length) & is_noise) | (~is_noise)] # sometimes the bbox's end time exceeds the file's length for name, sub_df in self.metadata.groupby('filename'): @@ -128,6 +125,26 @@ def _preprocess_metadata(self, slice_flag=False): self.metadata.reset_index(drop=True, inplace=True) + def _preprocess_target(self): + """ + Preprocesses the label column in the metadata. If the label is a string, it is evaluated and converted to an + integer or a list of integers. + """ + if pd.api.types.is_string_dtype(self.metadata['label']): + assert self.metadata['label'].str.match(r'^(\[|\()?(\d+)(\s*,\s*\d+)*(\]|\))?$').all(), \ + "label should be a string that could be evaluated as a list of integers or integers." + self.metadata['label'] = self.metadata['label'].apply(eval) + if self.metadata['label'].apply(lambda x: isinstance(x, (list, tuple))).all(): + self.metadata['label'] = self.metadata['label'].apply(np.array, dtype=int) + + @staticmethod + def _is_noise(value: [int, np.ndarray]) -> bool: + """ + Checks if the value is a noise, i.e., if it is equal to 0. + """ + assert (isinstance(value, (int, np.integer)) | isinstance(value, np.ndarray)), "value should be either int or np.ndarray" + return np.sum(value) == 0 + def _grab_fields(self, idx): """ grabs fields from metadata according to idx @@ -159,10 +176,10 @@ def _slice_sequence(self): self.metadata sliced according to buffers """ self.metadata = self.metadata.reset_index(drop=True) - count_values_before = self.metadata.value_counts('label', sort=False) # for validating that the following code doesn't lose samples + count_values_before = self.metadata.astype({'label': str}).value_counts('label', sort=False) # for validating that the following code doesn't lose samples sliced_times = list(starmap(np.arange, zip(self.metadata['begin_time'], self.metadata['end_time'], repeat(self.seq_length)))) # add the last sequence at the end of this list for calls only (only if it does not exceed the file) - sliced_times = list([np.append(s, self.metadata.loc[i, 'end_time']) if self.metadata.loc[i, 'label'] != 0 + sliced_times = list([np.append(s, self.metadata.loc[i, 'end_time']) if (not self._is_noise(self.metadata.loc[i, 'label'])) else s for i, s in enumerate(sliced_times)]) new_begin_time = list(x[:-1] for x in sliced_times) duplicate_size_vector = [len(list_elem) for list_elem in new_begin_time] # vector to duplicate original dataframe @@ -172,8 +189,10 @@ def _slice_sequence(self): self.metadata['begin_time'] = new_begin_time self.metadata['end_time'] = new_end_time self.metadata['call_length'] = np.shape(self.metadata)[0] * [self.seq_length] - if not all(self.metadata.value_counts('label', sort=False) >= count_values_before): - print(f'Note: seems like _slice_sequence erases data.\nbefore:{count_values_before}\nafter:{self.metadata.value_counts("label", sort=False)}') + count_values_after = self.metadata.astype({'label': str}).value_counts('label', sort=False) + if not all(count_values_after >= count_values_before): + print(f'Note: seems like _slice_sequence erases data.\nbefore:{count_values_before}\n' + f'after:{count_values_after}') return def _get_audio(self, path_to_file, begin_time, end_time, label, channel=None): @@ -215,6 +234,34 @@ def set_preprocessor(preprocessors_args): preprocessor = torch.nn.Identity() return preprocessor + def _get_num_classes(self) -> int: + """ + Returns the number of classes in the metadata. + """ + if self.metadata['label'].apply(lambda x: isinstance(x, np.ndarray)).all(): + label_lengths = self.metadata['label'].apply(len) + assert label_lengths.nunique() == 1, "All labels should have the same length" + return label_lengths.iloc[0] + else: + return self.metadata['label'].nunique() + + def _get_samples_weight(self) -> np.ndarray: + """ + Returns the weight of each sample in the dataset: + - if the label is integer, the weight is the inverse of the class count. + - if the label is a list, the weight is the inverse of the minimum class count. + """ + if self.metadata['label'].apply(lambda x: isinstance(x, np.ndarray)).all(): + noise_counts = self.metadata['label'].apply(self._is_noise).sum() + class_counts = np.sum(self.metadata['label']) + per_sample_min_class_count = (self.metadata['label'].apply( + lambda x: class_counts[x.astype(bool)].min() if not self._is_noise(x) else noise_counts)) + return (1 / per_sample_min_class_count).values + else: + weights = 1 / np.unique(self.metadata['label'], return_counts=True)[1] + return np.array([weights[t] for t in self.metadata['label']]) + + def __getitem__(self, idx): ''' __getitem__ method loads item according to idx from the metadata @@ -276,7 +323,7 @@ def _get_audio(self, path_to_file, begin_time, end_time, label, channel=None): if self.mode == "train": if seg_length >= requested_seq_length: # Only for calls we can safely add sections out of the call and label it as call - if (self.margin_ratio != 0) and (label > 0): + if (self.margin_ratio != 0) and (not self._is_noise(label)): # self.margin_ratio ranges from 0 to 1 - indicates the relative part from seq_len to exceed call_length margin_len_begin = int(requested_seq_length * self.margin_ratio) margin_len_end = int(requested_seq_length * (1 - self.margin_ratio)) diff --git a/soundbay/train.py b/soundbay/train.py index de09508..9b0914b 100644 --- a/soundbay/train.py +++ b/soundbay/train.py @@ -110,7 +110,7 @@ def modeling( model.to(device) # Assert number of labels in the dataset and the number of labels in the model - assert model_args['num_classes'] == len(train_dataset.items_per_classes) == len(val_dataset.items_per_classes), \ + assert model_args['num_classes'] == train_dataset.num_classes == val_dataset.num_classes, \ "Num of classes in model and the datasets must be equal, check your configs and your dataset labels!!" # Add model watch to WANDB