Skip to content

Commit

Permalink
changed behavior of resample to retain all-0 rows
Browse files Browse the repository at this point in the history
default behavior now uses argument n_samples_without_labels=None, which retains all rows from original df with all-0 labels.

Can still specify a number eg 0, 20, to up/downsample all-0 rows

Updated tests
  • Loading branch information
sammlapp committed Jan 23, 2025
1 parent d6f556d commit 46e16cc
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 16 deletions.
40 changes: 27 additions & 13 deletions opensoundscape/data_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,28 @@
def resample(
df,
n_samples_per_class,
n_samples_without_labels=0,
n_samples_without_labels=None,
upsample=True,
downsample=True,
with_replace=False,
random_state=None,
):
"""resample a one-hot encoded label df for a target n_samples_per_class
Returns a new dataframe with duplicated and/or subset rows. Note that the order of samples changes.
Can enable/disable upsampling (randomly repeating rows) and downsampling (randomly subsetting rows)
args:
df: dataframe with one-hot encoded labels: columns are classes, index is sample name/path
n_samples_per_class: target number of samples per class
n_samples_without_labels: number of samples with all-0 labels to include in the returned df
[default: 0]. `upsample` and `downsample` flags are ignored for generating all-0 label samples.
None or integer.
- [default: None] keeps all of the original df's rows that have all-0 labels.
- if integer > 0: upsample or downsample as needed from original df to achieve this number
of rows with all-0 labels
- if 0: no all-0 labels are included in the returned df
Note: `upsample` and `downsample` arguments are ignored for generating all-0 label samples.
upsample: if True, duplicate samples for classes with <n samples to get to n samples
downsample: if True, randomly sample classis with >n samples to get to n samples
with_replace: flag to enable sampling of the same row more than once, default False
Expand All @@ -45,16 +54,16 @@ def resample(

class_dfs = [None] * len(df.columns)
for idx, unique_label in enumerate(df.columns):
sub_df = df[df[unique_label] == 1]
n_class_samples = sub_df.shape[0]
no_labels_df = df[df[unique_label] == 1]
n_class_samples = no_labels_df.shape[0]

if n_class_samples < n_samples_per_class and (not upsample):
# we don't want to upsample, so just keep these samples
class_dfs[idx] = sub_df
class_dfs[idx] = no_labels_df
continue
if n_class_samples > n_samples_per_class and (not downsample):
# we don't want to downsample, so just keep all of samples
class_dfs[idx] = sub_df
class_dfs[idx] = no_labels_df
continue

# upsample or downsample as needed to get to n samples
Expand All @@ -63,21 +72,25 @@ def resample(
# take a random sample for the "remainder" portion
# this is the entirety of the new set of n samples if downsampling,
# and the samples with an 'extra' representation if upsampling
random_df = sub_df.sample(
random_df = no_labels_df.sample(
n=remainder, replace=with_replace, random_state=random_state
)

# if upsampling, repeat all of the samples as many times as necessary
if num_replicates > 0:
repeat_df = pd.concat(itertools.repeat(sub_df, num_replicates))
repeat_df = pd.concat(itertools.repeat(no_labels_df, num_replicates))
class_dfs[idx] = pd.concat([repeat_df, random_df])
else:
class_dfs[idx] = random_df

# add samples without any labels, if desired (i.e. "negatives")
if n_samples_without_labels > 0:
sub_df = df[df.sum(1) == 0]
n_negatives = sub_df.shape[0]
if n_samples_without_labels is None:
# keep all samples (rows) from original df that did not contain any labels
class_dfs.append(df[df.sum(1) == 0])
elif n_samples_without_labels > 0:
# user specified
no_labels_df = df[df.sum(1) == 0]
n_negatives = no_labels_df.shape[0]
num_replicates, remainder = divmod(n_samples_without_labels, n_negatives)

# should we consider the upsampling and downsampling flags here?
Expand All @@ -89,16 +102,17 @@ def resample(
# # we don't want to downsample, so just keep all of samples
# class_dfs.append(sub_df)

random_df = sub_df.sample(
random_df = no_labels_df.sample(
n=remainder, replace=with_replace, random_state=random_state
)

# if upsampling, repeat all of the samples as many times as necessary
if num_replicates > 0:
repeat_df = pd.concat(itertools.repeat(sub_df, num_replicates))
repeat_df = pd.concat(itertools.repeat(no_labels_df, num_replicates))
class_dfs.append(pd.concat([repeat_df, random_df]))
else:
class_dfs.append(random_df)
# (implicit) else: keep 0 samples without any labels

return pd.concat(class_dfs)

Expand Down
13 changes: 10 additions & 3 deletions tests/test_data_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,22 @@ def input_dataframe():


def test_resample_basic(resample_df):
# changed behavior: now retains all-0 rows by default
for _ in range(5):
df = selection.resample(resample_df, 2)
assert df.shape[0] == 6
assert df.shape[0] == 2 * 3 + 2


def test_resample_no_upsample(resample_df):
for _ in range(5):
df = selection.resample(resample_df, 2, upsample=False)
assert df.shape[0] == 5
assert df.shape[0] == 5 + 2


def test_resample_no_downsample(resample_df):
for _ in range(5):
df = selection.resample(resample_df, 2, downsample=False)
assert df.shape[0] == 8
assert df.shape[0] == 8 + 2


def test_resample_inclue_negatives(resample_df):
Expand All @@ -69,3 +70,9 @@ def test_upsample_basic(upsample_df):
for _ in range(100):
upsampled_df = selection.upsample(upsample_df)
assert upsampled_df.shape[0] == 20


def test_resample_no_negatives(resample_df):
for _ in range(5):
df = selection.resample(resample_df, 2, n_samples_without_labels=0)
assert df.shape[0] == 2 * 3

0 comments on commit 46e16cc

Please sign in to comment.