Skip to content

Commit

Permalink
pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
amontanez24 committed Jun 28, 2024
1 parent 9dcf7e6 commit c810b15
Show file tree
Hide file tree
Showing 7 changed files with 279 additions and 261 deletions.
4 changes: 2 additions & 2 deletions sdv/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Utils module."""

from sdv.utils.utils import drop_unknown_references
from sdv.utils.utils import drop_unknown_references, get_random_sequence_subset

__all__ = ('drop_unknown_references',)
__all__ = ('drop_unknown_references', 'get_random_sequence_subset')
66 changes: 0 additions & 66 deletions sdv/utils/poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@

import warnings

import numpy as np
import pandas as pd

from sdv.errors import InvalidDataError
from sdv.metadata.errors import InvalidMetadataError
from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS
Expand Down Expand Up @@ -142,66 +139,3 @@ def get_random_subset(data, metadata, main_table_name, num_rows, verbose=True):

metadata.validate_data(result)
return result


def get_random_sequence_subset(
data,
metadata,
num_sequences,
max_sequence_length=None,
long_sequence_subsampling_method='first_rows',
):
"""Subsample sequential data based on a number of sequences.
Args:
data (pandas.DataFrame):
The sequential data.
metadata (SingleTableMetadata):
A SingleTableMetadata object describing the data.
num_sequences (int):
The number of sequences to subsample.
max_sequence_length (int):
The maximum length each subsampled sequence is allowed to be. Defaults to None. If
None, do not enforce any max length, meaning that entire sequences will be sampled.
If provided all subsampled sequences must be <= the provided length.
long_sequence_subsampling_method (str):
The method to use when a selected sequence is too long. Options are:
- (default) first_rows: Keep the first n rows of the sequence, where n is the max
sequence length.
- last_rows: Keep the last n rows of the sequence, where n is the max sequence length.
- random: Randomly choose n rows to keep within the sequence. It is important to keep
the randomly chosen rows in the same order as they appear in the original data.
"""
if long_sequence_subsampling_method not in ['first_rows', 'last_rows', 'random']:
raise ValueError(
'long_sequence_subsampling_method must be one of "first_rows", "last_rows" or "random"'
)

sequence_key = metadata.sequence_key
if not sequence_key:
raise ValueError(
'Your metadata does not include a sequence key. A sequence key must be provided to '
'subset the sequential data.'
)

selected_sequences = np.random.permutation(data[sequence_key])[:num_sequences]
subset = data[data[sequence_key].isin(selected_sequences)].reset_index(drop=True)
if max_sequence_length:
grouped_sequences = subset.groupby(sequence_key)
if long_sequence_subsampling_method == 'first_rows':
return grouped_sequences.head(max_sequence_length).reset_index(drop=True)
elif long_sequence_subsampling_method == 'last_rows':
return grouped_sequences.tail(max_sequence_length).reset_index(drop=True)
else:
subsetted_sequences = []
for _, group in grouped_sequences:
if len(group) > max_sequence_length:
idx = np.random.permutation(len(group))[:max_sequence_length]
idx.sort()
subsetted_sequences.append(group.iloc[idx])
else:
subsetted_sequences.append(group)

return pd.concat(subsetted_sequences, ignore_index=True)

return subset
69 changes: 69 additions & 0 deletions sdv/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
from copy import deepcopy

import numpy as np
import pandas as pd

from sdv._utils import _validate_foreign_keys_not_null
Expand Down Expand Up @@ -60,3 +61,71 @@ def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=Tr
sys.stdout.write('\n'.join([success_message, '', summary_table.to_string(index=False)]))

return result


def get_random_sequence_subset(
data,
metadata,
num_sequences,
max_sequence_length=None,
long_sequence_subsampling_method='first_rows',
):
"""Subsample sequential data based on a number of sequences.
Args:
data (pandas.DataFrame):
The sequential data.
metadata (SingleTableMetadata):
A SingleTableMetadata object describing the data.
num_sequences (int):
The number of sequences to subsample.
max_sequence_length (int):
The maximum length each subsampled sequence is allowed to be. Defaults to None. If
None, do not enforce any max length, meaning that entire sequences will be sampled.
If provided all subsampled sequences must be <= the provided length.
long_sequence_subsampling_method (str):
The method to use when a selected sequence is too long. Options are:
- (default) first_rows: Keep the first n rows of the sequence, where n is the max
sequence length.
- last_rows: Keep the last n rows of the sequence, where n is the max sequence length.
- random: Randomly choose n rows to keep within the sequence. It is important to keep
the randomly chosen rows in the same order as they appear in the original data.
"""
if long_sequence_subsampling_method not in ['first_rows', 'last_rows', 'random']:
raise ValueError(
'long_sequence_subsampling_method must be one of "first_rows", "last_rows" or "random"'
)

sequence_key = metadata.sequence_key
if not sequence_key:
raise ValueError(
'Your metadata does not include a sequence key. A sequence key must be provided to '
'subset the sequential data.'
)

if sequence_key not in data:
raise ValueError(
'Your provided sequence key is not in the data. This is required to get a subset.'
)

selected_sequences = np.random.permutation(data[sequence_key])[:num_sequences]
subset = data[data[sequence_key].isin(selected_sequences)].reset_index(drop=True)
if max_sequence_length:
grouped_sequences = subset.groupby(sequence_key)
if long_sequence_subsampling_method == 'first_rows':
return grouped_sequences.head(max_sequence_length).reset_index(drop=True)
elif long_sequence_subsampling_method == 'last_rows':
return grouped_sequences.tail(max_sequence_length).reset_index(drop=True)
else:
subsetted_sequences = []
for _, group in grouped_sequences:
if len(group) > max_sequence_length:
idx = np.random.permutation(len(group))[:max_sequence_length]
idx.sort()
subsetted_sequences.append(group.iloc[idx])
else:
subsetted_sequences.append(group)

return pd.concat(subsetted_sequences, ignore_index=True)

return subset
53 changes: 1 addition & 52 deletions tests/integration/utils/test_poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sdv.metadata import MultiTableMetadata
from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS, HMASynthesizer
from sdv.multi_table.utils import _get_total_estimated_columns
from sdv.utils.poc import get_random_sequence_subset, get_random_subset, simplify_schema
from sdv.utils.poc import get_random_subset, simplify_schema


@pytest.fixture
Expand Down Expand Up @@ -246,54 +246,3 @@ def test_get_random_subset_with_missing_values(metadata, data):
# Assert
assert len(cleaned_data['child']) == 3
assert not pd.isna(cleaned_data['child']['parent_id']).any()


def test_get_random_sequence_subset():
"""Test that the sequences are subsetted and properly clipped."""
# Setup
data, metadata = download_demo(modality='sequential', dataset_name='nasdaq100_2019')

# Run
subset = get_random_sequence_subset(data, metadata, num_sequences=3, max_sequence_length=5)

# Assert
selected_sequences = subset[metadata.sequence_key].unique()
assert len(selected_sequences) == 3
for sequence_key in selected_sequences:
pd.testing.assert_frame_equal(
subset[subset[metadata.sequence_key] == sequence_key].reset_index(drop=True),
data[data[metadata.sequence_key] == sequence_key].head(5).reset_index(drop=True),
)


def test_get_random_sequence_subset_random_clipping():
"""Test that the sequences are subsetted and properly clipped.
If the long_sequence_sampling_method is set to 'random', the selected sequences should be
subsampled randomly, but maintain the same order.
"""
# Setup
data, metadata = download_demo(modality='sequential', dataset_name='nasdaq100_2019')

# Run
subset = get_random_sequence_subset(
data,
metadata,
num_sequences=3,
max_sequence_length=5,
long_sequence_subsampling_method='random',
)

# Assert
selected_sequences = subset[metadata.sequence_key].unique()
assert len(selected_sequences) == 3
for sequence_key in selected_sequences:
selected_sequence = subset[subset[metadata.sequence_key] == sequence_key]
assert len(selected_sequence) <= 5
subset_data = data[
data['Date'].isin(selected_sequence['Date'])
& data['Symbol'].isin(selected_sequence['Symbol'])
]
pd.testing.assert_frame_equal(
subset_data.reset_index(drop=True), selected_sequence.reset_index(drop=True)
)
54 changes: 53 additions & 1 deletion tests/integration/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import pandas as pd
import pytest

from sdv.datasets.demo import download_demo
from sdv.errors import InvalidDataError
from sdv.metadata import MultiTableMetadata
from sdv.utils import drop_unknown_references
from sdv.utils import drop_unknown_references, get_random_sequence_subset


@pytest.fixture
Expand Down Expand Up @@ -140,3 +141,54 @@ def test_drop_unknown_references_not_drop_missing_values(metadata, data):
pd.testing.assert_frame_equal(cleaned_data['child'], data['child'].iloc[:4])
assert pd.isna(cleaned_data['child']['parent_id']).any()
assert len(cleaned_data['child']) == 4


def test_get_random_sequence_subset():
"""Test that the sequences are subsetted and properly clipped."""
# Setup
data, metadata = download_demo(modality='sequential', dataset_name='nasdaq100_2019')

# Run
subset = get_random_sequence_subset(data, metadata, num_sequences=3, max_sequence_length=5)

# Assert
selected_sequences = subset[metadata.sequence_key].unique()
assert len(selected_sequences) == 3
for sequence_key in selected_sequences:
pd.testing.assert_frame_equal(
subset[subset[metadata.sequence_key] == sequence_key].reset_index(drop=True),
data[data[metadata.sequence_key] == sequence_key].head(5).reset_index(drop=True),
)


def test_get_random_sequence_subset_random_clipping():
"""Test that the sequences are subsetted and properly clipped.
If the long_sequence_sampling_method is set to 'random', the selected sequences should be
subsampled randomly, but maintain the same order.
"""
# Setup
data, metadata = download_demo(modality='sequential', dataset_name='nasdaq100_2019')

# Run
subset = get_random_sequence_subset(
data,
metadata,
num_sequences=3,
max_sequence_length=5,
long_sequence_subsampling_method='random',
)

# Assert
selected_sequences = subset[metadata.sequence_key].unique()
assert len(selected_sequences) == 3
for sequence_key in selected_sequences:
selected_sequence = subset[subset[metadata.sequence_key] == sequence_key]
assert len(selected_sequence) <= 5
subset_data = data[
data['Date'].isin(selected_sequence['Date'])
& data['Symbol'].isin(selected_sequence['Symbol'])
]
pd.testing.assert_frame_equal(
subset_data.reset_index(drop=True), selected_sequence.reset_index(drop=True)
)
Loading

0 comments on commit c810b15

Please sign in to comment.