From bd56d274c4748c846b07616d0595a38a76ab3c50 Mon Sep 17 00:00:00 2001 From: alitinet Date: Thu, 19 Dec 2024 17:09:15 +0100 Subject: [PATCH] cleaned up --- src/multigrate/model/_multivae.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/multigrate/model/_multivae.py b/src/multigrate/model/_multivae.py index ca1d5b7..58a136c 100644 --- a/src/multigrate/model/_multivae.py +++ b/src/multigrate/model/_multivae.py @@ -261,31 +261,20 @@ def get_model_output(self, adata=None, batch_size=256, save_unimodal_params=Fals inference_inputs = self.module._get_inference_input(tensors) outputs = self.module.inference(**inference_inputs) z = outputs["z_joint"] - # print('z_joint') - # print(outputs["z_joint"].shape) if save_unimodal_latent is True: - # print('z_marginal') - # print(len(outputs["z_marginal"])) - # print(outputs["z_marginal"].shape) z_marginal += [outputs["z_marginal"].cpu()] if save_unimodal_params is True: - # print('params marginal') - # print(outputs["mu_marginal"].shape) - # print(outputs["logvar_marginal"].shape) mu_marginal += [outputs["mu_marginal"].cpu()] logvar_marginal += [outputs["logvar_marginal"].cpu()] latent += [z.cpu()] if save_unimodal_latent is True: z_marginal = torch.cat(z_marginal) - print(z_marginal.shape) for i in range(z_marginal.shape[1]): adata.obsm[f"X_unimodal_{i}"] = z_marginal[:, i, :].squeeze(1).numpy() if save_unimodal_params is True: mu_marginal = torch.cat(mu_marginal) logvar_marginal = torch.cat(logvar_marginal) - print(mu_marginal.shape) - print(logvar_marginal.shape) for i in range(mu_marginal.shape[1]): adata.obsm[f"mu_unimodal_{i}"] = mu_marginal[:, i, :].squeeze(1).numpy() adata.obsm[f"logvar_unimodal_{i}"] = logvar_marginal[:, i, :].squeeze(1).numpy()