Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: Replace custom Trainer with Pytorch Lightning #22

Merged
merged 12 commits into from
Jan 16, 2024

Conversation

carterbox
Copy link
Collaborator

I have refactored the PtychoNN API to be functional and replaced the custom training management scripts with Pytorch Lightning.

This has the following advantages

  • reduces the amount of boilerplate in this project
  • automates device management
  • provides a more robust model checkpointing/saving/reloading method

Fixes #21

@carterbox carterbox requested a review from stevehenke January 2, 2024 22:09
@carterbox
Copy link
Collaborator Author

@stevehenke, please take a look at the new API. The following functions specifically:

  • ptychonn.train
  • ptychonn.init_or_load_model
  • ptychonn.infer

The network model is now LitReconSmallModel instead of ReconSmallModel because some of the training parameters are moved into the parameters of the model.

Copy link
Collaborator

@stevehenke stevehenke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes look like a significant improvement in PtychoNN training capabilities. Nice job!

@carterbox
Copy link
Collaborator Author

carterbox commented Jan 3, 2024

  • Create a custom logger that doesn't touch disk
  • Create a checkpointing function so downstream doesn't need to import pytorch lightning

@carterbox
Copy link
Collaborator Author

import numpy as np
import ptychonn
from threading import Event

if __name__== '__main__':

    x = np.random.rand(1_000, 256, 256).astype(np.float32)
    y = np.random.rand(1_000, 1, 256, 256).astype(np.float32)

    print("Transferring memory to PtychoNN")
    # t = ptychonn.Trainer(
    #     model=ptychonn.ReconSmallModel,
    #     batch_size=32,
    # )
    # t.setTrainingData(
    #     x,
    #     y,
    # )
    dataloader0, dataloader1 = ptychonn.create_training_dataloader(
        x,
        y,
        batch_size=32,
        training_fraction=0.8,
    )

    print("Waiting forever for you to check memory consumption.")
    Event().wait()

I have used the above script to test the difference in memory consumption between the old and new API. In my tests, the memory usage was decreased from 1.2GiB to 0.76GiB. That's like 40% reduction.

@carterbox carterbox merged commit 78e4ec9 into mcherukara:package Jan 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants