Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

here is the comit #64

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .env.template

This file was deleted.

6 changes: 3 additions & 3 deletions cdvae/common/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def frac_to_cart_coords(
angles,
num_atoms,
):
lattice = lattice_params_to_matrix_torch(lengths, angles)
lattice = lattice_params_to_matrix_torch(lengths, angles) #lattice matrix
lattice_nodes = torch.repeat_interleave(lattice, num_atoms, dim=0)
pos = torch.einsum('bi,bij->bj', frac_coords, lattice_nodes) # cart coords

Expand Down Expand Up @@ -293,7 +293,7 @@ def get_pbc_distances(

distance_vectors = pos[j_index] - pos[i_index]

# correct for pbc
# correct for pbc; added comment: if we change the distance vectors we need to account for offsets
lattice_edges = torch.repeat_interleave(lattice, num_bonds, dim=0)
offsets = torch.einsum('bi,bij->bj', to_jimages.float(), lattice_edges)
distance_vectors += offsets
Expand All @@ -302,7 +302,7 @@ def get_pbc_distances(
distances = distance_vectors.norm(dim=-1)

out = {
"edge_index": edge_index,
"edge_index": edge_index, #edge index stays the same
"distances": distances,
}

Expand Down
3 changes: 1 addition & 2 deletions cdvae/pl_data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def __init__(
self.train_dataset: Optional[Dataset] = None
self.val_datasets: Optional[Sequence[Dataset]] = None
self.test_datasets: Optional[Sequence[Dataset]] = None

self.get_scaler(scaler_path)

def prepare_data(self) -> None:
Expand All @@ -66,6 +65,7 @@ def get_scaler(self, scaler_path):
self.scaler = get_scaler_from_data_list(
train_dataset.cached_data,
key=train_dataset.prop)
self.train_dataset = train_dataset
else:
self.lattice_scaler = torch.load(
Path(scaler_path) / 'lattice_scaler.pt')
Expand All @@ -76,7 +76,6 @@ def setup(self, stage: Optional[str] = None):
construct datasets and assign data scalers.
"""
if stage is None or stage == "fit":
self.train_dataset = hydra.utils.instantiate(self.datasets.train)
self.val_datasets = [
hydra.utils.instantiate(dataset_cfg)
for dataset_cfg in self.datasets.val
Expand Down
111 changes: 90 additions & 21 deletions cdvae/pl_modules/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,48 @@ def build_mlp(in_dim, hidden_dim, fc_num_layers, out_dim):
mods += [nn.Linear(hidden_dim, out_dim)]
return nn.Sequential(*mods)

def split_atoms(input_tensor, num_atoms, latent_dim,max_num_atoms=20):
output_tensor = torch.randn(num_atoms.sum(), latent_dim) # Shape: (num_atoms.sum(), latent_dim)
# Split the output tensor into individual molecule tensors
molecule_tensors = torch.split(output_tensor, num_atoms.tolist())

# Pad the tensors to ensure they all have the same shape (max_num_atoms, latent_dim)
max_num_atoms = 20
padded_tensors = []
masks = []

for mol_tensor in molecule_tensors:
num_atoms_in_molecule = mol_tensor.shape[0]

# Pad the tensor
padding = (0, 0, 0, max_num_atoms - num_atoms_in_molecule) # Pad only the second dimension
padded_tensor = F.pad(mol_tensor, padding, "constant", 0)
padded_tensors.append(padded_tensor)

# Create the mask
mask = torch.zeros(max_num_atoms, dtype=float)
mask[:num_atoms_in_molecule] = 1
masks.append(mask)

# Stack the padded tensors and masks to form the final batch tensors
padded_batch_tensor = torch.stack(padded_tensors) # Shape: (batch_size, max_num_atoms, latent_dim)
mask_tensor = torch.stack(masks) # Shape: (batch_size, max_num_atoms)
return padded_batch_tensor, mask_tensor
class GemNetTDecoder(nn.Module):
"""Decoder with GemNetT."""

def __init__(
self,
hidden_dim=128,
latent_dim=256,
hidden_dim=64,
latent_dim=128,
max_neighbors=20,
radius=6.,
scale_file=None,
):
super(GemNetTDecoder, self).__init__()
self.cutoff = radius
self.max_num_neighbors = max_neighbors

self.latent_dim = latent_dim
self.gemnet = GemNetT(
num_targets=1,
latent_dim=latent_dim,
Expand All @@ -40,10 +66,17 @@ def __init__(
otf_graph=True,
scale_file=scale_file,
)
self.fc_atom = nn.Linear(hidden_dim, MAX_ATOMIC_NUM)

def forward(self, z, pred_frac_coords, pred_atom_types, num_atoms,
lengths, angles):
self.fc_atom = build_mlp(hidden_dim, hidden_dim, 2, MAX_ATOMIC_NUM)
# self.fc_lengths = nn.Linear(hidden_dim, 3)
# self.fc_angles = nn.Linear(hidden_dim, 3)
#other way:
# self.fc_atom = build_mlp(latent_dim, latent_dim*2, 1, MAX_ATOMIC_NUM)
# self.fc_lengths = build_mlp(latent_dim*20, latent_dim, 1, 3)
# self.fc_angles = build_mlp(latent_dim*20, latent_dim, 1, 3)
# self.fc_hidden = build_mlp(hidden_dim, latent_dim*2, 1, latent_dim)
# self.len_attention = nn.MultiheadAttention(hidden_dim, num_heads=4)
# self.angles_attention = nn.MultiheadAttention(hidden_dim, num_heads=4)
def forward(self, z, pred_frac_coords, pred_atom_types, num_atoms, lengths, angles, batch):
"""
args:
z: (N_cryst, num_latent)
Expand All @@ -56,17 +89,53 @@ def forward(self, z, pred_frac_coords, pred_atom_types, num_atoms,
atom_frac_coords: (N_atoms, 3)
atom_types: (N_atoms, MAX_ATOMIC_NUM)
"""
# (num_atoms, hidden_dim) (num_crysts, 3)
h, pred_cart_coord_diff = self.gemnet(
z=z,
frac_coords=pred_frac_coords,
atom_types=pred_atom_types,
num_atoms=num_atoms,
lengths=lengths,
angles=angles,
edge_index=None,
to_jimages=None,
num_bonds=None,
)
pred_atom_types = self.fc_atom(h)
return pred_cart_coord_diff, pred_atom_types
try:
# Attempt using predicted lengths and angles
h, pred_cart_coords = self.gemnet(
z=z,
frac_coords=pred_frac_coords,
atom_types=pred_atom_types,
num_atoms=num_atoms,
lengths=lengths,
angles=angles,
edge_index=None,
to_jimages=None,
num_bonds=None,
)
pred_atom_types = self.fc_atom(h)
return pred_cart_coords, pred_atom_types

except Exception as e:
print(f"Prediction error: {str(e)}. Attempting to use ground truth lengths and angles.")

try:
ground_truth_angles = batch.angles
ground_truth_lengths = batch.lengths
ground_truth_coords = batch.frac_coords #* .9+.1 *pred_frac_coords
# Attempt using ground truth lengths and angles
h, pred_cart_coords = self.gemnet(
z=z,
frac_coords=ground_truth_coords,
atom_types=pred_atom_types,
num_atoms=num_atoms,
lengths=ground_truth_lengths,
angles=ground_truth_angles,
edge_index=None,
to_jimages=None,
num_bonds=None,
)
pred_atom_types = self.fc_atom(h)
return pred_cart_coords, pred_atom_types

except Exception as inner_e:
# If both attempts fail, log the error and skip the batch
print(f"Error in batch after fallback: {str(inner_e)}")
return None, None

# pred_atom_types = self.fc_atom(insider)
# #split:
# insider_split, mask = split_atoms(insider, num_atoms, self.latent_dim)
# insider_view = insider_split.view(-1, self.latent_dim*20) #warning bsz
# pred_lengths, pred_angles = self.fc_lengths(insider_view), self.fc_angles(insider_view)
# if transformer:
# pred_lengths, pred_angles = self.len_attention(insider, insider, insider), self.angles_attention(insider, insider, insider)
94 changes: 94 additions & 0 deletions cdvae/pl_modules/decoderdenoiser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from cdvae.pl_modules.embeddings import MAX_ATOMIC_NUM
from cdvae.pl_modules.gemnet.gemnetdenoiser import GemNetTDenoiser

def build_mlp(in_dim, hidden_dim, fc_num_layers, out_dim):
mods = [nn.Linear(in_dim, hidden_dim), nn.ReLU()]
for i in range(fc_num_layers-1):
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU()]
mods += [nn.Linear(hidden_dim, out_dim)]
return nn.Sequential(*mods)

def split_atoms(input_tensor, num_atoms, latent_dim,max_num_atoms=20):
output_tensor = torch.randn(num_atoms.sum(), latent_dim) # Shape: (num_atoms.sum(), latent_dim)
# Split the output tensor into individual molecule tensors
molecule_tensors = torch.split(output_tensor, num_atoms.tolist())

# Pad the tensors to ensure they all have the same shape (max_num_atoms, latent_dim)
max_num_atoms = 20
padded_tensors = []
masks = []

for mol_tensor in molecule_tensors:
num_atoms_in_molecule = mol_tensor.shape[0]

# Pad the tensor
padding = (0, 0, 0, max_num_atoms - num_atoms_in_molecule) # Pad only the second dimension
padded_tensor = F.pad(mol_tensor, padding, "constant", 0)
padded_tensors.append(padded_tensor)

# Create the mask
mask = torch.zeros(max_num_atoms, dtype=float)
mask[:num_atoms_in_molecule] = 1
masks.append(mask)

# Stack the padded tensors and masks to form the final batch tensors
padded_batch_tensor = torch.stack(padded_tensors) # Shape: (batch_size, max_num_atoms, latent_dim)
mask_tensor = torch.stack(masks) # Shape: (batch_size, max_num_atoms)
return padded_batch_tensor, mask_tensor
class GemNetTDenoiserDecoder(nn.Module):
"""Denoiser with GemNetT."""

def __init__(
self,
hidden_dim=64,
latent_dim=128,
max_neighbors=20,
radius=6.,
scale_file=None,
):
super(GemNetTDenoiserDecoder, self).__init__()
self.cutoff = radius
self.max_num_neighbors = max_neighbors
self.latent_dim = latent_dim
self.gemnet = GemNetTDenoiser(
num_targets=1,
latent_dim=latent_dim,
emb_size_atom=hidden_dim,
emb_size_edge=hidden_dim,
regress_forces=True,
cutoff=self.cutoff,
max_neighbors=self.max_num_neighbors,
otf_graph=True,
scale_file=scale_file,
)
def forward(self, z, pred_frac_coords, pred_atom_types, num_atoms, lengths, angles, batch, timesteps):
"""
args:
z: (N_cryst, num_latent)
pred_frac_coords: (N_atoms, 3)
pred_atom_types: (N_atoms, ), need to use atomic number e.g. H = 1
num_atoms: (N_cryst,)
lengths: (N_cryst, 3)
angles: (N_cryst, 3)
returns:
atom_frac_coords: (N_atoms, 3)
atom_types: (N_atoms, MAX_ATOMIC_NUM)
"""
# Attempt using predicted lengths and angles
pred_z_a_noise, pred_z_x_noise = self.gemnet(
z=z,
frac_coords=pred_frac_coords,
atom_types=pred_atom_types,
num_atoms=num_atoms,
lengths=lengths,
angles=angles,
edge_index=None,
to_jimages=None,
num_bonds=None,
timesteps=timesteps
)
return pred_z_a_noise, pred_z_x_noise
2 changes: 1 addition & 1 deletion cdvae/pl_modules/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
from .continuous_embeddings import CONTINUOUS_EMBEDDINGS
from .khot_embeddings import KHOT_EMBEDDINGS

MAX_ATOMIC_NUM = 100
MAX_ATOMIC_NUM = 100 #actual max is 94 but 100 is for stability
Loading