Skip to content

Commit

Permalink
prefetch data during training (#2534)
Browse files Browse the repository at this point in the history
Fix #2229.

Train models and prefetch data in parallel to decouple the time when
data is produced from the time when data is consumed.

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored May 19, 2023
1 parent 0b670f8 commit f6d883f
Showing 1 changed file with 117 additions and 11 deletions.
128 changes: 117 additions & 11 deletions deepmd/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
import platform
import shutil
import time
from typing import (
Dict,
List,
)

import google.protobuf.message
import numpy as np
Expand Down Expand Up @@ -57,6 +61,9 @@
from deepmd.utils.argcheck import (
type_embedding_args,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.errors import (
GraphTooLargeError,
GraphWithoutTensorError,
Expand Down Expand Up @@ -903,19 +910,55 @@ def train(self, train_data=None, valid_data=None):

train_time = 0
total_train_time = 0.0
wall_time_tic = time.time()

next_batch_train_op = None
next_fitting_key = None
next_train_batch_list = None
next_datasetloader = None

# dataset loader op
if not self.multi_task_mode:
datasetloader = DatasetLoader(train_data)
data_op = datasetloader.build()
else:
datasetloader = {}
data_op = {}
for fitting_key in self.fitting_type_dict:
datasetloader[fitting_key] = DatasetLoader(train_data[fitting_key])
data_op[fitting_key] = datasetloader[fitting_key].build()

while cur_batch < stop_batch:
# first round validation:
if is_first_step:
if not self.multi_task_mode:
train_batch = train_data.get_batch()
batch_train_op = self.train_op
else:
fitting_idx = dp_random.choice(
np.arange(self.nfitting), p=np.array(self.fitting_prob)
)
fitting_key = self.fitting_key_list[fitting_idx]
train_batch = train_data[fitting_key].get_batch()
batch_train_op = self.train_op[fitting_key]
else:
train_batch = next_datasetloader.get_data_dict(next_train_batch_list)
batch_train_op = next_batch_train_op
fitting_key = next_fitting_key
# for next round
if not self.multi_task_mode:
train_batch = train_data.get_batch()
batch_train_op = self.train_op
next_datasetloader = datasetloader
next_batch_train_op = self.train_op
next_train_batch_op = data_op
else:
fitting_idx = dp_random.choice(
np.arange(self.nfitting), p=np.array(self.fitting_prob)
)
fitting_key = self.fitting_key_list[fitting_idx]
train_batch = train_data[fitting_key].get_batch()
batch_train_op = self.train_op[fitting_key]
next_fitting_key = self.fitting_key_list[fitting_idx]
next_datasetloader = datasetloader[next_fitting_key]
next_batch_train_op = self.train_op[fitting_key]
next_train_batch_op = data_op[fitting_key]

if self.display_in_training and is_first_step:
if self.run_opt.is_chief:
if not self.multi_task_mode:
Expand Down Expand Up @@ -964,18 +1007,18 @@ def train(self, train_data=None, valid_data=None):
# use tensorboard to visualize the training of deepmd-kit
# it will takes some extra execution time to generate the tensorboard data
if self.tensorboard and (cur_batch % self.tensorboard_freq == 0):
summary, _ = run_sess(
summary, _, next_train_batch_list = run_sess(
self.sess,
[summary_merged_op, batch_train_op],
[summary_merged_op, batch_train_op, next_train_batch_op],
feed_dict=train_feed_dict,
options=prf_options,
run_metadata=prf_run_metadata,
)
tb_train_writer.add_summary(summary, cur_batch)
else:
run_sess(
_, next_train_batch_list = run_sess(
self.sess,
[batch_train_op],
[batch_train_op, next_train_batch_op],
feed_dict=train_feed_dict,
options=prf_options,
run_metadata=prf_run_metadata,
Expand Down Expand Up @@ -1025,14 +1068,16 @@ def train(self, train_data=None, valid_data=None):
if self.timing_in_training:
toc = time.time()
test_time = toc - tic
wall_time = toc - wall_time_tic
log.info(
"batch %7d training time %.2f s, testing time %.2f s"
% (cur_batch, train_time, test_time)
"batch %7d training time %.2f s, testing time %.2f s, total wall time %.2f s"
% (cur_batch, train_time, test_time, wall_time)
)
# the first training time is not accurate
if cur_batch > self.disp_freq or stop_batch < 2 * self.disp_freq:
total_train_time += train_time
train_time = 0
wall_time_tic = toc
if (
self.save_freq > 0
and cur_batch % self.save_freq == 0
Expand Down Expand Up @@ -1405,3 +1450,64 @@ def _change_energy_bias(
bias_shift=bias_shift,
ntest=self.model_param.get("data_bias_nsample", 10),
)


class DatasetLoader:
"""Generate an OP that loads the training data from the given DeepmdDataSystem.
It can be used to load the training data in the training process, so there is
no waiting time between training steps.
Parameters
----------
train_data : DeepmdDataSystem
The training data.
Examples
--------
>>> loader = DatasetLoader(train_data)
>>> data_op = loader.build()
>>> with tf.Session() as sess:
>>> data_list = sess.run(data_op)
>>> data_dict = loader.get_data_dict(data_list)
"""

def __init__(self, train_data: DeepmdDataSystem):
self.train_data = train_data
# get the keys of the data
batch_data = self.train_data.get_batch()
self.data_keys = batch_data.keys()
self.data_types = [tf.as_dtype(x.dtype) for x in batch_data.values()]

def build(self) -> List[tf.Tensor]:
"""Build the OP that loads the training data.
Returns
-------
List[tf.Tensor]
Tensor of the loaded data.
"""
train_data = self.train_data

def get_train_batch() -> List[np.ndarray]:
batch_data = train_data.get_batch()
# convert dict to list of arryas
batch_data = tuple([batch_data[kk] for kk in self.data_keys])
return batch_data

return tf.py_func(get_train_batch, [], self.data_types, name="train_data")

def get_data_dict(self, batch_list: List[np.ndarray]) -> Dict[str, np.ndarray]:
"""Generate a dict of the loaded data.
Parameters
----------
batch_list : List[np.ndarray]
The loaded data.
Returns
-------
Dict[str, np.ndarray]
The dict of the loaded data.
"""
return {kk: vv for kk, vv in zip(self.data_keys, batch_list)}

0 comments on commit f6d883f

Please sign in to comment.