Skip to content

Commit

Permalink
Pre commit (#11)
Browse files Browse the repository at this point in the history
* add some base hooks

* dummy
  • Loading branch information
drisspg authored Dec 20, 2023
1 parent 6db36c3 commit 22a8572
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 0 deletions.
20 changes: 20 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files

- repo: https://github.com/omnilib/ufmt
rev: v2.1.0
hooks:
- id: ufmt
additional_dependencies:
- black == 23.3.0
- usort == 1.0.6
- ufmt == 2.1.0
- libcst == 1.0.1
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,28 @@ git clone https://github.com/drisspg/transformer_nuggets.git
pip install -e .
```

#### Dev Tool Chain
``` Shell
pip install -e ".[dev]"
```
pre-commit is used to make sure that I don't forget to format stuff, I am going to see if I like this or not. This
should be installed when installing the dev tools.

## Project Structure

- **benchmarks**: Contains scripts and data related to benchmarking the transformer components.
- **data**: Benchmark data files.
- `flash.py`: Benchmarking script for Flash.
- `llama.py`: Benchmarking script for Llama.
- `qlora.py`: Benchmarking script for Qlora.
- `fp8_sat_cast.py`: Benchmarks for comparing FP8 saturated casting kernel to eager and compile code.

- **transformer_nuggets**: The main directory containing all transformer components/modules.
- **flash**: Components related to the FlashAttention.
- **quant**: Implementation of NF4 Tensor and QLora in pure Pytorch
- **sdpa**: Prototype for updated SDPA interface in Pytorch.
- **fp8**: Components related interacting with PyTorch FP8 tensors.
- **llama**: Contains a model def for llama2 models as well as a pretraining script.
- **utils**: General utility functions and scripts.
- `benchmark.py`: Benchmark-related utility functions.
- `tracing.py`: Tracing utilities for transformers.
Expand All @@ -34,3 +44,4 @@ pip install -e .
- `test_flash.py`: Tests for Flash.
- `test_qlora.py`: Tests for Qlora.
- `test_sdpa.py`: Tests for SDPA.
- `test_fp8.py`: Tests for FP8.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dev = [
"usort==1.0.6",
"ufmt==2.1.0",
"libcst==1.0.1",
"pre-commit-3.6.0",
"bumpver",
"pip-tools",
"pytest"
Expand Down

0 comments on commit 22a8572

Please sign in to comment.