Skip to content

Commit

Permalink
Merge pull request #22 from carterbox/lightning
Browse files Browse the repository at this point in the history
API: Replace custom Trainer with Pytorch Lightning
  • Loading branch information
carterbox authored Jan 16, 2024
2 parents 40ba697 + 051f4dc commit 78e4ec9
Show file tree
Hide file tree
Showing 7 changed files with 341 additions and 523 deletions.
4 changes: 2 additions & 2 deletions ptychonn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# package is not installed
pass

from ptychonn._infer.__main__ import infer, stitch_from_inference, Tester
from ptychonn._train.__main__ import Trainer
from ptychonn._infer.__main__ import infer, stitch_from_inference
from ptychonn._train.__main__ import train, init_or_load_model, create_training_dataloader, ListLogger, create_model_checkpoint
from ptychonn.model import *
from ptychonn.plot import *
128 changes: 37 additions & 91 deletions ptychonn/_infer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import typing
import glob

from torch.utils.data import TensorDataset, DataLoader
import click
import lightning
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -128,30 +128,34 @@ def infer_cli(

inferences = infer(
data=data,
model_params_path=params_path,
model=ptychonn.init_or_load_model(
ptychonn.model.LitReconSmallModel,
model_checkpoint_path=params_path,
model_init_params=None,
)
)

# Plotting some summary images

pstitched = stitch_from_inference(
inferences[:, 0],
scan,
stitched_pixel_width=1,
inference_pixel_width=1,
)
astitched = stitch_from_inference(
inferences[:, 1],
scan,
stitched_pixel_width=1,
inference_pixel_width=1,
)

# Plotting some summary images
plt.figure(1, figsize=[8.5, 7])
plt.imshow(pstitched)
plt.colorbar()
plt.tight_layout()
plt.title('stitched_phases')
plt.savefig(out_dir / 'pstitched.png', bbox_inches='tight')

astitched = stitch_from_inference(
inferences[:, 1],
scan,
stitched_pixel_width=1,
inference_pixel_width=1,
)
plt.figure(2, figsize=[8.5, 7])
plt.imshow(astitched)
plt.colorbar()
Expand All @@ -177,8 +181,8 @@ def infer_cli(


def infer(
data: npt.NDArray,
model_params_path: pathlib.Path,
data: npt.NDArray[np.float32],
model: lightning.LightningModule,
*,
inferences_out_file: typing.Optional[pathlib.Path] = None,
) -> npt.NDArray:
Expand All @@ -190,10 +194,15 @@ def infer(
Set the CUDA_VISIBLE_DEVICES environment variable to control which GPUs
will be used.
Initialize a model for the model parameter using the `init_or_load_model()`
function.
Parameters
----------
data : (POSITION, WIDTH, HEIGHT)
Diffraction patterns to be reconstructed.
model
An initialized PtychoNN model.
inferences_out_file : pathlib.Path
Optional file to save reconstructed patches.
Expand All @@ -202,82 +211,19 @@ def infer(
inferences : (POSITION, 2, WIDTH, HEIGHT)
The reconstructed patches inferred by the model.
'''
tester = Tester(
model=ptychonn.model.ReconSmallModel(),
model_params_path=model_params_path,
)
tester.setTestData(
data,
batch_size=max(torch.cuda.device_count(), 1) * 64,
)
return tester.predictTestData(npz_save_path=inferences_out_file)


class Tester():

def __init__(
self,
*,
model: torch.nn.Module,
model_params_path: pathlib.Path,
):
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
print(f"Let's use {torch.cuda.device_count()} GPUs!")

self.model = model

params = torch.load(
model_params_path,
map_location=self.device,
)
self.model.load_state_dict(params)

self.model = torch.nn.DataParallel(self.model)

self.model.to(self.device)

self.model.eval()

def setTestData(self, X_test: np.ndarray, batch_size: int):
self.X_test = torch.tensor(X_test[:, None, ...], dtype=torch.float32)
self.test_data = TensorDataset(self.X_test)
self.testloader = DataLoader(
self.test_data,
batch_size=batch_size,
shuffle=False,
)

def predictTestData(self, npz_save_path: str = None):

phs_eval = []
with torch.inference_mode():
for (ft_images, ) in self.testloader:
ph_eval = self.model(ft_images.to(self.device))
phs_eval.append(ph_eval.detach().cpu().numpy())

self.phs_eval = np.concatenate(phs_eval, axis=0)

if npz_save_path is not None:
np.savez_compressed(npz_save_path, ph=self.phs_eval)
print(f'Finished the inference stage and saved at {npz_save_path}')

return self.phs_eval

def calcErrors(self, phs_true: np.ndarray, npz_save_path: str = None):
from skimage.metrics import mean_squared_error as mse

self.phs_true = phs_true
self.errors = []
for i, (p1, p2) in enumerate(zip(self.phs_eval, self.phs_true)):
err2 = mse(p1, p2)
self.errors.append([err2])

self.errors = np.array(self.errors)
print("Mean errors in phase")
print(np.mean(self.errors, axis=0))

if npz_save_path is not None:
np.savez_compressed(npz_save_path, phs_err=self.errors[:, 0])

return self.errors
model.eval()
result = []
with torch.no_grad():
for batch in data:
result.append(
model(torch.from_numpy(batch[None, None, :, :]).to("cuda"))
.cpu()
.numpy()
)

result = np.concatenate(result, axis=0)

if inferences_out_file is not None:
np.save(inferences_out_file, result)

return result
Loading

0 comments on commit 78e4ec9

Please sign in to comment.