Skip to content

Commit

Permalink
refactor(ai2bmd): sanitise logging
Browse files Browse the repository at this point in the history
remove 'DEBUG_RC' environment variable -- level of output is controlled
solely by the '--verbose' argument. reassign logging levels to different
output, based on these (rough) guidelines:

1: verbose output for fragmentation stage
2: DEBUG-level logging for the main process
3: output from ViSNet/tinker servers
4: verbose output for preprocess stage
  • Loading branch information
bi-ran committed Dec 9, 2024
1 parent ec7b172 commit 04c65d3
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 46 deletions.
3 changes: 0 additions & 3 deletions src/AIMD/envflags.py

This file was deleted.

12 changes: 7 additions & 5 deletions src/AIMD/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
from typing import List, Tuple

from AIMD import arguments, envflags
from AIMD import arguments
from Calculators.device_strategy import DeviceStrategy
from utils.pdb import reorder_atoms, standardise_pdb, translate_coord_pdb, reorder_coord_amber2tinker
from utils.system import get_physical_core_count
Expand All @@ -18,7 +18,7 @@ def run_command(command: str, cwd_path: str) -> None:
It is more safe than os.system.
"""

if envflags.DEBUG_RC:
if arguments.get().verbose >= 4:
print("run_command: ", command)

proc = subprocess.Popen(
Expand All @@ -37,11 +37,13 @@ def run_command(command: str, cwd_path: str) -> None:
'Failed with command "{}" failed in '
""
"{} with error code {}"
"stdout: {}"
"stderr: {}".format(command, path, proc.returncode, out, err)
"stdout:"
"{}"
"stderr:"
"{}".format(command, path, proc.returncode, out, err)
)
raise ValueError(msg)
elif envflags.DEBUG_RC:
elif arguments.get().verbose >= 4:
print('-------------- stdout -----------------')
print(out)
print('-------------- stderr -----------------')
Expand Down
5 changes: 2 additions & 3 deletions src/AIMD/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ase.md.nvtberendsen import NVTBerendsen
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution

from AIMD import arguments, envflags
from AIMD import arguments
from AIMD.protein import Protein
from Calculators.device_strategy import DeviceStrategy
from Calculators.fragment import FragmentCalculator
Expand Down Expand Up @@ -200,8 +200,7 @@ def simulate(
if build_frames and not restart:
self.build_frames_from_traj(prot_name, record_per_steps, MolDyn.nsteps)

if not envflags.DEBUG_RC:
shutil.rmtree(os.path.join(self.log_path, "SimulationResults"))
shutil.rmtree(os.path.join(self.log_path, "SimulationResults"))

def build_frames_from_traj(self, prot_name, record_per_steps, nsteps):
print("Building frames from trajectory...")
Expand Down
2 changes: 1 addition & 1 deletion src/Calculators/nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __call__(self, prot: Protein) -> tuple[np.float32, np.ndarray]:
force. Non-bonded forces are calculated by calculating the gradient of
the non-bonded energy with respect to the atom positions.
"""
pos = torch.FloatTensor(prot.get_positions()).to(self.device)
pos = torch.tensor(prot.get_positions(), dtype=torch.float, device=self.device)

vec = pos[self.dst] - pos[self.src]
d2 = (vec**2).sum(-1)
Expand Down
8 changes: 4 additions & 4 deletions src/Calculators/pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ def __init__(self, beta=0.3, cutoff=9.0, grid_spacing=1.0, device="cpu") -> None

