Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: Replace custom Trainer with Pytorch Lightning #22

Merged
merged 12 commits into from
Jan 16, 2024
4 changes: 2 additions & 2 deletions ptychonn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# package is not installed
pass

from ptychonn._infer.__main__ import infer, stitch_from_inference, Tester
from ptychonn._train.__main__ import Trainer
from ptychonn._infer.__main__ import infer, stitch_from_inference
from ptychonn._train.__main__ import train, init_or_load_model, create_training_dataloader, ListLogger, create_model_checkpoint
from ptychonn.model import *
from ptychonn.plot import *
128 changes: 37 additions & 91 deletions ptychonn/_infer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import typing
import glob

from torch.utils.data import TensorDataset, DataLoader
import click
import lightning
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -128,30 +128,34 @@ def infer_cli(

inferences = infer(
data=data,
model_params_path=params_path,
model=ptychonn.init_or_load_model(
ptychonn.model.LitReconSmallModel,
model_checkpoint_path=params_path,
model_init_params=None,
)
)

# Plotting some summary images

pstitched = stitch_from_inference(
inferences[:, 0],
scan,
stitched_pixel_width=1,
inference_pixel_width=1,
)
astitched = stitch_from_inference(
inferences[:, 1],
scan,
stitched_pixel_width=1,
inference_pixel_width=1,
)

# Plotting some summary images
plt.figure(1, figsize=[8.5, 7])
plt.imshow(pstitched)
plt.colorbar()
plt.tight_layout()
plt.title('stitched_phases')
plt.savefig(out_dir / 'pstitched.png', bbox_inches='tight')

astitched = stitch_from_inference(
inferences[:, 1],
scan,
stitched_pixel_width=1,
inference_pixel_width=1,
)
plt.figure(2, figsize=[8.5, 7])
plt.imshow(astitched)
plt.colorbar()
Expand All @@ -177,8 +181,8 @@ def infer_cli(


def infer(
data: npt.NDArray,
model_params_path: pathlib.Path,
data: npt.NDArray[np.float32],
model: lightning.LightningModule,
*,
inferences_out_file: typing.Optional[pathlib.Path] = None,
) -> npt.NDArray:
Expand All @@ -190,10 +194,15 @@ def infer(
Set the CUDA_VISIBLE_DEVICES environment variable to control which GPUs
will be used.

Initialize a model for the model parameter using the `init_or_load_model()`
function.

Parameters
----------
data : (POSITION, WIDTH, HEIGHT)
Diffraction patterns to be reconstructed.
model
An initialized PtychoNN model.
inferences_out_file : pathlib.Path
Optional file to save reconstructed patches.

Expand All @@ -202,82 +211,19 @@ def infer(
inferences : (POSITION, 2, WIDTH, HEIGHT)
The reconstructed patches inferred by the model.
'''
tester = Tester(
model=ptychonn.model.ReconSmallModel(),
model_params_path=model_params_path,
)
tester.setTestData(
data,
batch_size=max(torch.cuda.device_count(), 1) * 64,
)
return tester.predictTestData(npz_save_path=inferences_out_file)


class Tester():

def __init__(
self,
*,
model: torch.nn.Module,
model_params_path: pathlib.Path,
):
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
print(f"Let's use {torch.cuda.device_count()} GPUs!")

self.model = model

params = torch.load(
model_params_path,
map_location=self.device,
)
self.model.load_state_dict(params)

self.model = torch.nn.DataParallel(self.model)

self.model.to(self.device)

self.model.eval()

def setTestData(self, X_test: np.ndarray, batch_size: int):
self.X_test = torch.tensor(X_test[:, None, ...], dtype=torch.float32)
self.test_data = TensorDataset(self.X_test)
self.testloader = DataLoader(
self.test_data,
batch_size=batch_size,
shuffle=False,
)

def predictTestData(self, npz_save_path: str = None):

phs_eval = []
with torch.inference_mode():
for (ft_images, ) in self.testloader:
ph_eval = self.model(ft_images.to(self.device))
phs_eval.append(ph_eval.detach().cpu().numpy())

self.phs_eval = np.concatenate(phs_eval, axis=0)

if npz_save_path is not None:
np.savez_compressed(npz_save_path, ph=self.phs_eval)
print(f'Finished the inference stage and saved at {npz_save_path}')

return self.phs_eval

def calcErrors(self, phs_true: np.ndarray, npz_save_path: str = None):
from skimage.metrics import mean_squared_error as mse

self.phs_true = phs_true
self.errors = []
for i, (p1, p2) in enumerate(zip(self.phs_eval, self.phs_true)):
err2 = mse(p1, p2)
self.errors.append([err2])

self.errors = np.array(self.errors)
print("Mean errors in phase")
print(np.mean(self.errors, axis=0))

if npz_save_path is not None:
np.savez_compressed(npz_save_path, phs_err=self.errors[:, 0])

return self.errors
model.eval()
result = []
with torch.no_grad():
for batch in data:
result.append(
model(torch.from_numpy(batch[None, None, :, :]).to("cuda"))
.cpu()
.numpy()
)

result = np.concatenate(result, axis=0)

if inferences_out_file is not None:
np.save(inferences_out_file, result)

return result
Loading