Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update quantization to use tensor subclasses
**Summary:** In torchao, we are migrating our quantization flows from module swap to tensor subclasses. The existing `Int8DynActInt4WeightQuantizer` will be deprecated in the near future in favor of quantizing using the `quantize_` API, so we should do the same in torchtune. This quantizer is currently only used by QAT, which also recently migrated to a tensor subclass implementation. This also changes the eval script slightly since models quantized through the torchao tensor subclasses are expected to be loaded with `assign=True`: https://github.com/pytorch/ao/blob/9a56e80cb6070599701b8f5f587bd8187c8dccb4/test/quantization/test_quant_api.py#L610. We should load the model similarly in torchtune. **Test Plan:** Quantized and evaluated the base Llama3-8B model on 1 A100 GPU: ``` CUDA_VISIBLE_DEVICES=1 tune run quantize --config recipes/configs/my_quantization.yaml \ model._component_=torchtune.models.llama3.llama3_8b \ checkpointer.checkpoint_dir=/tmp/Meta-Llama-3-8B-Instruct/original \ checkpointer.output_dir=/tmp/Meta-Llama-3-8B-Instruct/original \ checkpointer.checkpoint_files=[consolidated.00.pth] \ checkpointer.model_type=LLAMA3 CUDA_VISIBLE_DEVICES=1 tune run eleuther_eval --config eleuther_evaluation \ model._component_=torchtune.models.llama3.llama3_8b \ checkpointer._component_=torchtune.utils.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir=/tmp/Meta-Llama-3-8B-Instruct/original \ checkpointer.output_dir=/tmp/Meta-Llama-3-8B-Instruct/original \ checkpointer.checkpoint_files=[consolidated-8da4w.pt] \ checkpointer.model_type=LLAMA3 \ tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \ quantizer._component_=torchtune.utils.quantization.Int8DynActInt4WeightQuantizer ``` Reviewers: ebsmothers, kartikayk, RdoubleA Subscribers: ebsmothers, kartikayk, RdoubleA Subscribers:
- Loading branch information