Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
update task
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Apr 21, 2021
1 parent 795f8d1 commit 644f062
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
8 changes: 4 additions & 4 deletions flash/core/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
_SCHEDULER_REGISTRY = FlashRegistry("scheduler")

if _TRANSFORMERS_AVAILABLE:
from transformers.optimization import TYPE_TO_SCHEDULER_FUNCTION

for v in TYPE_TO_SCHEDULER_FUNCTION.values():
_SCHEDULER_REGISTRY(v, name=v.__name__[4:])
from transformers import optimization
functions = [getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != 'get_scheduler')]
for fn in functions:
_SCHEDULER_REGISTRY(fn, name=fn.__name__[4:])
6 changes: 3 additions & 3 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_optimization(tmpdir):
assert isinstance(scheduler[0], torch.optim.lr_scheduler.StepLR)

if _TRANSFORMERS_AVAILABLE:
from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from transformers.optimization import get_linear_schedule_with_warmup

assert task.available_schedulers() == [
'constant_schedule', 'constant_schedule_with_warmup', 'cosine_schedule_with_warmup',
Expand All @@ -219,11 +219,11 @@ def test_optimization(tmpdir):
scheduler_kwargs={"num_warmup_steps": 0.1},
loss_fn=F.nll_loss,
)
trainer = flash.Trainer(max_epochs=100)
trainer = flash.Trainer(max_epochs=1, limit_train_batches=2)
ds = DummyDataset()
trainer.fit(task, train_dataloader=DataLoader(ds))
optimizer, scheduler = task.configure_optimizers()
assert isinstance(optimizer[0], torch.optim.Adadelta)
assert isinstance(scheduler[0], torch.optim.lr_scheduler.LambdaLR)
expected = TYPE_TO_SCHEDULER_FUNCTION[SchedulerType.LINEAR].__name__
expected = get_linear_schedule_with_warmup.__name__
assert scheduler[0].lr_lambdas[0].__qualname__.split('.')[0] == expected

0 comments on commit 644f062

Please sign in to comment.