Skip to content

Commit

Permalink
added progress bar for trajectory.get_spectrum
Browse files Browse the repository at this point in the history
  • Loading branch information
wolearyc committed Sep 26, 2024
1 parent f29b962 commit 6c0567c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 7 deletions.
1 change: 1 addition & 0 deletions cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
"Softplus",
"spglib",
"subbatch",
"subbatchs",
"symprec",
"tablefmt",
"timestep",
Expand Down
18 changes: 12 additions & 6 deletions ramannoodle/polarizability/torch/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,9 +672,13 @@ def calc_polarizabilities(
self.eval()

polarizabilities = torch.zeros((positions_batch.shape[0], 3, 3))
for subbatch_index, positions_subbatch in tqdm(
enumerate(rn_torch_utils.batch_positions(positions_batch, batch_size=100))
):
positions_subbatchs = rn_torch_utils.batch_positions(
positions_batch, batch_size=100
)
progress_bar = tqdm(total=positions_batch.shape[0], unit=" configs")

start_index = 0
for positions_subbatch in positions_subbatchs:
subbatch_size = positions_subbatch.shape[0]

lattice = torch.tensor(self._ref_structure.lattice)
Expand All @@ -693,10 +697,12 @@ def calc_polarizabilities(
polarizability = rn_torch_utils.polarizability_vectors_to_tensors(
polarizability
)

start_index = subbatch_index * subbatch_size
end_index = start_index + positions_subbatch.shape[0]
end_index = start_index + subbatch_size
polarizabilities[start_index:end_index] = polarizability.detach()
progress_bar.update(end_index - start_index)
start_index = end_index

progress_bar.close()

result = polarizabilities.detach().numpy() * self._stddev_polarizability
result += self._mean_polarizability
Expand Down
2 changes: 1 addition & 1 deletion test/tests/torch/test_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def test_calc_polarizabilities(
model = PotGNN(ref_structure, 2, 5, 5, 5, 0, 5, np.zeros((3, 3)), np.ones((3, 3)))
model.eval()

for batch_size in [50, 100, 200]:
for batch_size in [50, 100, 180]:

# Generate random data.
num_atoms = len(ref_structure.atomic_numbers)
Expand Down

0 comments on commit 6c0567c

Please sign in to comment.