Skip to content

Commit

Permalink
add wandb and file
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Oct 8, 2024
1 parent c99caae commit e18b5b0
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 23 deletions.
3 changes: 3 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,11 @@ def train(FLAGS):
)

# argcheck
wandb_config = config["training"].pop("wandb_config", None)
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config, multi_task=multi_task)
if wandb_config is not None:
config["training"]["wandb_config"] = wandb_config

# do neighbor stat
min_nbor_dist = None
Expand Down
94 changes: 72 additions & 22 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@
if torch.__version__.startswith("2"):
import torch._dynamo

import os

import torch.distributed as dist
import wandb as wb
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import (
DataLoader,
Expand Down Expand Up @@ -146,6 +149,33 @@ def __init__(
)
self.lcurve_should_print_header = True

# Init wandb
self.wandb_config = training_params.get("wandb_config", {})
self.wandb_enabled = self.wandb_config.get("wandb_enabled", False)
self.wandb_log_model = self.wandb_config.get("wandb_log_model", False)
self.wandb_log_model_freq = self.wandb_config.get("wandb_log_model_freq", 100)
if self.wandb_enabled:
entity = self.wandb_config.get("entity", None)
assert (
entity is not None
), "The parameter 'entity' of wandb must be specified."
project = self.wandb_config.get("project", None)
assert (
project is not None
), "The parameter 'project' of wandb must be specified."
job_name = self.wandb_config.get("job_name", None)
if job_name is None:
name_path = os.path.abspath(".").split("/")
job_name = name_path[-2] + "/" + name_path[-1]
if self.rank == 0:
wb.init(
project=project,
entity=entity,
config=model_params,
name=job_name,
settings=wb.Settings(start_method="fork"),
)

def get_opt_param(params):
opt_type = params.get("opt_type", "Adam")
opt_param = {
Expand Down Expand Up @@ -317,14 +347,14 @@ def get_lr(lr_params):
self.validation_data,
self.valid_numb_batch,
) = get_data_loader(training_data, validation_data, training_params)
training_data.print_summary(
"training", to_numpy_array(self.training_dataloader.sampler.weights)
)
if validation_data is not None:
validation_data.print_summary(
"validation",
to_numpy_array(self.validation_dataloader.sampler.weights),
)
# training_data.print_summary(
# "training", to_numpy_array(self.training_dataloader.sampler.weights)
# )
# if validation_data is not None:
# validation_data.print_summary(
# "validation",
# to_numpy_array(self.validation_dataloader.sampler.weights),
# )
else:
(
self.training_dataloader,
Expand Down Expand Up @@ -360,20 +390,20 @@ def get_lr(lr_params):
training_params["data_dict"][model_key],
)

training_data[model_key].print_summary(
f"training in {model_key}",
to_numpy_array(self.training_dataloader[model_key].sampler.weights),
)
if (
validation_data is not None
and validation_data[model_key] is not None
):
validation_data[model_key].print_summary(
f"validation in {model_key}",
to_numpy_array(
self.validation_dataloader[model_key].sampler.weights
),
)
# training_data[model_key].print_summary(
# f"training in {model_key}",
# to_numpy_array(self.training_dataloader[model_key].sampler.weights),
# )
# if (
# validation_data is not None
# and validation_data[model_key] is not None
# ):
# validation_data[model_key].print_summary(
# f"validation in {model_key}",
# to_numpy_array(
# self.validation_dataloader[model_key].sampler.weights
# ),
# )

# Learning rate
self.warmup_steps = training_params.get("warmup_steps", 0)
Expand Down Expand Up @@ -658,6 +688,8 @@ def run(self):
with_stack=True,
)
prof.start()
if self.wandb_enabled and self.wandb_log_model and self.rank == 0:
wb.watch(self.wrapper, log="all", log_freq=self.wandb_log_model_freq)

def step(_step_id, task_key="Default"):
# PyTorch Profiler
Expand Down Expand Up @@ -827,6 +859,7 @@ def log_loss_valid(_task_key="Default"):
learning_rate=cur_lr,
)
)
self.wandb_log(train_results, _step_id, "_train")
if valid_results:
log.info(
format_training_message_per_task(
Expand All @@ -836,12 +869,16 @@ def log_loss_valid(_task_key="Default"):
learning_rate=None,
)
)
self.wandb_log(valid_results, _step_id, "_valid")
else:
train_results = {_key: {} for _key in self.model_keys}
valid_results = {_key: {} for _key in self.model_keys}
train_results[task_key] = log_loss_train(
loss, more_loss, _task_key=task_key
)
self.wandb_log(
train_results[task_key], _step_id, f"_train_{task_key}"
)
for _key in self.model_keys:
if _key != task_key:
self.optimizer.zero_grad()
Expand All @@ -867,6 +904,9 @@ def log_loss_valid(_task_key="Default"):
learning_rate=cur_lr,
)
)
self.wandb_log(
train_results[_key], _step_id, f"_train_{_key}"
)
if valid_results[_key]:
log.info(
format_training_message_per_task(
Expand All @@ -876,6 +916,9 @@ def log_loss_valid(_task_key="Default"):
learning_rate=None,
)
)
self.wandb_log(
valid_results[_key], _step_id, f"_valid_{_key}"
)

current_time = time.time()
train_time = current_time - self.t0
Expand All @@ -887,6 +930,7 @@ def log_loss_valid(_task_key="Default"):
wall_time=train_time,
)
)
self.wandb_log({"lr": cur_lr}, step_id)
# the first training time is not accurate
if (
_step_id + 1
Expand Down Expand Up @@ -1125,6 +1169,12 @@ def get_data(self, is_train=True, task_key="Default"):
log_dict["sid"] = batch_data["sid"]
return input_dict, label_dict, log_dict

def wandb_log(self, data: dict, step, type_suffix=""):
if not self.wandb_enabled or self.rank != 0:
return
for k, v in data.items():
wb.log({k + type_suffix: v}, step=step)

def print_header(self, fout, train_results, valid_results):
train_keys = sorted(train_results.keys())
print_str = ""
Expand Down
2 changes: 1 addition & 1 deletion deepmd/utils/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def _load_h5py(cls, path: str, mode: str = "r") -> h5py.File:
# this method has cache to avoid duplicated
# loading from different DPH5Path
# However the file will be never closed?
return h5py.File(path, mode)
return h5py.File(path, mode, locking=False)

def load_numpy(self) -> np.ndarray:
"""Load NumPy array.
Expand Down

0 comments on commit e18b5b0

Please sign in to comment.