diff --git a/src/beignet/bond_angles.py b/src/beignet/bond_angles.py new file mode 100644 index 0000000000..b154477d28 --- /dev/null +++ b/src/beignet/bond_angles.py @@ -0,0 +1,46 @@ +import torch +from torch import Tensor + + +def bond_angles(input: Tensor, angle_indices: Tensor) -> Tensor: + r""" + Compute the bond angles between the supplied triplets of indices in each frame of a trajectory using PyTorch. + + Parameters + ---------- + input : Tensor + Trajectory tensor with shape=(n_frames, n_atoms, 3). + angle_indices : Tensor + Tensor of shape=(num_angles, 3), each row consists of indices of three atoms. + + Returns + ------- + angles : Tensor + Angles for each specified group of indices, shape=(n_frames, num_angles). Angles are in radians. + """ + # Data verification + num_frames, n_atoms, _ = input.shape + if torch.any(angle_indices >= n_atoms) or torch.any(angle_indices < 0): + raise ValueError("angle_indices must be between 0 and %d" % (n_atoms - 1)) + + if angle_indices.shape[0] == 0: + return torch.zeros((num_frames, 0), dtype=torch.float32) + + # Initializing the output tensor + angles = torch.zeros((num_frames, angle_indices.shape[0]), dtype=torch.float32) + + # Gathering vectors related to the angle calculation + vec1 = input[:, angle_indices[:, 1]] - input[:, angle_indices[:, 0]] + vec2 = input[:, angle_indices[:, 1]] - input[:, angle_indices[:, 2]] + + # Normalize the vectors + vec1_norm = torch.norm(vec1, dim=2, keepdim=True) + vec2_norm = torch.norm(vec2, dim=2, keepdim=True) + vec1_unit = vec1 / vec1_norm + vec2_unit = vec2 / vec2_norm + + # Compute angles using arccos of dot products + dot_products = torch.sum(vec1_unit * vec2_unit, dim=2) + angles = torch.acos(dot_products) + + return angles diff --git a/src/beignet/center_of_mass.py b/src/beignet/center_of_mass.py new file mode 100644 index 0000000000..a77498ba5c --- /dev/null +++ b/src/beignet/center_of_mass.py @@ -0,0 +1,23 @@ +from torch import Tensor + + +def center_of_mass(input: Tensor, masses: Tensor) -> Tensor: + r"""Compute the center of mass for each frame. + + Parameters + ---------- + input : Tensor + A tensor of shape (n_frames, n_atoms, 3) which contains the XYZ coordinates of atoms in each frame. + masses : Tensor + A tensor of shape (n_atoms,) containing the masses of each atom. + + Returns + ------- + output : Tensor, shape=(n_frames, 3) + A tensor of shape (n_frames, 3) with the coordinates of the center of mass for each frame. + """ + total_mass = masses.sum() + + weighted_positions = input * masses[:, None] + + return weighted_positions.sum(dim=1) / total_mass diff --git a/src/beignet/dihedrals.py b/src/beignet/dihedrals.py new file mode 100644 index 0000000000..37406e301b --- /dev/null +++ b/src/beignet/dihedrals.py @@ -0,0 +1,56 @@ +import math + +import torch +from torch import Tensor + + +def dihedrals(input: Tensor, indices: Tensor) -> Tensor: + r""" + Compute the dihedral angles between specified quartets of atoms in each frame of a trajectory using PyTorch. + + Parameters + ---------- + input : Tensor + A tensor representing the trajectory with shape (n_frames, n_atoms, 3). + indices : Tensor + Each row gives the indices of four atoms which, together, define a dihedral angle, + shape (n_dihedrals, 4). + + Returns + ------- + Tensor + A tensor of dihedral angles in radians, shape = (n_frames, n_dihedrals). + """ + n_frames, n_atoms, _ = input.shape + if torch.any(indices >= n_atoms) or torch.any(indices < 0): + raise ValueError("indices must be between 0 and %d" % (n_atoms - 1)) + + if len(indices) == 0: + return torch.zeros((n_frames, 0), dtype=torch.float32) + + # Get vectors between atoms + vec1 = input[:, indices[:, 1]] - input[:, indices[:, 0]] + vec2 = input[:, indices[:, 2]] - input[:, indices[:, 1]] + vec3 = input[:, indices[:, 3]] - input[:, indices[:, 2]] + + # Compute normals to the planes defined by the first three and last three atoms + normal1 = torch.cross(vec1, vec2, dim=2) + normal2 = torch.cross(vec2, vec3, dim=2) + + # Compute norms and check for zero to avoid division by zero + norm1 = torch.norm(normal1, dim=2, keepdim=True) + norm2 = torch.norm(normal2, dim=2, keepdim=True) + normal1 = torch.where(norm1 > 0, normal1 / norm1, normal1) + normal2 = torch.where(norm2 > 0, normal2 / norm2, normal2) + + cosine = torch.sum(normal1 * normal2, dim=2) + cosine = torch.clamp(cosine, -1.0, 1.0) + + cross = torch.cross(normal1, normal2, dim=2) + sine = torch.norm(cross, dim=2) * torch.sign(torch.sum(cross * vec2, dim=2)) + + # Handle case where the cross product is zero - indicating collinear points + angle = torch.atan2(sine, cosine) + angle = torch.where((norm1 > 0) & (norm2 > 0), angle, torch.tensor(math.pi)) + + return angle diff --git a/src/beignet/dipole_moments.py b/src/beignet/dipole_moments.py new file mode 100644 index 0000000000..31c35442d8 --- /dev/null +++ b/src/beignet/dipole_moments.py @@ -0,0 +1,36 @@ +import torch +from torch import Tensor + + +def dipole_moments(input: Tensor, charges: Tensor) -> Tensor: + """ + Calculate the dipole moments of each frame in a tensor-represented trajectory using PyTorch. + + Parameters + ---------- + input : Tensor + A tensor representing the trajectory with shape (n_frames, n_atoms, 3). + charges : Tensor + Charges of each atom in the trajectory. Shape (n_atoms,), units of elementary charges. + + Returns + ------- + moments : Tensor + Dipole moments of trajectory, units of nm * elementary charge, shape (n_frames, 3). + + Notes + ----- + This function performs a straightforward calculation of the dipole moments based on the input atomic positions + and charges. The dipole moment is calculated as the sum of charge-weighted atomic positions for each frame. + """ + # Ensure charges is at least 2D: (n_atoms, 1) + charges = charges.view(-1, 1) + + # Calculate the weighted positions by charges for each frame + weighted_positions = input * charges + + # TODO (isaacoh) this is broken for symmetric atoms w/ mixed charges + # Sum over all atoms to get the dipole moment per frame + moments = torch.sum(weighted_positions, dim=1) + + return moments diff --git a/src/beignet/gyration_tensor.py b/src/beignet/gyration_tensor.py new file mode 100644 index 0000000000..b9e36f7ef7 --- /dev/null +++ b/src/beignet/gyration_tensor.py @@ -0,0 +1,50 @@ +import torch +from torch import Tensor + + +def _compute_center_of_geometry(input: Tensor) -> Tensor: + r"""Compute the center of geometry for each frame. + Parameters + ---------- + input : Tensor + Trajectory to compute center of geometry for, shape=(n_frames, n_atoms, 3) + + Returns + ------- + centers : Tensor, shape=(n_frames, 3) + Coordinates of the center of geometry for each frame. + """ + centers = torch.mean(input, dim=1) + return centers + + +def gyration_tensor(input: Tensor) -> Tensor: + """Compute the gyration tensor of a trajectory. + + Parameters + ---------- + input : Tensor + Trajectory for which to compute gyration tensor, shape=(n_frames, n_atoms, 3) + + Returns + ------- + gyration_tensors: Tensor, shape=(n_frames, 3, 3) + Gyration tensors for each frame. + + References + ---------- + .. [1] https://isg.nist.gov/deepzoomweb/measurement3Ddata_help#shape-metrics-formulas + """ + n_frames, n_atoms, _ = input.shape + center_of_geometry = _compute_center_of_geometry(input).unsqueeze(1) + + # Translate the atoms by subtracting the center of geometry + translated_trajectory = input - center_of_geometry + + # Compute gyration tensor for each frame + gyration_tensors = ( + torch.einsum("nij,nik->njk", translated_trajectory, translated_trajectory) + / n_atoms + ) + + return gyration_tensors diff --git a/src/beignet/rmsd.py b/src/beignet/rmsd.py new file mode 100644 index 0000000000..05c5c2e78b --- /dev/null +++ b/src/beignet/rmsd.py @@ -0,0 +1,44 @@ +from torch import Tensor +import torch + + +# TODO (isaacsoh) parallelize and speed up, eliminate 3-D requirement +def rmsd(traj1: Tensor, traj2: Tensor): + r""" + Compute the Root Mean Square Deviation (RMSD) between two trajectories. + + Parameters + ---------- + traj1 : Tensor + First trajectory tensor, shape (num_frames, num_atoms, dim). + traj2 : Tensor + Second trajectory tensor (reference), same shape as traj1. + + Returns + ------- + rmsd_result : Tensor + The RMSD calculation of two trajectories. + """ + assert traj1.shape == traj2.shape, "Input tensors must have the same shape" + + num_frames = traj1.shape[0] + rmsd_result = torch.zeros(num_frames) + + for i in range(num_frames): + traj1_centered = traj1[i] - traj1[i].mean(dim=0, keepdim=True) + traj2_centered = traj2[i] - traj2[i].mean(dim=0, keepdim=True) + + u, s, vh = torch.linalg.svd(torch.mm(traj1_centered.t(), traj2_centered)) + d = torch.sign(torch.det(torch.mm(vh.t(), u.t()))) + + if d < 0: + vh[:, -1] *= -1 + + rot_matrix = torch.mm(vh.t(), u.t()) + traj2_rotated = torch.mm(traj2_centered, rot_matrix) + + rmsd = torch.sqrt(((traj1_centered - traj2_rotated) ** 2).sum(dim=1).mean()) + + rmsd_result[i] = rmsd + + return rmsd_result diff --git a/tests/beignet/test_bond_angles.py b/tests/beignet/test_bond_angles.py new file mode 100644 index 0000000000..96f0d9cf1a --- /dev/null +++ b/tests/beignet/test_bond_angles.py @@ -0,0 +1,62 @@ +import torch +import pytest + +from beignet.bond_angles import bond_angles + + +def radians(degrees): + """Utility function to convert degrees to radians.""" + return degrees * torch.pi / 180 + + +def test_straight_line_angle(): + # Tests three collinear points which must produce an angle of pi radians (180 degrees) + traj = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]]], dtype=torch.float32) + angle_indices = torch.tensor([[0, 1, 2]]) + + expected_angles = torch.tensor([[radians(180)]]) + computed_angles = bond_angles(traj, angle_indices) + + assert torch.allclose( + computed_angles, expected_angles + ), "Should calculate 180 degrees for collinear points." + + +def test_right_angle(): + # Tests an L shape (right angle, 90 degrees) + traj = torch.tensor([[[0, 0, 0], [1, 0, 0], [1, 1, 0]]], dtype=torch.float32) + angle_indices = torch.tensor([[0, 1, 2]]) + + expected_angles = torch.tensor([[radians(90)]]) + computed_angles = bond_angles(traj, angle_indices) + + assert torch.allclose( + computed_angles, expected_angles, atol=1e-5 + ), "Should calculate 90 degrees for orthogonal vectors." + + +def test_acute_angle(): + # Acute angle test 45 degrees + traj = torch.tensor( + [[[0, 0, 0], [1, 0, 0], [1, torch.sqrt(torch.tensor(2.0)), 0]]], + dtype=torch.float32, + ) + angle_indices = torch.tensor([[0, 1, 2]]) + + expected_angles = torch.tensor([[radians(45)]]) + computed_angles = bond_angles(traj, angle_indices) + + assert torch.allclose( + computed_angles, expected_angles, atol=1e-5 + ), "Should calculate 45 degrees for acute angle." + + +def test_no_indices_provided(): + # Providing no indices should return an empty tensor + traj = torch.randn(1, 10, 3) + angle_indices = torch.empty((0, 3), dtype=torch.int32) + + computed_angles = bond_angles(traj, angle_indices) + assert ( + computed_angles.nelement() == 0 + ), "Providing no indices should result in an empty tensor." diff --git a/tests/beignet/test_center_of_mass.py b/tests/beignet/test_center_of_mass.py new file mode 100644 index 0000000000..0c7d7c0da7 --- /dev/null +++ b/tests/beignet/test_center_of_mass.py @@ -0,0 +1,42 @@ +import torch + +from beignet.center_of_mass import center_of_mass + + +def test_center_of_mass_basic(): + positions = torch.tensor( + [ + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], + [[0.0, 0.0, 0.0], [2.0, 2.0, 0.0]], + ] + ) + masses = torch.tensor([1.0, 1.0]) + + expected_com = torch.tensor( + [ + [0.5, 0.5, 0.0], + [1.0, 1.0, 0.0], + ] + ) + + com = center_of_mass(positions, masses) + + assert torch.allclose(com, expected_com) + + +def test_center_of_mass_shape(): + positions = torch.randn(10, 5, 3) + masses = torch.rand(5) + + com = center_of_mass(positions, masses) + + assert com.shape == (10, 3) + + +def test_center_of_mass_at_origin(): + positions = torch.zeros(3, 4, 3) + masses = torch.rand(4) + + com = center_of_mass(positions, masses) + + assert torch.all(com == 0) diff --git a/tests/beignet/test_dihedrals.py b/tests/beignet/test_dihedrals.py new file mode 100644 index 0000000000..9f5b57a287 --- /dev/null +++ b/tests/beignet/test_dihedrals.py @@ -0,0 +1,63 @@ +import torch +import pytest + +from beignet.dihedrals import dihedrals + + +def radians(degrees): + """Utility function to convert degrees to radians.""" + return degrees * torch.pi / 180 + + +def test_dihedral_180_degrees(): + # Tests four collinear points which should result in a dihedral angle of 180 degrees (pi radians) + traj = torch.tensor( + [[[0, 0, 0], [1, 0, 0], [2, 0, 0], [3, 0, 0]]], dtype=torch.float32 + ) + indices = torch.tensor([[0, 1, 2, 3]]) + + expected_angles = torch.tensor([[radians(180)]]) + computed_angles = dihedrals(traj, indices) + + assert torch.allclose( + computed_angles, expected_angles + ), "Dihedral angle should be 180 degrees for collinear points." + + +def test_dihedral_90_degrees(): + # Configuration that should result in a dihedral angle of 90 degrees + traj = torch.tensor( + [[[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.float32 + ) # A right-angle turn at the second atom + indices = torch.tensor([[0, 1, 2, 3]]) + + expected_angles = torch.tensor([[radians(90)]]) + computed_angles = dihedrals(traj, indices) + + assert torch.allclose( + torch.abs(computed_angles), expected_angles, atol=1e-5 + ), "Dihedral angle should be 90 degrees." + + +def test_no_indices_provided(): + # Providing no indices should return an empty tensor + traj = torch.randn(1, 10, 3) + indices = torch.empty((0, 4), dtype=torch.int32) + + computed_angles = dihedrals(traj, indices) + assert ( + computed_angles.nelement() == 0 + ), "Providing no dihedral indices should result in an empty tensor." + + +def test_index_out_of_bounds(): + # Providing indices out of bounds should raise an error + traj = torch.randn(1, 4, 3) + indices = torch.tensor([[0, 1, 5, 3]]) # Index 5 is out of bounds + + with pytest.raises(ValueError): + dihedrals(traj, indices) + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/beignet/test_dipole_moments.py b/tests/beignet/test_dipole_moments.py new file mode 100644 index 0000000000..6784914683 --- /dev/null +++ b/tests/beignet/test_dipole_moments.py @@ -0,0 +1,50 @@ +import torch +import pytest + +from beignet.dipole_moments import dipole_moments + + +def test_basic_dipole_moments(): + # This test case assumes a simple setup where the math can be easily verified. + positions = torch.tensor( + [[[1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]], [[0.0, 1.0, 0.0], [0.0, -1.0, 0.0]]], + dtype=torch.float32, + ) # Two frames, two atoms + charges = torch.tensor( + [1.0, -1.0], dtype=torch.float32 + ) # Positive and negative charges + + expected_dipoles = torch.tensor( + [[2.0, 0.0, 0.0], [0.0, 2.0, 0.0]], dtype=torch.float32 + ) + computed_dipoles = dipole_moments(positions, charges) + + assert torch.allclose( + computed_dipoles, expected_dipoles + ), "Basic dipole moments calculation failed." + + +def test_zero_charges(): + # Tests the scenario where all charges are zero - resulting in zero dipole moment. + positions = torch.randn(1, 5, 3) # One frame, five atoms, arbitrary positions + charges = torch.zeros(5, dtype=torch.float32) # Zero charges + + expected_dipoles = torch.zeros(1, 3, dtype=torch.float32) + computed_dipoles = dipole_moments(positions, charges) + + assert torch.allclose( + computed_dipoles, expected_dipoles + ), "Dipole moment should be zero when all charges are zero." + + +def test_negative_and_positive_charges(): + # Mixed charges but symmetrically arranged atoms ensuring zero dipole moment + positions = torch.tensor([[1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]], dtype=torch.float32) + charges = torch.tensor([1.0, -1.0], dtype=torch.float32) + + expected_dipoles = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32) + computed_dipoles = dipole_moments(positions.unsqueeze(0), charges) + + assert torch.allclose( + computed_dipoles, expected_dipoles + ), "Dipole moment should be zero for symmetric charges and positions." diff --git a/tests/beignet/test_gyration_tensor.py b/tests/beignet/test_gyration_tensor.py new file mode 100644 index 0000000000..1e43733a8b --- /dev/null +++ b/tests/beignet/test_gyration_tensor.py @@ -0,0 +1,58 @@ +import torch + +from beignet.gyration_tensor import gyration_tensor, _compute_center_of_geometry + + +def test_center_of_geometry_origin(): + # Scenario where all atoms are at the origin + traj = torch.zeros(1, 10, 3) # 1 frame, 10 atoms, all at origin + expected_center = torch.zeros(1, 3) + computed_center = _compute_center_of_geometry(traj) + assert torch.allclose( + computed_center, expected_center + ), "Center of geometry should be at origin for zeroed data." + + +def test_gyration_tensor_origin(): + # All atoms at the origin, expecting zero gyration tensor + traj = torch.zeros(1, 10, 3) + expected_gyration = torch.zeros(1, 3, 3) + computed_gyration = gyration_tensor(traj) + assert torch.allclose( + computed_gyration, expected_gyration + ), "Gyration tensor should be zero for zeroed data." + + +def test_center_of_geometry_translation(): + # Translate structure along x-axis + traj = torch.zeros(1, 10, 3) + torch.tensor([[[5.0, 0.0, 0.0]]]) + expected_center = torch.tensor([[5.0, 0.0, 0.0]]) + computed_center = _compute_center_of_geometry(traj) + assert torch.allclose( + computed_center, expected_center + ), "Translation mismatch in calculated center of geometry." + + +def test_gyration_tensor_translation(): + # Gyration tensor for translated atoms that are spaced uniformly along x from 1 to 10 + traj = torch.arange(1, 11).float().view(1, 10, 1).expand(-1, -1, 3) + centers = _compute_center_of_geometry(traj) + + deviations = traj[:, :, 0] - centers[:, 0].unsqueeze(1) + expected_s_xx = torch.sum(deviations**2, dim=1) / traj.shape[1] + + computed_gyration = gyration_tensor(traj) + + assert torch.allclose( + computed_gyration[:, 0, 0], expected_s_xx + ), "Gyration tensor calculation error after translation." + + +def test_random_data(): + # Check with random data to ensure no errors occur in general usage + traj = torch.randn(5, 50, 3) + assert gyration_tensor(traj).shape == ( + 5, + 3, + 3, + ), "Unexpected output shape for gyration tensor with random data." diff --git a/tests/beignet/test_rmsd.py b/tests/beignet/test_rmsd.py new file mode 100644 index 0000000000..bcb76138d8 --- /dev/null +++ b/tests/beignet/test_rmsd.py @@ -0,0 +1,39 @@ +import torch +from beignet.rmsd import rmsd + + +def test_rmsd_2d_case(): + traj1 = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]) + traj2 = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]) + expected_rmsd = torch.zeros(2) + out = rmsd(traj1, traj2) + assert torch.allclose( + out, expected_rmsd, atol=1e-5 + ), "RMSD should be zero for identical trajectories" + + +def test_different_2d_configurations(): + """Test RMSD for genuinely different 2D configurations (misalignment not recoverable by translation or rotation).""" + traj1 = torch.tensor([[[0.0, 0.0], [1.0, 0.0]]]) # In line along the x-axis + traj2 = torch.tensor( + [[[0.0, 1.0], [1.0, 1.0]]] + ) # In line but offset along the y-axis + out = rmsd(traj1, traj2) + print("RMSD:", rmsd) + assert not torch.allclose( + out, torch.zeros(1), atol=1e-5 + ), "RMSD should not be zero for misaligned configurations" + + +def test_rmsd_3d_case(): + traj1 = torch.tensor( + [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]] + ) + traj2 = torch.tensor( + [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]] + ) + expected_rmsd = torch.zeros(2) + out = rmsd(traj1, traj2) + assert torch.allclose( + out, expected_rmsd, atol=1e-5 + ), "RMSD should be zero for identical trajectories"