Skip to content

Commit

Permalink
Merge pull request #26 from carterbox/train-parameter
Browse files Browse the repository at this point in the history
NEW: Train parameter
  • Loading branch information
carterbox authored Jun 6, 2024
2 parents 170ffb7 + b960d96 commit 06c6f27
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions ptychonn/_train/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def train(
out_dir: pathlib.Path | None,
epochs: int = 1,
batch_size: int = 32,
training_fraction: float = 0.8,
) -> typing.Tuple[lightning.Trainer, lightning.pytorch.loggers.CSVLogger | ListLogger]:
"""Train a PtychoNN model.
Expand All @@ -178,6 +179,8 @@ def train(
The maximum number of training epochs
batch_size
The size of one training batch.
training_fraction
The proprotion of X_train and Y_train that is used for training.
"""
if out_dir is not None:
checkpoint_callback = lightning.pytorch.callbacks.ModelCheckpoint(
Expand Down Expand Up @@ -210,6 +213,7 @@ def train(
X_train,
Y_train,
batch_size,
training_fraction=training_fraction,
)

trainer.fit(
Expand Down

0 comments on commit 06c6f27

Please sign in to comment.