Skip to content

Commit

Permalink
reduce batchsize
Browse files Browse the repository at this point in the history
  • Loading branch information
v-chen_data committed Nov 30, 2024
1 parent 4ca7bbb commit 17cbd20
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
8 changes: 4 additions & 4 deletions tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@ def get_trainer(self, precision='fp32', **kwargs):

train_dataset = RandomClassificationDataset()
eval_dataset = RandomClassificationDataset()
train_batch_size = 4
train_batch_size = 2

evaluator1 = DataLoader(
dataset=eval_dataset,
batch_size=8,
batch_size=2,
sampler=dist.get_sampler(eval_dataset),
)

evaluator2 = DataLoader(
dataset=eval_dataset,
batch_size=4,
batch_size=2,
sampler=dist.get_sampler(eval_dataset),
)

Expand All @@ -57,7 +57,7 @@ def get_trainer(self, precision='fp32', **kwargs):
precision=precision,
train_subset_num_batches=self.train_subset_num_batches,
eval_subset_num_batches=self.eval_subset_num_batches,
max_duration='2ep',
max_duration='1ep1ba',
optimizers=optimizer,
callbacks=[EventCounterCallback()],
**kwargs,
Expand Down
6 changes: 3 additions & 3 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def get_trainer(
save_folder: Optional[str] = None,
save_filename: str = 'ba{batch}-rank{rank}.pt',
save_overwrite: bool = False,
num_features: int = 4, # Reduced from default
num_classes: int = 2, # Reduced from default
num_features: int = 4,
num_classes: int = 2,
load_path: Optional[str] = None,
autoresume: bool = False,
run_name: Optional[str] = None,
Expand All @@ -111,7 +111,7 @@ def get_trainer(
val_metrics=val_metrics,
)
model.module.to(model_init_device)
dataset = RandomClassificationDataset(shape=(num_features,), num_classes=num_classes, size=32)
dataset = RandomClassificationDataset(shape=(num_features,), num_classes=num_classes, size=8)
dataloader = DataLoader(
dataset,
sampler=dist.get_sampler(dataset),
Expand Down

0 comments on commit 17cbd20

Please sign in to comment.