Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Docstrings for image classification data (#1093)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Jan 4, 2022
1 parent 1a599ef commit 088bd18
Show file tree
Hide file tree
Showing 9 changed files with 513 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ repos:
rev: v1.11.0
hooks:
- id: blacken-docs
args: [ --line-length=120 ]
args: [ --line-length=120, --skip-errors ]
additional_dependencies: [ black==21.10b0 ]

- repo: https://github.com/PyCQA/flake8
Expand Down
7 changes: 7 additions & 0 deletions docs/source/general/classification_targets.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.. _formatting_classification_targets:

*********************************
Formatting Classification Targets
*********************************

.. note:: The contents of this page are currently being updated. Stay tuned!
7 changes: 7 additions & 0 deletions docs/source/general/customizing_transforms.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.. _customizing_transforms:

**********************
Customizing Transforms
**********************

.. note:: The contents of this page are currently being updated. Stay tuned!
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ Lightning Flash
general/serve
general/backbones
general/optimization
general/classification_targets
general/customizing_transforms

.. toctree::
:maxdepth: 1
Expand Down
8 changes: 8 additions & 0 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,14 @@ def num_classes(self) -> Optional[int]:
n_cls_test = getattr(self.test_dataset, "num_classes", None)
return n_cls_train or n_cls_val or n_cls_test

@property
def labels(self) -> Optional[int]:
"""Property that returns the labels if this ``DataModule`` contains classification data."""
n_cls_train = getattr(self.train_dataset, "labels", None)
n_cls_val = getattr(self.val_dataset, "labels", None)
n_cls_test = getattr(self.test_dataset, "labels", None)
return n_cls_train or n_cls_val or n_cls_test

@property
def multi_label(self) -> Optional[bool]:
"""Property that returns ``True`` if this ``DataModule`` contains multi-label data."""
Expand Down
465 changes: 462 additions & 3 deletions flash/image/classification/data.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ codecov>=2.1
pytest>=5.0
pytest-flake8
flake8
pytest-doctestplus

# install pkg
check-manifest
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ norecursedirs =
.git
dist
build
doctest_plus = enabled
addopts =
--strict
--doctest-modules
--durations=0
--color=yes

Expand Down
49 changes: 24 additions & 25 deletions tests/core/data/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import ANY, call, MagicMock

import torch

Expand All @@ -29,7 +28,7 @@
def test_flash_callback(_, __, tmpdir):
"""Test the callback hook system for fit."""

callback_mock = MagicMock()
callback_mock = mock.MagicMock()

inputs = [(torch.rand(1), torch.rand(1))]
dm = DataModule(
Expand All @@ -44,10 +43,10 @@ def test_flash_callback(_, __, tmpdir):
_ = next(iter(dm.train_dataloader()))

assert callback_mock.method_calls == [
call.on_load_sample(ANY, RunningStage.TRAINING),
call.on_per_sample_transform(ANY, RunningStage.TRAINING),
call.on_collate(ANY, RunningStage.TRAINING),
call.on_per_batch_transform(ANY, RunningStage.TRAINING),
mock.call.on_load_sample(mock.ANY, RunningStage.TRAINING),
mock.call.on_per_sample_transform(mock.ANY, RunningStage.TRAINING),
mock.call.on_collate(mock.ANY, RunningStage.TRAINING),
mock.call.on_per_batch_transform(mock.ANY, RunningStage.TRAINING),
]

class CustomModel(Task):
Expand Down Expand Up @@ -75,23 +74,23 @@ def step(self, batch, batch_idx, metrics):
trainer.fit(CustomModel(), datamodule=dm)

assert callback_mock.method_calls == [
call.on_load_sample(ANY, RunningStage.TRAINING),
call.on_per_sample_transform(ANY, RunningStage.TRAINING),
call.on_collate(ANY, RunningStage.TRAINING),
call.on_per_batch_transform(ANY, RunningStage.TRAINING),
call.on_load_sample(ANY, RunningStage.VALIDATING),
call.on_per_sample_transform(ANY, RunningStage.VALIDATING),
call.on_collate(ANY, RunningStage.VALIDATING),
call.on_per_batch_transform(ANY, RunningStage.VALIDATING),
call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING),
call.on_load_sample(ANY, RunningStage.TRAINING),
call.on_per_sample_transform(ANY, RunningStage.TRAINING),
call.on_collate(ANY, RunningStage.TRAINING),
call.on_per_batch_transform(ANY, RunningStage.TRAINING),
call.on_per_batch_transform_on_device(ANY, RunningStage.TRAINING),
call.on_load_sample(ANY, RunningStage.VALIDATING),
call.on_per_sample_transform(ANY, RunningStage.VALIDATING),
call.on_collate(ANY, RunningStage.VALIDATING),
call.on_per_batch_transform(ANY, RunningStage.VALIDATING),
call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING),
mock.call.on_load_sample(mock.ANY, RunningStage.TRAINING),
mock.call.on_per_sample_transform(mock.ANY, RunningStage.TRAINING),
mock.call.on_collate(mock.ANY, RunningStage.TRAINING),
mock.call.on_per_batch_transform(mock.ANY, RunningStage.TRAINING),
mock.call.on_load_sample(mock.ANY, RunningStage.VALIDATING),
mock.call.on_per_sample_transform(mock.ANY, RunningStage.VALIDATING),
mock.call.on_collate(mock.ANY, RunningStage.VALIDATING),
mock.call.on_per_batch_transform(mock.ANY, RunningStage.VALIDATING),
mock.call.on_per_batch_transform_on_device(mock.ANY, RunningStage.VALIDATING),
mock.call.on_load_sample(mock.ANY, RunningStage.TRAINING),
mock.call.on_per_sample_transform(mock.ANY, RunningStage.TRAINING),
mock.call.on_collate(mock.ANY, RunningStage.TRAINING),
mock.call.on_per_batch_transform(mock.ANY, RunningStage.TRAINING),
mock.call.on_per_batch_transform_on_device(mock.ANY, RunningStage.TRAINING),
mock.call.on_load_sample(mock.ANY, RunningStage.VALIDATING),
mock.call.on_per_sample_transform(mock.ANY, RunningStage.VALIDATING),
mock.call.on_collate(mock.ANY, RunningStage.VALIDATING),
mock.call.on_per_batch_transform(mock.ANY, RunningStage.VALIDATING),
mock.call.on_per_batch_transform_on_device(mock.ANY, RunningStage.VALIDATING),
]

0 comments on commit 088bd18

Please sign in to comment.