From 6c0567c9025f40e58edd1075ece9f7115c9a7b3d Mon Sep 17 00:00:00 2001 From: wolearyc Date: Fri, 20 Sep 2024 13:27:43 -0700 Subject: [PATCH] added progress bar for trajectory.get_spectrum --- cspell.json | 1 + ramannoodle/polarizability/torch/gnn.py | 18 ++++++++++++------ test/tests/torch/test_gnn.py | 2 +- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/cspell.json b/cspell.json index 00d5d66..04135c7 100644 --- a/cspell.json +++ b/cspell.json @@ -72,6 +72,7 @@ "Softplus", "spglib", "subbatch", + "subbatchs", "symprec", "tablefmt", "timestep", diff --git a/ramannoodle/polarizability/torch/gnn.py b/ramannoodle/polarizability/torch/gnn.py index 051b844..2ed4956 100644 --- a/ramannoodle/polarizability/torch/gnn.py +++ b/ramannoodle/polarizability/torch/gnn.py @@ -672,9 +672,13 @@ def calc_polarizabilities( self.eval() polarizabilities = torch.zeros((positions_batch.shape[0], 3, 3)) - for subbatch_index, positions_subbatch in tqdm( - enumerate(rn_torch_utils.batch_positions(positions_batch, batch_size=100)) - ): + positions_subbatchs = rn_torch_utils.batch_positions( + positions_batch, batch_size=100 + ) + progress_bar = tqdm(total=positions_batch.shape[0], unit=" configs") + + start_index = 0 + for positions_subbatch in positions_subbatchs: subbatch_size = positions_subbatch.shape[0] lattice = torch.tensor(self._ref_structure.lattice) @@ -693,10 +697,12 @@ def calc_polarizabilities( polarizability = rn_torch_utils.polarizability_vectors_to_tensors( polarizability ) - - start_index = subbatch_index * subbatch_size - end_index = start_index + positions_subbatch.shape[0] + end_index = start_index + subbatch_size polarizabilities[start_index:end_index] = polarizability.detach() + progress_bar.update(end_index - start_index) + start_index = end_index + + progress_bar.close() result = polarizabilities.detach().numpy() * self._stddev_polarizability result += self._mean_polarizability diff --git a/test/tests/torch/test_gnn.py b/test/tests/torch/test_gnn.py index 83a9cdd..10261f5 100644 --- a/test/tests/torch/test_gnn.py +++ b/test/tests/torch/test_gnn.py @@ -161,7 +161,7 @@ def test_calc_polarizabilities( model = PotGNN(ref_structure, 2, 5, 5, 5, 0, 5, np.zeros((3, 3)), np.ones((3, 3))) model.eval() - for batch_size in [50, 100, 200]: + for batch_size in [50, 100, 180]: # Generate random data. num_atoms = len(ref_structure.atomic_numbers)