Skip to content

Commit

Permalink
refactor: simplify dataset construction (#4437)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced a new function for dataset construction, enhancing data
loading processes.
- Added a method to improve pickling and unpickling capabilities for
path handling classes.

- **Bug Fixes**
- Updated summary printing to prevent redundant output during
distributed training.

- **Refactor**
	- Simplified initialization of the BackgroundConsumer class.
- Streamlined consumer thread and queue handling in the BufferedIterator
class.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Chun Cai <amoycaic@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 9, 2024
1 parent ce9aeb3 commit b4ade5c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 41 deletions.
65 changes: 24 additions & 41 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import os
import queue
import time
from multiprocessing.dummy import (
Pool,
)
from queue import (
Queue,
)
from threading import (
Thread,
)
Expand Down Expand Up @@ -204,70 +206,51 @@ def print_summary(
)


_sentinel = object()
QUEUESIZE = 32


class BackgroundConsumer(Thread):
def __init__(self, queue, source, max_len) -> None:
Thread.__init__(self)
def __init__(self, queue, source) -> None:
super().__init__()
self.daemon = True
self._queue = queue
self._source = source # Main DL iterator
self._max_len = max_len #

def run(self) -> None:
for item in self._source:
self._queue.put(item) # Blocking if the queue is full

# Signal the consumer we are done.
self._queue.put(_sentinel)
# Signal the consumer we are done; this should not happen for DataLoader
self._queue.put(StopIteration())


QUEUESIZE = 32


class BufferedIterator:
def __init__(self, iterable) -> None:
self._queue = queue.Queue(QUEUESIZE)
self._queue = Queue(QUEUESIZE)
self._iterable = iterable
self._consumer = None

self.start_time = time.time()
self.warning_time = None
self.total = len(iterable)

def _create_consumer(self) -> None:
self._consumer = BackgroundConsumer(self._queue, self._iterable, self.total)
self._consumer.daemon = True
self._consumer = BackgroundConsumer(self._queue, self._iterable)
self._consumer.start()
self.last_warning_time = time.time()

def __iter__(self):
return self

def __len__(self) -> int:
return self.total
return len(self._iterable)

def __next__(self):
# Create consumer if not created yet
if self._consumer is None:
self._create_consumer()
# Notify the user if there is a data loading bottleneck
if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)):
if time.time() - self.start_time > 5 * 60:
if (
self.warning_time is None
or time.time() - self.warning_time > 15 * 60
):
log.warning(
"Data loading buffer is empty or nearly empty. This may "
"indicate a data loading bottleneck, and increasing the "
"number of workers (--num-workers) may help."
)
self.warning_time = time.time()

# Get next example
start_wait = time.time()
item = self._queue.get()
wait_time = time.time() - start_wait
if (
wait_time > 1.0 and start_wait - self.last_warning_time > 15 * 60
): # Even for Multi-Task training, each step usually takes < 1s
log.warning(
f"Data loading is slow, waited {wait_time:.2f} seconds. Ignoring this warning for 15 minutes."
)
self.last_warning_time = start_wait
if isinstance(item, Exception):
raise item
if item is _sentinel:
raise StopIteration
return item


Expand Down
10 changes: 10 additions & 0 deletions deepmd/utils/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def is_file(self) -> bool:
def is_dir(self) -> bool:
"""Check if self is directory."""

@abstractmethod
def __getnewargs__(self):
"""Return the arguments to be passed to __new__ when unpickling an instance."""

@abstractmethod
def __truediv__(self, key: str) -> "DPPath":
"""Used for / operator."""
Expand Down Expand Up @@ -169,6 +173,9 @@ def __init__(self, path: Union[str, Path], mode: str = "r") -> None:
self.mode = mode
self.path = Path(path)

def __getnewargs__(self):
return (self.path, self.mode)

def load_numpy(self) -> np.ndarray:
"""Load NumPy array.
Expand Down Expand Up @@ -304,6 +311,9 @@ def __init__(self, path: str, mode: str = "r") -> None:
# h5 path: default is the root path
self._name = s[1] if len(s) > 1 else "/"

def __getnewargs__(self):
return (self.root_path, self.mode)

@classmethod
@lru_cache(None)
def _load_h5py(cls, path: str, mode: str = "r") -> h5py.File:
Expand Down

0 comments on commit b4ade5c

Please sign in to comment.