Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix loss function buffer support (#1203)
Browse files Browse the repository at this point in the history
* Fix loss function buffer support

* Update CHANGELOG.md
  • Loading branch information
ethanwharris committed Mar 1, 2022
1 parent 2713af4 commit c835d8c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed DDP support for `VideoClassifier` ([#1189](https://github.com/PyTorchLightning/lightning-flash/pull/1189))

- Fixed a bug where buffers in loss functions were not correctly registered in the `Task` ([#1203](https://github.com/PyTorchLightning/lightning-flash/pull/1203))

## [0.7.0] - 2022-02-15

### Added
Expand Down
6 changes: 5 additions & 1 deletion flash/core/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@
# limitations under the License.
from typing import Callable, Dict, Mapping, Sequence, Type, Union

from torch import nn


def get_callable_name(fn_or_class: Union[Callable, object]) -> str:
return getattr(fn_or_class, "__name__", fn_or_class.__class__.__name__).lower()


def get_callable_dict(fn: Union[Callable, Mapping, Sequence]) -> Union[Dict, Mapping]:
def get_callable_dict(fn: Union[nn.Module, Callable, Mapping, Sequence]) -> Union[Dict, Mapping]:
if isinstance(fn, nn.Module):
return nn.ModuleDict({get_callable_name(fn): fn})
if isinstance(fn, Mapping):
return fn
if isinstance(fn, Sequence):
Expand Down
9 changes: 9 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,12 @@ def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") ->
trainer = flash.Trainer(max_epochs=1, callbacks=CheckAccuracy(), gpus=torch.cuda.device_count())
trainer.fit(task, train_dataloader=DataLoader(train_dataset), val_dataloaders=DataLoader(val_dataset))
trainer.test(task, DataLoader(test_dataset))


def test_loss_fn_buffer():
weight = torch.rand(10)
model = Task(loss_fn=nn.CrossEntropyLoss(weight=weight))
state_dict = model.state_dict()

assert len(state_dict) == 1
assert torch.allclose(state_dict["loss_fn.crossentropyloss.weight"], weight)

0 comments on commit c835d8c

Please sign in to comment.