Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: explicitly set device #3307

Merged
merged 8 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,13 +343,17 @@
natoms = len(atom_types[0])

coord_input = torch.tensor(
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION
).to(DEVICE)
type_input = torch.tensor(atom_types, dtype=torch.long).to(DEVICE)
coords.reshape([-1, natoms, 3]),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)

Check warning on line 350 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L350

Added line #L350 was not covered by tests
if cells is not None:
box_input = torch.tensor(
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION
).to(DEVICE)
cells.reshape([-1, 3, 3]),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
else:
box_input = None

Expand Down Expand Up @@ -420,7 +424,7 @@
if cells is not None:
assert isinstance(cells, torch.Tensor), err_msg
assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list)
atom_types = torch.tensor(atom_types, dtype=torch.long).to(DEVICE)
atom_types = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)

Check warning on line 427 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L427

Added line #L427 was not covered by tests
elif isinstance(coords, np.ndarray):
if cells is not None:
assert isinstance(cells, np.ndarray), err_msg
Expand All @@ -441,17 +445,17 @@
natoms = len(atom_types[0])

coord_input = torch.tensor(
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION
).to(DEVICE)
type_input = torch.tensor(atom_types, dtype=torch.long).to(DEVICE)
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)

Check warning on line 450 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L450

Added line #L450 was not covered by tests
box_input = None
if cells is None:
pbc = False
else:
pbc = True
box_input = torch.tensor(
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION
).to(DEVICE)
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size)