def set_parameters(self, prot: Protein) -> None:
# Initialize parameters
sigmas = torch.FloatTensor(prot.sigmas)
epsilons = torch.FloatTensor(prot.epsilons)
charges = torch.FloatTensor(prot.charges)
sigmas = torch.tensor(prot.sigmas, dtype=torch.float)
epsilons = torch.tensor(prot.epsilons, dtype=torch.float)
charges = torch.tensor(prot.charges, dtype=torch.float)
self.sigmas = sigmas.to(self.device)
self.epsilons = epsilons.to(self.device)
self.charges = charges.to(self.device)
Expand All @@ -159,7 +159,7 @@ def __call__(self, prot: Protein) -> tuple[np.float32, np.ndarray]:
Non-bonded forces are calculated by grading the non-bonded energy
with respect of the atom positions.
"""
pos_cpu = torch.FloatTensor(prot.get_positions())
pos_cpu = torch.tensor(prot.get_positions(), dtype=torch.float)
pos = pos_cpu.to(self.device)
src, dst = radius_graph(
pos, self.cutoff, max_num_neighbors=len(pos), loop=False
Expand Down
16 changes: 7 additions & 9 deletions src/Calculators/tinker_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ase.calculators.calculator import Calculator
from ase.units import kcal, mol

from AIMD import envflags
from AIMD import arguments
from AIMD.preprocess import run_command
from Calculators.async_utils import AsyncServer

Expand All @@ -33,13 +33,14 @@ def __init__(self, pdb_file, utils_dir, devices: list[str], **kwargs):
self.devices = devices
self._tinker_proc = None
self.atoms: Atoms
self.server = AsyncServer("tinker")
self.logger = getLogger("Tinker-Proxy")

global _tinker_instance_id
self.instance_id = _tinker_instance_id
_tinker_instance_id += 1

self.server = AsyncServer("tinker")
self.logger = getLogger(f"Tinker-Proxy-{self.instance_id}")

if any(map(lambda x: x.startswith('cuda'), devices)):
self.command_dir = '/usr/local/gpu-m'
else:
Expand Down Expand Up @@ -67,18 +68,15 @@ def _start_tinker(self):
elif len(gpus) == 1:
envs["CUDA_VISIBLE_DEVICES"] = gpus[0]

outfd = None
stderrfd = open(os.devnull, 'wb')
log_args = "<< _EOF\n" if envflags.DEBUG_RC else f" > dynamic{self.instance_id}.log << _EOF\n"
outfd = None if arguments.get().verbose >= 3 else subprocess.DEVNULL
self._tinker_proc = subprocess.Popen(
f"{self.command_dir}/tinker9 ai2bmd {self.prot_name} -k {self.prot_name} "
f"{log_args}"
f"{self.command_dir}/tinker9 ai2bmd {self.prot_name} -k {self.prot_name} << _EOF\n"
f"{self.server.socket_path}\n" # unix socket for IPC
f"_EOF",
shell=True,
env=envs,
stdout=outfd,
stderr=stderrfd,
stderr=outfd,
)
self.logger.debug('Waiting for Tinker to start...')
self.server.accept()
Expand Down
15 changes: 3 additions & 12 deletions src/Calculators/visnet_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@

import numpy as np
import torch
import torch.multiprocessing as mp
from ase.calculators.calculator import Calculator

from AIMD import envflags
from AIMD import arguments
from AIMD.fragment import FragmentData
from Calculators.device_strategy import DeviceStrategy
from Calculators.async_utils import AsyncServer, AsyncClient
Expand Down Expand Up @@ -69,7 +68,6 @@ def from_file(cls, **kwargs):
raise ValueError("model_path must be provided")

model_path = kwargs["model_path"]

device = kwargs.get("device", "cpu")

model = load_model(model_path)
Expand All @@ -87,7 +85,7 @@ def __init__(self, model_path: str, device: str):
self.logger = getLogger("ViSNet-Proxy")
envs = os.environ.copy()
envs["PYTHONPATH"] = f"{osp.abspath(osp.join(osp.dirname(__file__), '..'))}:{envs['PYTHONPATH']}"
outfd = subprocess.PIPE if not envflags.DEBUG_RC else None
outfd = None if arguments.get().verbose >= 3 else subprocess.DEVNULL
# use __file__ as process so that viztracer-patched subprocess doesn't track us
# this file should have chmod +x
self.proc = subprocess.Popen(
Expand Down Expand Up @@ -158,27 +156,19 @@ def calculate(self, atoms, properties, system_changes):


if __name__ == "__main__":
mp.set_sharing_strategy("file_system")
mp.set_start_method("spawn")

parser = argparse.ArgumentParser("ViSNet proxy")
parser.add_argument("--model-path", type=str, required=True)
parser.add_argument("--device", type=str, required=True)
parser.add_argument("--socket-path", type=str, required=True)
args = parser.parse_args()

logger = getLogger("AI2BMD-ViSNet-Worker")
logger.info(f'model-path: {args.model_path}')
logger.info(f'device: {args.device}')

kwargs = {
'model_path': args.model_path,
'device': args.device,
}
calculator = ViSNetModel.from_file(**kwargs)
client = AsyncClient(args.socket_path)
# start serving
logger.info('Start serving.')
try:
while True:
data: FragmentData = client.recv_object()
Expand All @@ -190,6 +180,7 @@ def calculate(self, atoms, properties, system_changes):
ViSNetModelLike = Union[ViSNetModel, ViSNetAsyncModel]
_local_calc: dict[str, ViSNetModel] = {}


def get_visnet_model(model_path: str, device: str):
# allow up to 1 copy of GPU model to run in the master process
device_sig = device
Expand Down
6 changes: 3 additions & 3 deletions src/ViSNet/model/visnet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
import torch.nn as nn
from torch import IntTensor, Tensor
from torch import Tensor
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter

Expand Down Expand Up @@ -102,9 +102,9 @@ def reset_parameters(self):

def forward(self, data: dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
# z, pos, batch = data['z'], data['pos'], data['batch']
z: IntTensor = data['z']
z: Tensor = data['z']
pos: Tensor = data['pos']
batch: IntTensor = data['batch']
batch: Tensor = data['batch']

# Embedding Layers
x = self.embedding(z)
Expand Down
12 changes: 6 additions & 6 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import logging
import os
import time
import warnings

import torch.multiprocessing as mp

from AIMD import arguments, envflags
from AIMD import arguments
from AIMD.preprocess import Preprocess
from AIMD.protein import Protein
from AIMD.simulator import SolventSimulator, NoSolventSimulator
Expand All @@ -18,11 +17,12 @@
mp.set_start_method("spawn")

args = arguments.init()
if not envflags.DEBUG_RC:
warnings.filterwarnings("ignore")
logging.disable(logging.WARNING)
else:
if args.verbose >= 2:
logging.basicConfig(level=logging.DEBUG)
elif args.verbose >= 1:
logging.basicConfig(level=logging.INFO)
else:
logging.basicConfig(level=logging.ERROR)

logfile = os.path.join(args.log_dir, f"main-{time.strftime('%Y%m%d-%H%M%S')}.log")
redir_output(logfile)
Expand Down

0 comments on commit 04c65d3

Please sign in to comment.