diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..b3757c2 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,20 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + +- repo: https://github.com/omnilib/ufmt + rev: v2.1.0 + hooks: + - id: ufmt + additional_dependencies: + - black == 23.3.0 + - usort == 1.0.6 + - ufmt == 2.1.0 + - libcst == 1.0.1 diff --git a/README.md b/README.md index 3b8f066..24e61fc 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,13 @@ git clone https://github.com/drisspg/transformer_nuggets.git pip install -e . ``` +#### Dev Tool Chain +``` Shell +pip install -e ".[dev]" +``` +pre-commit is used to make sure that I don't forget to format stuff, I am going to see if I like this or not. This +should be installed when installing the dev tools. + ## Project Structure - **benchmarks**: Contains scripts and data related to benchmarking the transformer components. @@ -21,11 +28,14 @@ pip install -e . - `flash.py`: Benchmarking script for Flash. - `llama.py`: Benchmarking script for Llama. - `qlora.py`: Benchmarking script for Qlora. + - `fp8_sat_cast.py`: Benchmarks for comparing FP8 saturated casting kernel to eager and compile code. - **transformer_nuggets**: The main directory containing all transformer components/modules. - **flash**: Components related to the FlashAttention. - **quant**: Implementation of NF4 Tensor and QLora in pure Pytorch - **sdpa**: Prototype for updated SDPA interface in Pytorch. + - **fp8**: Components related interacting with PyTorch FP8 tensors. + - **llama**: Contains a model def for llama2 models as well as a pretraining script. - **utils**: General utility functions and scripts. - `benchmark.py`: Benchmark-related utility functions. - `tracing.py`: Tracing utilities for transformers. @@ -34,3 +44,4 @@ pip install -e . - `test_flash.py`: Tests for Flash. - `test_qlora.py`: Tests for Qlora. - `test_sdpa.py`: Tests for SDPA. + - `test_fp8.py`: Tests for FP8. diff --git a/pyproject.toml b/pyproject.toml index ad28a12..3885ef9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dev = [ "usort==1.0.6", "ufmt==2.1.0", "libcst==1.0.1", + "pre-commit-3.6.0", "bumpver", "pip-tools", "pytest"