for ii in range(num_iter):
Expand Down Expand Up @@ -527,35 +531,37 @@
energy_out = (
torch.cat(energy_out)
if energy_out
else torch.zeros([nframes, 1], dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE)
else torch.zeros(
[nframes, 1], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
)
atomic_energy_out = (
torch.cat(atomic_energy_out)
if atomic_energy_out
else torch.zeros([nframes, natoms, 1], dtype=GLOBAL_PT_FLOAT_PRECISION).to(
DEVICE
else torch.zeros(
[nframes, natoms, 1], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
)
force_out = (
torch.cat(force_out)
if force_out
else torch.zeros([nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION).to(
DEVICE
else torch.zeros(
[nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
)
virial_out = (
torch.cat(virial_out)
if virial_out
else torch.zeros([nframes, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION).to(
DEVICE
else torch.zeros(
[nframes, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
)
atomic_virial_out = (
torch.cat(atomic_virial_out)
if atomic_virial_out
else torch.zeros(
[nframes, natoms, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION
).to(DEVICE)
[nframes, natoms, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
)
updated_coord_out = torch.cat(updated_coord_out) if updated_coord_out else None
logits_out = torch.cat(logits_out) if logits_out else None
Expand Down
6 changes: 4 additions & 2 deletions deepmd/pt/infer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@

@staticmethod
def get_data(data):
batch_data = next(iter(data))
with torch.device("cpu"):
batch_data = next(iter(data))

Check warning on line 175 in deepmd/pt/infer/inference.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/inference.py#L174-L175

Added lines #L174 - L175 were not covered by tests
for key in batch_data.keys():
if key == "sid" or key == "fid":
continue
Expand Down Expand Up @@ -235,7 +236,8 @@
), # setting to 0 diverges the behavior of its iterator; should be >=1
drop_last=False,
)
data = iter(dataloader)
with torch.device("cpu"):
data = iter(dataloader)

Check warning on line 240 in deepmd/pt/infer/inference.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/inference.py#L239-L240

Added lines #L239 - L240 were not covered by tests

single_results = {}
sum_natoms = 0
Expand Down
19 changes: 15 additions & 4 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
FittingOutputDef,
OutputVariableDef,
)
from deepmd.pt.utils import (

Check warning on line 19 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L19

Added line #L19 was not covered by tests
env,
)
from deepmd.pt.utils.nlist import (
build_multiple_neighbor_list,
get_multiple_nlist_key,
Expand Down Expand Up @@ -91,9 +94,17 @@

def _sort_rcuts_sels(self) -> Tuple[List[float], List[int]]:
# sort the pair of rcut and sels in ascending order, first based on sel, then on rcut.
rcuts = torch.tensor(self.get_model_rcuts(), dtype=torch.float64)
nsels = torch.tensor(self.get_model_nsels())
zipped = torch.stack([torch.tensor(rcuts), torch.tensor(nsels)], dim=0).T
rcuts = torch.tensor(

Check warning on line 97 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L97

Added line #L97 was not covered by tests
self.get_model_rcuts(), dtype=torch.float64, device=env.DEVICE
)
nsels = torch.tensor(self.get_model_nsels(), device=env.DEVICE)
zipped = torch.stack(

Check warning on line 101 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L100-L101

Added lines #L100 - L101 were not covered by tests
[
torch.tensor(rcuts, device=env.DEVICE),
torch.tensor(nsels, device=env.DEVICE),
],
dim=0,
).T
inner_sorting = torch.argsort(zipped[:, 1], dim=0)
inner_sorted = zipped[inner_sorting]
outer_sorting = torch.argsort(inner_sorted[:, 0], stable=True)
Expand Down Expand Up @@ -285,7 +296,7 @@
self.smin_alpha = smin_alpha

# this is a placeholder being updated in _compute_weight, to handle Jit attribute init error.
self.zbl_weight = torch.empty(0, dtype=torch.float64)
self.zbl_weight = torch.empty(0, dtype=torch.float64, device=env.DEVICE)

Check warning on line 299 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L299

Added line #L299 was not covered by tests

def serialize(self) -> dict:
return {
Expand Down
8 changes: 6 additions & 2 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
FittingOutputDef,
OutputVariableDef,
)
from deepmd.pt.utils import (

Check warning on line 15 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L15

Added line #L15 was not covered by tests
env,
)
from deepmd.utils.pair_tab import (
PairTab,
)
Expand Down Expand Up @@ -156,15 +159,16 @@
pairwise_rr = self._get_pairwise_dist(
extended_coord, masked_nlist
) # (nframes, nloc, nnei)
self.tab_data = self.tab_data.view(
self.tab_data = self.tab_data.to(device=env.DEVICE).view(

Check warning on line 162 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L162

Added line #L162 was not covered by tests
int(self.tab_info[-1]), int(self.tab_info[-1]), int(self.tab_info[2]), 4
)

# to calculate the atomic_energy, we need 3 tensors, i_type, j_type, pairwise_rr
# i_type : (nframes, nloc), this is atype.
# j_type : (nframes, nloc, nnei)
j_type = extended_atype[
torch.arange(extended_atype.size(0))[:, None, None], masked_nlist
torch.arange(extended_atype.size(0), device=env.DEVICE)[:, None, None],
masked_nlist,
]

raw_atomic_energy = self._pair_tabulated_inter(
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@
sel = [sel] if isinstance(sel, int) else sel
self.nnei = sum(sel)
assert len(sel) == 1
self.sel = torch.tensor(sel)
self.sel = torch.tensor(sel, device=env.DEVICE)

Check warning on line 329 in deepmd/pt/model/descriptor/repformer_layer.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformer_layer.py#L329

Added line #L329 was not covered by tests
self.sec = self.sel
self.axis_dim = axis_dim
self.set_davg_zero = set_davg_zero
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ def forward(
) # shape is [nframes*nall, self.ndescrpt]
xyz_scatter = torch.empty(
1,
device=env.DEVICE,
)
ret = self.filter_layers_old[0](dmatrix)
xyz_scatter = ret
Expand Down
6 changes: 4 additions & 2 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,10 @@ def _format_nlist(
nlist,
-1
* torch.ones(
[n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype
).to(nlist.device),
[n_nf, n_nloc, nnei - n_nnei],
dtype=nlist.dtype,
device=nlist.device,
),
],
dim=-1,
)
Expand Down
13 changes: 10 additions & 3 deletions deepmd/pt/model/network/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


def Tensor(*shape):
return torch.empty(shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION)
return torch.empty(shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)

Check warning on line 35 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L35

Added line #L35 was not covered by tests


class Dropout(nn.Module):
Expand Down Expand Up @@ -332,7 +332,13 @@
bias: bool = True,
init: str = "default",
):
super().__init__(d_in, d_out, bias=bias, dtype=env.GLOBAL_PT_FLOAT_PRECISION)
super().__init__(

Check warning on line 335 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L335

Added line #L335 was not covered by tests
d_in,
d_out,
bias=bias,
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
)

self.use_bias = bias

Expand Down Expand Up @@ -552,6 +558,7 @@
embed_dim,
padding_idx=type_nums,
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
)
# nn.init.normal_(self.embedding.weight[:-1], mean=bavg, std=stddev)

Expand Down Expand Up @@ -799,7 +806,7 @@
temperature=temperature,
)
self.attn_layer_norm = nn.LayerNorm(
self.embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION
self.embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
if self.ffn:
self.ffn_embed_dim = ffn_embed_dim
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@
bias_atom_e = np.zeros([self.ntypes])
if not use_tebd:
assert self.ntypes == len(bias_atom_e), "Element count mismatches!"
bias_atom_e = torch.tensor(bias_atom_e)
bias_atom_e = torch.tensor(bias_atom_e, device=env.DEVICE)

Check warning on line 271 in deepmd/pt/model/task/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/ener.py#L271

Added line #L271 was not covered by tests
self.register_buffer("bias_atom_e", bias_atom_e)

filter_layers_dipole = []
Expand Down
9 changes: 6 additions & 3 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@
drop_last=False,
pin_memory=True,
)
training_data_buffered = BufferedIterator(iter(training_dataloader))
with torch.device("cpu"):
training_data_buffered = BufferedIterator(iter(training_dataloader))

Check warning on line 160 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L159-L160

Added lines #L159 - L160 were not covered by tests
validation_dataloader = DataLoader(
_validation_data,
sampler=valid_sampler,
Expand All @@ -166,7 +167,8 @@
pin_memory=True,
)

validation_data_buffered = BufferedIterator(iter(validation_dataloader))
with torch.device("cpu"):
validation_data_buffered = BufferedIterator(iter(validation_dataloader))

Check warning on line 171 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L170-L171

Added lines #L170 - L171 were not covered by tests
if _training_params.get("validation_data", None) is not None:
valid_numb_batch = _training_params["validation_data"].get(
"numb_btch", 1
Expand Down Expand Up @@ -519,7 +521,8 @@
if not torch.isfinite(grad_norm).all():
# check local gradnorm single GPU case, trigger NanDetector
raise FloatingPointError("gradients are Nan/Inf")
self.optimizer.step()
with torch.device("cpu"):
self.optimizer.step()

Check warning on line 525 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L524-L525

Added lines #L524 - L525 were not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved
self.scheduler.step()
elif self.opt_type == "LKF":
if isinstance(self.loss, EnergyStdLoss):
Expand Down
8 changes: 5 additions & 3 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@
self.total_batch += len(system_dataloader)
# Initialize iterator instances for DataLoader
self.iters = []
for item in self.dataloaders:
self.iters.append(iter(item))
with torch.device("cpu"):
for item in self.dataloaders:
self.iters.append(iter(item))

Check warning on line 125 in deepmd/pt/utils/dataloader.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/dataloader.py#L123-L125

Added lines #L123 - L125 were not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved

def set_noise(self, noise_settings):
# noise_settings['noise_type'] # "trunc_normal", "normal", "uniform"
Expand Down Expand Up @@ -250,5 +251,6 @@
log.info("Generated weighted sampler with prob array: " + str(probs))
# training_data.total_batch is the size of one epoch, you can increase it to avoid too many rebuilding of iteraters
len_sampler = training_data.total_batch * max(env.NUM_WORKERS, 1)
sampler = WeightedRandomSampler(probs, len_sampler, replacement=True)
with torch.device("cpu"):
sampler = WeightedRandomSampler(probs, len_sampler, replacement=True)

Check warning on line 255 in deepmd/pt/utils/dataloader.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/dataloader.py#L254-L255

Added lines #L254 - L255 were not covered by tests
return sampler
8 changes: 4 additions & 4 deletions deepmd/pt/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,14 @@
nlist = nlist[:, :, :nsel]
else:
rr = torch.cat(
[rr, torch.ones([batch_size, nloc, nsel - nnei]).to(rr.device) + rcut],
[rr, torch.ones([batch_size, nloc, nsel - nnei], device=rr.device) + rcut],
dim=-1,
)
nlist = torch.cat(
[
nlist,
torch.ones([batch_size, nloc, nsel - nnei], dtype=nlist.dtype).to(
rr.device
torch.ones(
[batch_size, nloc, nsel - nnei], dtype=nlist.dtype, device=rr.device
),
],
dim=-1,
Expand Down Expand Up @@ -289,7 +289,7 @@

"""
nf, nloc = atype.shape
aidx = torch.tile(torch.arange(nloc).unsqueeze(0), [nf, 1])
aidx = torch.tile(torch.arange(nloc, device=env.DEVICE).unsqueeze(0), [nf, 1])

Check warning on line 292 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L292

Added line #L292 was not covered by tests
if cell is None:
nall = nloc
extend_coord = coord.clone()
Expand Down
21 changes: 11 additions & 10 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,17 @@
log.info(f"Packing data for statistics from {len(datasets)} systems")
for i in range(len(datasets)):
sys_stat = {key: [] for key in keys}
iterator = iter(dataloaders[i])
for _ in range(nbatches):
try:
stat_data = next(iterator)
except StopIteration:
iterator = iter(dataloaders[i])
stat_data = next(iterator)
for dd in stat_data:
if dd in keys:
sys_stat[dd].append(stat_data[dd])
with torch.device("cpu"):
iterator = iter(dataloaders[i])
for _ in range(nbatches):
try:
stat_data = next(iterator)
except StopIteration:
iterator = iter(dataloaders[i])
stat_data = next(iterator)
for dd in stat_data:
if dd in keys:
sys_stat[dd].append(stat_data[dd])

Check warning on line 45 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L35-L45

Added lines #L35 - L45 were not covered by tests
for key in keys:
if not isinstance(sys_stat[key][0], list):
if sys_stat[key][0] is None:
Expand Down
2 changes: 2 additions & 0 deletions source/tests/pt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@

torch.set_num_threads(1)
torch.set_num_interop_threads(1)
# testing purposes; device should always be set explicitly
torch.set_default_device("cuda:9999999")
Loading