-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
API: Replace custom Trainer with Pytorch Lightning #22
Conversation
@stevehenke, please take a look at the new API. The following functions specifically:
The network model is now LitReconSmallModel instead of ReconSmallModel because some of the training parameters are moved into the parameters of the model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes look like a significant improvement in PtychoNN training capabilities. Nice job!
|
import numpy as np
import ptychonn
from threading import Event
if __name__== '__main__':
x = np.random.rand(1_000, 256, 256).astype(np.float32)
y = np.random.rand(1_000, 1, 256, 256).astype(np.float32)
print("Transferring memory to PtychoNN")
# t = ptychonn.Trainer(
# model=ptychonn.ReconSmallModel,
# batch_size=32,
# )
# t.setTrainingData(
# x,
# y,
# )
dataloader0, dataloader1 = ptychonn.create_training_dataloader(
x,
y,
batch_size=32,
training_fraction=0.8,
)
print("Waiting forever for you to check memory consumption.")
Event().wait() I have used the above script to test the difference in memory consumption between the old and new API. In my tests, the memory usage was decreased from 1.2GiB to 0.76GiB. That's like 40% reduction. |
I have refactored the PtychoNN API to be functional and replaced the custom training management scripts with Pytorch Lightning.
This has the following advantages
Fixes #21