Skip to content

Commit

Permalink
Revert "fix: bugs in uts for property fit (deepmodeling#4120)"
Browse files Browse the repository at this point in the history
This reverts commit 96ed5df.
  • Loading branch information
theAfish committed Sep 21, 2024
1 parent 81b9d20 commit d5a7c04
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 175 deletions.
2 changes: 0 additions & 2 deletions backend/read_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ def get_argument_from_env() -> Tuple[str, list, list, dict, str, str]:
cmake_minimum_required_version = "3.21"
cmake_args.append("-DUSE_ROCM_TOOLKIT:BOOL=TRUE")
rocm_root = os.environ.get("ROCM_ROOT")
if not rocm_root:
rocm_root = os.environ.get("ROCM_PATH")
if rocm_root:
cmake_args.append(f"-DCMAKE_HIP_COMPILER_ROCM_ROOT:STRING={rocm_root}")
hipcc_flags = os.environ.get("HIP_HIPCC_FLAGS")
Expand Down
21 changes: 21 additions & 0 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,21 @@ def get_lr(lr_params):
# JIT
if JIT:
self.model = torch.jit.script(self.model)

# Initialize the fparam
if model_params["fitting_net"]["numb_fparam"] > 0:
nbatches = 10
datasets = training_data.systems
dataloaders = training_data.dataloaders
fparams = []
for i in range(len(datasets)):
iterator = iter(dataloaders[i])
numb_batches = min(nbatches, len(dataloaders[i]))
for _ in range(numb_batches):
stat_data = next(iterator)
fparams.append(stat_data['fparam'])
fparams = torch.tensor(fparams)
init_fparam(self.model, fparams)

# Model Wrapper
self.wrapper = ModelWrapper(self.model, self.loss, model_params=model_params)
Expand Down Expand Up @@ -1212,6 +1227,12 @@ def get_additional_data_requirement(_model):
return additional_data_requirement


def init_fparam(_model, fparams):
fitting = _model.get_fitting_net()
fitting['fparam_avg'] = torch.unsqueeze(torch.mean(fparams, dim=0), dim=-1).to(DEVICE)
fitting['fparam_inv_std'] = torch.unsqueeze(1. / torch.std(fparams, dim=0), dim=-1).to(DEVICE)


def get_loss(loss_params, start_lr, _ntypes, _model):
loss_type = loss_params.get("type", "ener")
if loss_type == "ener":
Expand Down
3 changes: 1 addition & 2 deletions doc/install/install-from-source.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,7 @@ The path to the CUDA toolkit directory. CUDA 9.0 or later is supported. NVCC is

**Type**: Path; **Default**: Detected automatically

The path to the ROCM toolkit directory. If `ROCM_ROOT` is not set, it will look for `ROCM_PATH`; if `ROCM_PATH` is also not set, it will be detected using `hipconfig --rocmpath`.

The path to the ROCM toolkit directory.
:::

:::{envvar} DP_ENABLE_TENSORFLOW
Expand Down
Loading

0 comments on commit d5a7c04

Please sign in to comment.