Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[feat] Add support for schedulers #232

Merged
merged 6 commits into from
Apr 21, 2021
Merged

[feat] Add support for schedulers #232

merged 6 commits into from
Apr 21, 2021

Conversation

tchaton
Copy link
Contributor

@tchaton tchaton commented Apr 21, 2021

What does this PR do?

This PR add a better support for optimizer and schedulers.

Fixes # (issue)

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests? [not needed for typos/docs]
  • Did you verify new and existing tests pass locally with your changes?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

PR review

  • Is this pull request ready for review? (if not, please submit in draft mode)

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@tchaton tchaton added this to the 0.3 milestone Apr 21, 2021
@tchaton tchaton self-assigned this Apr 21, 2021
@codecov
Copy link

codecov bot commented Apr 21, 2021

Codecov Report

Merging #232 (d5d24e5) into master (42f8db4) will increase coverage by 0.00%.
The diff coverage is 87.14%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master     #232   +/-   ##
=======================================
  Coverage   86.56%   86.56%           
=======================================
  Files          57       58    +1     
  Lines        2843     2910   +67     
=======================================
+ Hits         2461     2519   +58     
- Misses        382      391    +9     
Flag Coverage Δ
unittests 86.56% <87.14%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
flash/core/model.py 91.47% <84.48%> (-2.57%) ⬇️
flash/core/schedulers.py 100.00% <100.00%> (ø)
flash/data/data_pipeline.py 87.27% <100.00%> (+0.09%) ⬆️
flash/utils/imports.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 42f8db4...d5d24e5. Read the comment docs.

@tchaton tchaton mentioned this pull request Apr 21, 2021
8 tasks
else:
dataset_size = len(self.train_dataloader())

num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe someone else can confirm, but I think num_gpus is automatically set to num_processes now, so this isn't needed, we should be able to do just num_devices = num_processes

Copy link
Contributor

@SeanNaren SeanNaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, might want to get someone else who hasn't seen this code to also review :)

flash/core/model.py Outdated Show resolved Hide resolved
flash/core/model.py Outdated Show resolved Hide resolved
flash/core/model.py Outdated Show resolved Hide resolved
flash/core/model.py Outdated Show resolved Hide resolved
flash/core/model.py Outdated Show resolved Hide resolved

if _TRANSFORMERS_AVAILABLE:
from transformers import optimization
functions = [getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != 'get_scheduler')]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
functions = [getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != 'get_scheduler')]
functions: List[Callable] = [getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != 'get_scheduler')]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done !

Copy link
Collaborator

@ethanwharris ethanwharris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, looks good 😃 Can you maybe add if _TRANSFORMERS_AVAILABLE to the other files that use transformers (e.g. /text/seq2seq/summarization/data.py)?

flash/core/schedulers.py Show resolved Hide resolved
flash/utils/imports.py Outdated Show resolved Hide resolved
@tchaton tchaton merged commit 79f271e into master Apr 21, 2021
@tchaton tchaton deleted the scheduler branch April 21, 2021 19:40
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants