Skip to content

Commit

Permalink
refactor: simplify dataset construction
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Nov 27, 2024
1 parent 656d200 commit ce5ff0c
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 61 deletions.
110 changes: 49 additions & 61 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import os
import queue
import time
from multiprocessing.dummy import (
from functools import (
partial,
)
from multiprocessing import (
Pool,
)
from queue import (
Queue,
)
from threading import (
Thread,
)
Expand Down Expand Up @@ -52,6 +57,13 @@ def setup_seed(seed) -> None:
torch.backends.cudnn.deterministic = True


def construct_dataset(system, type_map):
return DeepmdDataSetForLoader(
system=system,
type_map=type_map,
)


class DpLoaderSet(Dataset):
"""A dataset for storing DataLoaders to multiple Systems.
Expand Down Expand Up @@ -87,11 +99,7 @@ def __init__(
if len(systems) >= 100:
log.info(f"Constructing DataLoaders from {len(systems)} systems")

def construct_dataset(system):
return DeepmdDataSetForLoader(
system=system,
type_map=type_map,
)
construct_dataset_systems = partial(construct_dataset, type_map=type_map)

with Pool(
os.cpu_count()
Expand All @@ -101,7 +109,7 @@ def construct_dataset(system):
else 1
)
) as pool:
self.systems = pool.map(construct_dataset, systems)
self.systems = pool.map(construct_dataset_systems, systems)

self.sampler_list: list[DistributedSampler] = []
self.index = []
Expand Down Expand Up @@ -185,85 +193,65 @@ def print_summary(
name: str,
prob: list[float],
) -> None:
print_summary(
name,
len(self.systems),
[ss.system for ss in self.systems],
[ss._natoms for ss in self.systems],
self.batch_sizes,
[
ss._data_system.get_sys_numb_batch(self.batch_sizes[ii])
for ii, ss in enumerate(self.systems)
],
prob,
[ss._data_system.pbc for ss in self.systems],
)


_sentinel = object()
QUEUESIZE = 32
rank = dist.get_rank() if dist.is_initialized() else 0
if rank == 0:
print_summary(
name,
len(self.systems),
[ss.system for ss in self.systems],
[ss._natoms for ss in self.systems],
self.batch_sizes,
[
ss._data_system.get_sys_numb_batch(self.batch_sizes[ii])
for ii, ss in enumerate(self.systems)
],
prob,
[ss._data_system.pbc for ss in self.systems],
)


class BackgroundConsumer(Thread):
def __init__(self, queue, source, max_len) -> None:
Thread.__init__(self)
def __init__(self, queue, source) -> None:
super().__init__()
self.daemon = True
self._queue = queue
self._source = source # Main DL iterator
self._max_len = max_len #

def run(self) -> None:
for item in self._source:
self._queue.put(item) # Blocking if the queue is full

# Signal the consumer we are done.
self._queue.put(_sentinel)
# Signal the consumer we are done; this should not happen for DataLoader
self._queue.put(StopIteration)


QUEUESIZE = 32


class BufferedIterator:
def __init__(self, iterable) -> None:
self._queue = queue.Queue(QUEUESIZE)
self._queue = Queue(QUEUESIZE)
self._iterable = iterable
self._consumer = None

self.start_time = time.time()
self.warning_time = None
self.total = len(iterable)

def _create_consumer(self) -> None:
self._consumer = BackgroundConsumer(self._queue, self._iterable, self.total)
self._consumer.daemon = True
self._consumer = BackgroundConsumer(self._queue, self._iterable)
self._consumer.start()
self.len = len(iterable)

def __iter__(self):
return self

def __len__(self) -> int:
return self.total
return self.len

def __next__(self):
# Create consumer if not created yet
if self._consumer is None:
self._create_consumer()
# Notify the user if there is a data loading bottleneck
if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)):
if time.time() - self.start_time > 5 * 60:
if (
self.warning_time is None
or time.time() - self.warning_time > 15 * 60
):
log.warning(
"Data loading buffer is empty or nearly empty. This may "
"indicate a data loading bottleneck, and increasing the "
"number of workers (--num-workers) may help."
)
self.warning_time = time.time()

# Get next example
start_wait = time.time()
item = self._queue.get()
wait_time = time.time() - start_wait
if (
wait_time > 1.0
): # Even for Multi-Task training, each step usually takes < 1s
log.warning(f"Data loading is slow, waited {wait_time:.2f} seconds.")
if isinstance(item, Exception):
raise item
if item is _sentinel:
raise StopIteration
return item


Expand Down
3 changes: 3 additions & 0 deletions deepmd/utils/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def __new__(cls, path: str, mode: str = "r"):
raise FileNotFoundError(f"{path} not found")
return super().__new__(cls)

def __getnewargs__(self):
return (self.path, self.mode)

@abstractmethod
def load_numpy(self) -> np.ndarray:
"""Load NumPy array.
Expand Down

0 comments on commit ce5ff0c

Please sign in to comment.