Skip to content

Commit

Permalink
fix pt
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Mar 26, 2024
1 parent f6e0bfb commit f33503b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 4 additions & 1 deletion deepmd/dpmodel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
"int32": np.int32,
"int64": np.int64,
"default": GLOBAL_NP_FLOAT_PRECISION,
# NumPy doesn't have bfloat16 (and does't plan to add). Use float32 as a substitute.
# NumPy doesn't have bfloat16 (and does't plan to add)
# ml_dtypes is a solution, but it seems not supporting np.save/np.load
# hdf5 hasn't supported bfloat16 as well (see https://forum.hdfgroup.org/t/11975)
# Use float32 as a substitute.
"bfloat16": np.float32,
}
assert VALID_PRECISION.issubset(PRECISION_DICT.keys())
Expand Down
5 changes: 4 additions & 1 deletion deepmd/pt/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.nn.functional as F

from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT
from deepmd.dpmodel.common import RESERVED_PRECISON_DICT as NP_RESERVED_PRECISON_DICT

from .env import (
DEVICE,
Expand Down Expand Up @@ -103,7 +104,9 @@ def to_torch_tensor(
return None
assert xx is not None
# Create a reverse mapping of NP_PRECISION_DICT
reverse_precision_dict = {v: k for k, v in NP_PRECISION_DICT.items()}
# unsafe considering bfloat16:
# reverse_precision_dict = {v: k for k, v in NP_PRECISION_DICT.items()}
reverse_precision_dict = NP_RESERVED_PRECISON_DICT
# Use the reverse mapping to find keys with the desired value
prec = reverse_precision_dict.get(xx.dtype.type, None)
prec = PT_PRECISION_DICT.get(prec, None)
Expand Down

0 comments on commit f33503b

Please sign in to comment.