Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new: add main to api index, update quaterion docstrings #126 #131

Merged
merged 2 commits into from
Jun 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ Extras

~pytorch_metric_learning_wrapper.PytorchMetricLearningWrapper

MAIN
----

.. py:currentmodule:: quaterion.main

.. autosummary::
:nosignatures:

Quaterion

TRAIN
-----

Expand Down
18 changes: 12 additions & 6 deletions quaterion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@


class Quaterion:
"""A dwarf on a giant's shoulders sees farther of the two"""
"""Fine-tuning entry point

Contains methods to launch the actual training and evaluation processes.
"""

@classmethod
def fit(
Expand All @@ -44,9 +47,10 @@ def fit(
stage
val_dataloader: Optional DataLoader instance to retrieve samples during
validation stage
ckpt_path: Path/URL of the checkpoint from which training is resumed. If there is
no checkpoint file at the path, an exception is raised. If resuming from mid-epoch checkpoint,
training will start from the beginning of the next epoch.
ckpt_path: Path/URL of the checkpoint from which training is resumed.
If there is no checkpoint file at the path, an exception is raised.
If resuming from mid-epoch checkpoint, training will start from the beginning of
the next epoch.
"""

if isinstance(train_dataloader, PairsSimilarityDataLoader):
Expand Down Expand Up @@ -100,8 +104,10 @@ def evaluate(
Compute metrics on a dataset

Args:
evaluator: Object which holds the configuration of which metrics to use and how to obtain samples for them
dataset: Sized object, like list, tuple, torch.utils.data.Dataset, etc. to compute metrics
evaluator: Object which holds the configuration of which metrics to use and how to
obtain samples for them
dataset: Sized object, like list, tuple, torch.utils.data.Dataset, etc. to compute
metrics
model: SimilarityModel instance to perform objects encoding

Returns:
Expand Down