Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add BaseViz Callback (2 / 2) #201

Merged
merged 37 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
168b231
wip
tchaton Mar 30, 2021
cda64d3
add base_viz + new features for DataPipeline
tchaton Mar 31, 2021
2b2c499
update
tchaton Mar 31, 2021
6db6b1c
resolve flake8
tchaton Mar 31, 2021
f61deea
update
tchaton Mar 31, 2021
cb85981
Merge branch 'master' into base_viz
tchaton Mar 31, 2021
ffaa7c7
resolve tests
tchaton Mar 31, 2021
596a523
update
tchaton Mar 31, 2021
2fdefbe
wip
tchaton Mar 31, 2021
4381441
update
tchaton Mar 31, 2021
d572248
resolve doc
tchaton Mar 31, 2021
b928fc5
resolve doc
tchaton Mar 31, 2021
9381d41
update doc
tchaton Mar 31, 2021
108a7cc
update
tchaton Apr 1, 2021
6da92b3
update
tchaton Apr 1, 2021
d4cf9f5
update
tchaton Apr 1, 2021
16deb7b
convert to staticmethod
tchaton Apr 1, 2021
4025eb0
initial visualisation implementation
edgarriba Apr 1, 2021
37c8084
Merge branch 'base_viz_2' of https://github.com/PyTorchLightning/ligh…
edgarriba Apr 1, 2021
d2076d4
implement test case using Kornia transforms
edgarriba Apr 1, 2021
ff8e1ad
update on comments
tchaton Apr 1, 2021
84eaa68
resolve bug
tchaton Apr 1, 2021
881851a
Merge branch 'data_pipeline_current_fn' into base_viz_2
tchaton Apr 1, 2021
fb25c04
update
tchaton Apr 1, 2021
cc760a5
Merge branch 'master' into base_viz_2
tchaton Apr 1, 2021
d3932c9
update
tchaton Apr 1, 2021
ee9f781
Merge branch 'base_viz_2' of https://github.com/PyTorchLightning/ligh…
tchaton Apr 1, 2021
f6f33b8
add test
tchaton Apr 1, 2021
2de0e15
update
tchaton Apr 1, 2021
631f06f
resolve tests
tchaton Apr 6, 2021
bda5ff2
resolve flake8
tchaton Apr 6, 2021
0e74167
update
tchaton Apr 6, 2021
d0fb78d
update
tchaton Apr 6, 2021
098d7ab
update
tchaton Apr 6, 2021
ba0a992
Merge branch 'master' into base_viz_2
tchaton Apr 6, 2021
9bdd179
resolve test
tchaton Apr 6, 2021
67ba94c
Merge branch 'base_viz_2' of https://github.com/PyTorchLightning/ligh…
tchaton Apr 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,22 +245,11 @@ def on_fit_end(self) -> None:
self.data_pipeline._detach_from_model(self)
super().on_fit_end()

@staticmethod
def _sanetize_funcs(obj: Any) -> Any:
if hasattr(obj, "__dict__"):
for k, v in obj.__dict__.items():
if isinstance(v, Callable):
obj.__dict__[k] = inspect.unwrap(v)
return obj

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# TODO: Is this the best way to do this? or should we also use some kind of hparams here?
# This may be an issue since here we create the same problems with pickle as in
# https://pytorch.org/docs/stable/notes/serialization.html
if self.data_pipeline is not None and 'data_pipeline' not in checkpoint:
self._preprocess = self._sanetize_funcs(self._preprocess)
checkpoint['data_pipeline'] = self.data_pipeline
# todo (tchaton) re-wrap visualization
super().on_save_checkpoint(checkpoint)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
Expand Down
14 changes: 13 additions & 1 deletion flash/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, Mapping, Sequence, Union
from typing import Any, Callable, Dict, Mapping, Sequence, Type, Union


