Skip to content

Commit

Permalink
Support limiting num sequences per category.
Browse files Browse the repository at this point in the history
Summary:
Adds stratified sampling of sequences within categories applied after category / sequence filters but before the num sequence limit.
It respects the insertion order into the sequence_annots table, i.e. takes top N sequences within each category.

Reviewed By: bottler

Differential Revision: D46724002

fbshipit-source-id: 597cb2a795c3f3bc07f838fc51b4e95a4f981ad3
  • Loading branch information
shapovalov authored and facebook-github-bot committed Jun 14, 2023
1 parent 5ffeb4d commit 09a99f2
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 6 deletions.
39 changes: 33 additions & 6 deletions pytorch3d/implicitron/dataset/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
pick_categories: Restrict the dataset to the given list of categories.
pick_sequences: A Sequence of sequence names to restrict the dataset to.
exclude_sequences: A Sequence of the names of the sequences to exclude.
limit_sequences_per_category_to: Limit the dataset to the first up to N
sequences within each category (applies after all other sequence filters
but before `limit_sequences_to`).
limit_sequences_to: Limit the dataset to the first `limit_sequences_to`
sequences (after other sequence filters have been applied but before
frame-based filters).
Expand All @@ -115,6 +118,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore

pick_sequences: Tuple[str, ...] = ()
exclude_sequences: Tuple[str, ...] = ()
limit_sequences_per_category_to: int = 0
limit_sequences_to: int = 0
limit_to: int = 0
n_frames_per_sequence: int = -1
Expand Down Expand Up @@ -373,27 +377,46 @@ def is_filtered(self) -> bool:
self.remove_empty_masks
or self.limit_to > 0
or self.limit_sequences_to > 0
or self.limit_sequences_per_category_to > 0
or len(self.pick_sequences) > 0
or len(self.exclude_sequences) > 0
or len(self.pick_categories) > 0
or self.n_frames_per_sequence > 0
)

def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
# maximum possible query: WHERE category IN 'self.pick_categories'
# maximum possible filter (if limit_sequences_per_category_to == 0):
# WHERE category IN 'self.pick_categories'
# AND sequence_name IN 'self.pick_sequences'
# AND sequence_name NOT IN 'self.exclude_sequences'
# LIMIT 'self.limit_sequence_to'

stmt = sa.select(SqlSequenceAnnotation.sequence_name)

where_conditions = [
*self._get_category_filters(),
*self._get_pick_filters(),
*self._get_exclude_filters(),
]
if where_conditions:
stmt = stmt.where(*where_conditions)

def add_where(stmt):
return stmt.where(*where_conditions) if where_conditions else stmt

if self.limit_sequences_per_category_to <= 0:
stmt = add_where(sa.select(SqlSequenceAnnotation.sequence_name))
else:
subquery = sa.select(
SqlSequenceAnnotation.sequence_name,
sa.func.row_number()
.over(
order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific
partition_by=SqlSequenceAnnotation.category,
)
.label("row_number"),
)

subquery = add_where(subquery).subquery()
stmt = sa.select(subquery.c.sequence_name).where(
subquery.c.row_number <= self.limit_sequences_per_category_to
)

if self.limit_sequences_to > 0:
logger.info(
Expand All @@ -402,7 +425,11 @@ def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
# NOTE: ROWID is SQLite-specific
stmt = stmt.order_by(sa.text("ROWID")).limit(self.limit_sequences_to)

if not where_conditions and self.limit_sequences_to <= 0:
if (
not where_conditions
and self.limit_sequences_to <= 0
and self.limit_sequences_per_category_to <= 0
):
# we will not need to filter by sequences
return None

Expand Down
24 changes: 24 additions & 0 deletions tests/implicitron/test_sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,30 @@ def test_limit_frames_per_sequence(self, num_frames=2):
)
self.assertEqual(len(dataset), 100)

def test_limit_sequence_per_category(self, num_sequences=2):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
limit_sequences_per_category_to=num_sequences,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)

self.assertEqual(len(dataset), num_sequences * 10 * 2)
seq_names = list(dataset.sequence_names())
self.assertEqual(len(seq_names), num_sequences * 2)
# check that we respect the row order
for seq_name in seq_names:
self.assertLess(int(seq_name[-1]), num_sequences)

# test when the limit is not binding
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
limit_sequences_per_category_to=13,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 100)

def test_filter_medley(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
Expand Down

0 comments on commit 09a99f2

Please sign in to comment.