Skip to content

Commit

Permalink
fix(pt): make 'find_' to be float in get data (deepmodeling#3992)
Browse files Browse the repository at this point in the history
make 'find_' to be float in get data, fix deepmodeling#3991 .

On my device, the profiler indicates that `cudaStreamSynchronize` takes
negligible time, resulting in minimal speedup.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced data loading by adding a `collate_fn` parameter for more
flexible data collation.
- Improved data filtering by excluding keys containing "find_" in
addition to existing filters.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
iProzd authored and Mathieu Taillefumier committed Sep 18, 2024
1 parent 5f0a45c commit d1882a6
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def get_dataloader_and_buffer(_data, _params):
if dist.is_available()
else 0, # setting to 0 diverges the behavior of its iterator; should be >=1
drop_last=False,
collate_fn=lambda batch: batch, # prevent extra conversion
pin_memory=True,
)
with torch.device("cpu"):
Expand Down Expand Up @@ -1093,7 +1094,7 @@ def get_data(self, is_train=True, task_key="Default"):
batch_data = next(iter(self.validation_data[task_key]))

for key in batch_data.keys():
if key == "sid" or key == "fid" or key == "box":
if key == "sid" or key == "fid" or key == "box" or "find_" in key:
continue
elif not isinstance(batch_data[key], list):
if batch_data[key] is not None:
Expand Down

0 comments on commit d1882a6

Please sign in to comment.