def get_callable_name(fn_or_class: Union[Callable, object]) -> str:
Expand All @@ -25,3 +25,15 @@ def get_callable_dict(fn: Union[Callable, Mapping, Sequence]) -> Union[Dict, Map
return {get_callable_name(f): f for f in fn}
elif callable(fn):
return {get_callable_name(fn): fn}


def _is_overriden(method_name: str, instance: object, parent: Type[object]) -> bool:
"""
Cropped Version of
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/model_helpers.py
"""

if not hasattr(instance, method_name):
return False

return getattr(instance, method_name).__code__ != getattr(parent, method_name).__code__
12 changes: 11 additions & 1 deletion flash/data/auto_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pytorch_lightning.utilities.warning_utils import rank_zero_warn
from torch.utils.data import Dataset

from flash.data.callback import ControlFlow
from flash.data.process import Preprocess
from flash.data.utils import _STAGES_PREFIX, _STAGES_PREFIX_VALUES, CurrentRunningStageFuncContext

Expand Down Expand Up @@ -82,6 +83,12 @@ def preprocess(self) -> Optional[Preprocess]:
if self.data_pipeline is not None:
return self.data_pipeline._preprocess_pipeline

@property
def control_flow_callback(self) -> Optional[ControlFlow]:
preprocess = self.preprocess
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if preprocess is not None:
return ControlFlow(preprocess.callbacks)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def _call_load_data(self, data: Any) -> Iterable:
parameters = signature(self.load_data).parameters
if len(parameters) > 1 and self.DATASET_KEY in parameters:
Expand Down Expand Up @@ -124,7 +131,10 @@ def __getitem__(self, index: int) -> Any:
raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.")
if self.load_sample:
with self._load_sample_context:
return self._call_load_sample(self.preprocessed_data[index])
data: Any = self._call_load_sample(self.preprocessed_data[index])
if self.control_flow_callback:
self.control_flow_callback.on_load_sample(data, self.running_stage)
return data
return self.preprocessed_data[index]

def __len__(self) -> int:
Expand Down
111 changes: 111 additions & 0 deletions flash/data/base_viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from contextlib import contextmanager
from typing import Any, Dict, List, Sequence

from pytorch_lightning.trainer.states import RunningStage
from torch import Tensor

from flash.core.utils import _is_overriden
from flash.data.callback import FlashCallback
from flash.data.process import Preprocess
from flash.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX


class BaseViz(FlashCallback):
"""
This class is used to profile ``Preprocess`` hook outputs and visualize the data transformations.
It is disabled by default.
tchaton marked this conversation as resolved.
Show resolved Hide resolved

batches: Dict = {"train": {"to_tensor_transform": [], ...}, ...}

"""

def __init__(self, enabled: bool = False):
self.batches = {k: {} for k in _STAGES_PREFIX.values()}
self.enabled = enabled
self._preprocess = None

def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("load_sample", [])
store["load_sample"].append(sample)

def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("pre_tensor_transform", [])
store["pre_tensor_transform"].append(sample)

def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("to_tensor_transform", [])
store["to_tensor_transform"].append(sample)

def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("post_tensor_transform", [])
store["post_tensor_transform"].append(sample)

def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("per_batch_transform", [])
store["per_batch_transform"].append(batch)

def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("collate", [])
store["collate"].append(batch)

def on_per_sample_transform_on_device(self, samples: Sequence, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("per_sample_transform_on_device", [])
store["per_sample_transform_on_device"].append(samples)

def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("per_batch_transform_on_device", [])
store["per_batch_transform_on_device"].append(batch)

@contextmanager
def enable(self):
self.enabled = True
yield
self.enabled = False

def attach_to_datamodule(self, datamodule) -> None:
datamodule.viz = self

def attach_to_preprocess(self, preprocess: Preprocess) -> None:
preprocess.callbacks = [self]
self._preprocess = preprocess

def show(self, batch: Dict[str, Any], running_stage: RunningStage) -> None:
"""
This function is a hook for users to override with their visualization on a batch.
"""
for func_name in _PREPROCESS_FUNCS:
hook_name = f"show_{func_name}"
if _is_overriden(hook_name, self, BaseViz):
getattr(self, hook_name)(batch[func_name], running_stage)

def show_load_sample(self, samples: List[Any], running_stage: RunningStage):
pass

def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
pass

def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
pass

def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
pass

def show_collate(self, batch: Sequence, running_stage: RunningStage) -> None:
pass

def show_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None:
pass

def show_per_sample_transform_on_device(self, samples: Sequence, running_stage: RunningStage) -> None:
pass

def show_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None:
pass
29 changes: 23 additions & 6 deletions flash/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor

from flash.data.callback import ControlFlow
from flash.data.utils import _contains_any_tensor, convert_to_modules, CurrentFuncContext, CurrentRunningStageContext

if TYPE_CHECKING:
Expand All @@ -43,6 +44,7 @@ def __init__(
):
super().__init__()
self.preprocess = preprocess
self.callback = ControlFlow(self.preprocess.callbacks)
self.pre_tensor_transform = convert_to_modules(pre_tensor_transform)
self.to_tensor_transform = convert_to_modules(to_tensor_transform)
self.post_tensor_transform = convert_to_modules(post_tensor_transform)
Expand All @@ -58,9 +60,11 @@ def forward(self, sample: Any) -> Any:
with self._current_stage_context:
with self._pre_tensor_transform_context:
sample = self.pre_tensor_transform(sample)
self.callback.on_pre_tensor_transform(sample, self.stage)

with self._to_tensor_transform_context:
sample = self.to_tensor_transform(sample)
self.callback.on_to_tensor_transform(sample, self.stage)

if self.assert_contains_tensor:
if not _contains_any_tensor(sample):
Expand All @@ -71,6 +75,7 @@ def forward(self, sample: Any) -> Any:

with self._post_tensor_transform_context:
sample = self.post_tensor_transform(sample)
self.callback.on_post_tensor_transform(sample, self.stage)

return sample

Expand Down Expand Up @@ -112,36 +117,48 @@ def __init__(
per_batch_transform: Callable,
stage: RunningStage,
apply_per_sample_transform: bool = True,
on_device: bool = False
on_device: bool = False,
):
super().__init__()
self.preprocess = preprocess
self.callback = ControlFlow(self.preprocess.callbacks)
self.collate_fn = convert_to_modules(collate_fn)
self.per_sample_transform = convert_to_modules(per_sample_transform)
self.per_batch_transform = convert_to_modules(per_batch_transform)
self.apply_per_sample_transform = apply_per_sample_transform
self.stage = stage
self.on_device = on_device

extension = f"{'on_device' if self.on_device else ''}"
extension = f"{'_on_device' if self.on_device else ''}"
self._current_stage_context = CurrentRunningStageContext(stage, preprocess)
self._per_sample_transform_context = CurrentFuncContext(f"per_sample_transform_{extension}", preprocess)
self._per_sample_transform_context = CurrentFuncContext(f"per_sample_transform{extension}", preprocess)
self._collate_context = CurrentFuncContext("collate", preprocess)
self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform_{extension}", preprocess)
self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform{extension}", preprocess)

def forward(self, samples: Sequence[Any]) -> Any:
with self._current_stage_context:

if self.apply_per_sample_transform:
with self._per_sample_transform_context:
samples = [self.per_sample_transform(sample) for sample in samples]
samples = type(samples)(samples)
_samples = []
for sample in samples:
sample = self.per_sample_transform(sample)
if self.on_device:
self.callback.on_per_sample_transform_on_device(sample, self.stage)
_samples.append(sample)

samples = type(_samples)(_samples)

with self._collate_context:
samples = self.collate_fn(samples)
self.callback.on_collate(samples, self.stage)

with self._per_batch_transform_context:
samples = self.per_batch_transform(samples)
if self.on_device:
self.callback.on_per_batch_transform_on_device(samples, self.stage)
else:
self.callback.on_per_batch_transform(samples, self.stage)
return samples

def __str__(self) -> str:
Expand Down
67 changes: 67 additions & 0 deletions flash/data/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Any, List, Sequence

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer.states import RunningStage
from torch import Tensor


class FlashCallback(Callback):

def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None:
"""Called once a sample has been loaded using ``load_sample``."""

def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
"""Called once ``pre_tensor_transform`` have been applied to a sample."""

def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
"""Called once ``to_tensor_transform`` have been applied to a sample."""

def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None:
"""Called once ``post_tensor_transform`` have been applied to a sample."""

def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None:
"""Called once ``per_batch_transform`` have been applied to a batch."""

def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None:
"""Called once ``collate`` have been applied to a sequence of samples."""

def on_per_sample_transform_on_device(self, sample: Any, running_stage: RunningStage) -> None:
"""Called once ``per_sample_transform_on_device`` have been applied to a sample."""

def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None:
"""Called once ``per_batch_transform_on_device`` have been applied to a sample."""


class ControlFlow(FlashCallback):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit weird to me as we don't have such thing in Lightning. We just loop over all callbacks for each hook around the code. Do you prefer this object? Why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I prefer this object and I think we should add a similar one within Lightning. The callbacks logic should be fully managed by callbacks.


def __init__(self, callbacks: List[FlashCallback]):
self._callbacks = callbacks

def run_for_all_callbacks(self, *args, method_name: str, **kwargs):
if self._callbacks:
for cb in self._callbacks:
getattr(cb, method_name)(*args, **kwargs)

def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(sample, running_stage, method_name="on_load_sample")

def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(sample, running_stage, method_name="on_pre_tensor_transform")

def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(sample, running_stage, method_name="on_to_tensor_transform")

def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(sample, running_stage, method_name="on_post_tensor_transform")

def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(batch, running_stage, method_name="on_per_batch_transform")

def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(batch, running_stage, method_name="on_collate")

def on_per_sample_transform_on_device(self, sample: Any, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(sample, running_stage, method_name="on_per_sample_transform_on_device")

def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(batch, running_stage, method_name="on_per_batch_transform_on_device")
Loading