Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FP8 checkpoint resumption with onnx export flag #2907

Merged
merged 3 commits into from
Jan 26, 2024

Conversation

j316chuck
Copy link
Contributor

@j316chuck j316chuck commented Jan 26, 2024

What does this PR do?

Before this PR, checkpoint resumption on FP8 did not work since checkpoints serialized extra FP8 buffers in bytes which is incompatible with torch's load_sharded_optimizers

After this PR checkpoint resumption from fp8 serializes out extra fp8 buffers in tensor onnx mode.

Tests

Torch 2.1

  • Failed run w/out this PR: mpt-125m-sharded-regression-fp8-ignore-td0frt 🔴
  • Working run w/this PR: mpt-125m-sharded-regression-fp8-ignore-6DD62n

Torch 2.3

  • Run w/out this PR: mpt-125m-sharded-regression-fp8-ignore-tAGCla
  • Working run w/this PR: mpt-125m-sharded-regression-fp8-ignore-8bZirC

@j316chuck j316chuck requested review from eracah, dakinggg and a team as code owners January 26, 2024 01:00
@j316chuck j316chuck requested a review from mvpatel2000 January 26, 2024 01:00
@j316chuck j316chuck force-pushed the chuck/add_onnx_eport branch from 901e5cb to 1a0685f Compare January 26, 2024 01:08
@j316chuck j316chuck changed the title Add fp8 onnx export to fix resumption/checkpoint errors. Fix FP8 resumption/checkpoint errors with onnx export Jan 26, 2024
@j316chuck j316chuck changed the title Fix FP8 resumption/checkpoint errors with onnx export Fix FP8 checkpoint resumption with onnx export flag Jan 26, 2024
@j316chuck j316chuck requested review from dskhudia and removed request for eracah January 26, 2024 01:14
@dskhudia
Copy link
Contributor

@j316chuck : Thinking a bit more about it, I think we should apply it when we are creating a checkpoint instead of the time when we are requesting precision.

Copy link
Contributor Author

@j316chuck j316chuck left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dskhudia I think it's a bit more clean to cast when requesting precision since the checkpoint code save/load is a bit scattered throughout checkpoint.py and I have to add this context manager in multiple places (10+ lines vs 1 line). Wdyt?

composer/core/precision.py Show resolved Hide resolved
@dskhudia
Copy link
Contributor

@j316chuck : sounds good. I looked at the defintion of this context manager and it doesn't do much besides setting a global variable. I was worried about this slowing down training but it probably doesn't. Feel free to merge.

@j316chuck j316chuck merged commit 2c6a390 into dev Jan 26, 2024
16 checks passed
@j316chuck j316chuck deleted the chuck/add_onnx_eport branch January 26, 2024 21:49
ShashankMosaicML pushed a commit to ShashankMosaicML/composer that referenced this pull request Feb 3, 2024
ShashankMosaicML pushed a commit to ShashankMosaicML/composer that referenced this pull request Feb 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants