diff --git a/deepmd/pt/utils/auto_batch_size.py b/deepmd/pt/utils/auto_batch_size.py index 181d56f2f4..13264a336c 100644 --- a/deepmd/pt/utils/auto_batch_size.py +++ b/deepmd/pt/utils/auto_batch_size.py @@ -12,6 +12,28 @@ class AutoBatchSize(AutoBatchSizeBase): + """Auto batch size. + + Parameters + ---------- + initial_batch_size : int, default: 1024 + initial batch size (number of total atoms) when DP_INFER_BATCH_SIZE + is not set + factor : float, default: 2. + increased factor + + """ + + def __init__( + self, + initial_batch_size: int = 1024, + factor: float = 2.0, + ): + super().__init__( + initial_batch_size=initial_batch_size, + factor=factor, + ) + def is_gpu_available(self) -> bool: """Check if GPU is available. @@ -78,26 +100,50 @@ def execute_with_batch_size( ) index = 0 - results = [] + results = None + returned_dict = None while index < total_size: n_batch, result = self.execute(execute_with_batch_size, index, natoms) - if not isinstance(result, tuple): - result = (result,) + returned_dict = ( + isinstance(result, dict) if returned_dict is None else returned_dict + ) + if not returned_dict: + result = (result,) if not isinstance(result, tuple) else result index += n_batch - if n_batch: - for rr in result: - rr.reshape((n_batch, -1)) - results.append(result) - r_list = [] - for r in zip(*results): + + def append_to_list(res_list, res): + if n_batch: + res_list.append(res) + return res_list + + if not returned_dict: + results = [] if results is None else results + results = append_to_list(results, result) + else: + results = ( + {kk: [] for kk in result.keys()} if results is None else results + ) + results = { + kk: append_to_list(results[kk], result[kk]) for kk in result.keys() + } + assert results is not None + assert returned_dict is not None + + def concate_result(r): if isinstance(r[0], np.ndarray): - r_list.append(np.concatenate(r, axis=0)) + ret = np.concatenate(r, axis=0) elif isinstance(r[0], torch.Tensor): - r_list.append(torch.cat(r, dim=0)) + ret = torch.cat(r, dim=0) else: raise RuntimeError(f"Unexpected result type {type(r[0])}") - r = tuple(r_list) - if len(r) == 1: - # avoid returning tuple if callable doesn't return tuple - r = r[0] + return ret + + if not returned_dict: + r_list = [concate_result(r) for r in zip(*results)] + r = tuple(r_list) + if len(r) == 1: + # avoid returning tuple if callable doesn't return tuple + r = r[0] + else: + r = {kk: concate_result(vv) for kk, vv in results.items()} return r diff --git a/source/tests/pt/test_auto_batch_size.py b/source/tests/pt/test_auto_batch_size.py new file mode 100644 index 0000000000..71194e001e --- /dev/null +++ b/source/tests/pt/test_auto_batch_size.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np + +from deepmd.pt.utils.auto_batch_size import ( + AutoBatchSize, +) + + +class TestAutoBatchSize(unittest.TestCase): + def test_execute_all(self): + dd0 = np.zeros((10000, 2, 1, 3, 4)) + dd1 = np.ones((10000, 2, 1, 3, 4)) + auto_batch_size = AutoBatchSize(256, 2.0) + + def func(dd1): + return np.zeros_like(dd1), np.ones_like(dd1) + + dd2 = auto_batch_size.execute_all(func, 10000, 2, dd1) + np.testing.assert_equal(dd0, dd2[0]) + np.testing.assert_equal(dd1, dd2[1]) + + def test_execute_all_dict(self): + dd0 = np.zeros((10000, 2, 1, 3, 4)) + dd1 = np.ones((10000, 2, 1, 3, 4)) + auto_batch_size = AutoBatchSize(256, 2.0) + + def func(dd1): + return { + "foo": np.zeros_like(dd1), + "bar": np.ones_like(dd1), + } + + dd2 = auto_batch_size.execute_all(func, 10000, 2, dd1) + np.testing.assert_equal(dd0, dd2["foo"]) + np.testing.assert_equal(dd1, dd2["bar"])