Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: auto batch size supports methods that return a dict #3626

Merged
merged 2 commits into from
Mar 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 59 additions & 15 deletions deepmd/pt/utils/auto_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,32 @@


class AutoBatchSize(AutoBatchSizeBase):
"""Auto batch size.

Parameters
----------
initial_batch_size : int, default: 1024
initial batch size (number of total atoms) when DP_INFER_BATCH_SIZE
is not set
factor : float, default: 2.
increased factor
returned_dict:
if the batched method returns a dict of arrays.
anyangml marked this conversation as resolved.
Show resolved Hide resolved

"""

def __init__(

Check warning on line 29 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L29

Added line #L29 was not covered by tests
self,
initial_batch_size: int = 1024,
factor: float = 2.0,
returned_dict: bool = False,
):
super().__init__(

Check warning on line 35 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L35

Added line #L35 was not covered by tests
initial_batch_size=initial_batch_size,
factor=factor,
)
self.returned_dict = returned_dict

Check warning on line 39 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L39

Added line #L39 was not covered by tests

def is_gpu_available(self) -> bool:
"""Check if GPU is available.

Expand Down Expand Up @@ -78,26 +104,44 @@
)

index = 0
results = []
results = None

Check warning on line 107 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L107

Added line #L107 was not covered by tests
while index < total_size:
n_batch, result = self.execute(execute_with_batch_size, index, natoms)
if not isinstance(result, tuple):
result = (result,)
if not self.returned_dict:
result = (result,) if not isinstance(result, tuple) else result

Check warning on line 111 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L110-L111

Added lines #L110 - L111 were not covered by tests
index += n_batch
if n_batch:
for rr in result:
rr.reshape((n_batch, -1))
results.append(result)
r_list = []
for r in zip(*results):

def append_to_list(res_list, res):
if n_batch:
res_list.append(res)
anyangml marked this conversation as resolved.
Show resolved Hide resolved
return res_list

Check warning on line 117 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L114-L117

Added lines #L114 - L117 were not covered by tests

if not self.returned_dict:
results = [] if results is None else results
results = append_to_list(results, result)

Check warning on line 121 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L119-L121

Added lines #L119 - L121 were not covered by tests
else:
results = (

Check warning on line 123 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L123

Added line #L123 was not covered by tests
{kk: [] for kk in result.keys()} if results is None else results
)
results = {

Check warning on line 126 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L126

Added line #L126 was not covered by tests
kk: append_to_list(results[kk], result[kk]) for kk in result.keys()
}

def concate_result(r):

Check warning on line 130 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L130

Added line #L130 was not covered by tests
if isinstance(r[0], np.ndarray):
r_list.append(np.concatenate(r, axis=0))
ret = np.concatenate(r, axis=0)

Check warning on line 132 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L132

Added line #L132 was not covered by tests
elif isinstance(r[0], torch.Tensor):
r_list.append(torch.cat(r, dim=0))
ret = torch.cat(r, dim=0)

Check warning on line 134 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L134

Added line #L134 was not covered by tests
else:
raise RuntimeError(f"Unexpected result type {type(r[0])}")
r = tuple(r_list)
if len(r) == 1:
# avoid returning tuple if callable doesn't return tuple
r = r[0]
return ret

Check warning on line 137 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L137

Added line #L137 was not covered by tests

if not self.returned_dict:
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
r_list = [concate_result(r) for r in zip(*results)]
r = tuple(r_list)
if len(r) == 1:

Check warning on line 142 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L139-L142

Added lines #L139 - L142 were not covered by tests
# avoid returning tuple if callable doesn't return tuple
r = r[0]

Check warning on line 144 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L144

Added line #L144 was not covered by tests
else:
r = {kk: concate_result(vv) for kk, vv in results.items()}

Check warning on line 146 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L146

Added line #L146 was not covered by tests
return r
37 changes: 37 additions & 0 deletions source/tests/pt/test_auto_batch_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

import numpy as np

from deepmd.pt.utils.auto_batch_size import (
AutoBatchSize,
)


class TestAutoBatchSize(unittest.TestCase):
def test_execute_all(self):
dd0 = np.zeros((10000, 2, 1, 3, 4))
dd1 = np.ones((10000, 2, 1, 3, 4))
auto_batch_size = AutoBatchSize(256, 2.0)

def func(dd1):
return np.zeros_like(dd1), np.ones_like(dd1)

dd2 = auto_batch_size.execute_all(func, 10000, 2, dd1)
np.testing.assert_equal(dd0, dd2[0])
np.testing.assert_equal(dd1, dd2[1])

def test_execute_all_dict(self):
dd0 = np.zeros((10000, 2, 1, 3, 4))
dd1 = np.ones((10000, 2, 1, 3, 4))
auto_batch_size = AutoBatchSize(256, 2.0, returned_dict=True)

def func(dd1):
return {
"foo": np.zeros_like(dd1),
"bar": np.ones_like(dd1),
}

dd2 = auto_batch_size.execute_all(func, 10000, 2, dd1)
np.testing.assert_equal(dd0, dd2["foo"])
np.testing.assert_equal(dd1, dd2["bar"])