Skip to content

Commit

Permalink
update profile function
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed May 15, 2024
1 parent e85975f commit 694f73d
Show file tree
Hide file tree
Showing 21 changed files with 141 additions and 91 deletions.
7 changes: 2 additions & 5 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Basic flak8 + pytest workflow for Python 3.10
# Pytest workflow for Python 3.10

name: Python Lint and Test
name: Pytest all the things

on:
push:
Expand All @@ -27,9 +27,6 @@ jobs:
python -m pip install --upgrade pip
pip install -e .
pip install -e .'[dev]'
- name: Lint with ruff
run: |
ruff check .
- name: Test with pytest
run: |
pytest
13 changes: 9 additions & 4 deletions .github/workflows/ufmt.yml → .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Ufmt
name: Code Analysis with Ruff

on:
push:
Expand All @@ -20,7 +20,12 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install black==23.3.0 usort==1.0.6 ufmt==2.1.0 libcst==1.0.1
- name: Analyzing the code with ufmt
python -m pip install --upgrade pip
pip install -e .
pip install -e .'[dev]'
- name: Analyzing the code with ruff
run: |
ufmt check .
ruff check .
- name: Check well formatted code
run: |
ruff format --check
25 changes: 9 additions & 16 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,12 @@ repos:
- id: check-yaml
- id: check-added-large-files

- repo: https://github.com/omnilib/ufmt
rev: v2.1.0
hooks:
- id: ufmt
additional_dependencies:
- black == 23.3.0
- usort == 1.0.6
- ufmt == 2.1.0
- libcst == 1.0.1

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.0
hooks:
# Run the linter.
- id: ruff
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.4
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
5 changes: 4 additions & 1 deletion benchmarks/fp8_dynamic_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from tabulate import tabulate
from tqdm import tqdm

from transformer_nuggets.fp8.scaled_quant import dynamic_scaled_quant, eager_dynamic_scaled_quant
from transformer_nuggets.fp8.scaled_quant import (
dynamic_scaled_quant,
eager_dynamic_scaled_quant,
)
from transformer_nuggets.utils import benchmark_torch_function_in_microseconds

device = torch.device("cuda")
Expand Down
12 changes: 10 additions & 2 deletions benchmarks/fp8_sat_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,19 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:

# Correctness check:
nuggets_out = scaled_quant(
triton_hp_tensor, triton_abs_max, scale, config.low_precision_dtype, config.saturated
triton_hp_tensor,
triton_abs_max,
scale,
config.low_precision_dtype,
config.saturated,
)
nuggets_out_hp = nuggets_out.to(config.high_precision_dtype)
eager_out = eager_scaled_quant(
high_precision_tensor, eager_abs_max, scale, config.low_precision_dtype, config.saturated
high_precision_tensor,
eager_abs_max,
scale,
config.low_precision_dtype,
config.saturated,
).to(config.high_precision_dtype)
eager_out_hp = eager_out.to(config.high_precision_dtype)
with suppress(AssertionError):
Expand Down
11 changes: 9 additions & 2 deletions benchmarks/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
"""

# mypy: ignore-errors
import math
from dataclasses import dataclass
Expand Down Expand Up @@ -163,7 +164,9 @@ def build_rope_cache(self, idx: torch.Tensor) -> RoPECache:

def build_mask_cache(self, idx: torch.Tensor) -> MaskCache:
ones = torch.ones(
(self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool
(self.config.block_size, self.config.block_size),
device=idx.device,
dtype=torch.bool,
)
return torch.tril(ones).unsqueeze(0).unsqueeze(0)

Expand Down Expand Up @@ -361,7 +364,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


def build_rope_cache(
seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
seq_len: int,
n_elem: int,
dtype: torch.dtype,
device: torch.device,
base: int = 10000,
) -> RoPECache:
"""Enhanced Transformer with Rotary Position Embedding.
Expand Down
6 changes: 5 additions & 1 deletion benchmarks/qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ def mlp_experiment(config: ExperimentConfig) -> ExperimentResult:
bnb_mlp_time = -1.0

return ExperimentResult(
mlp_time, qlora_mlp_time, compiled_qlora_mlp_time, bnb_mlp_time, qlora_mlp_triton_time
mlp_time,
qlora_mlp_time,
compiled_qlora_mlp_time,
bnb_mlp_time,
qlora_mlp_triton_time,
)


Expand Down
7 changes: 6 additions & 1 deletion benchmarks/qlora_memory_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
import torch.nn as nn
from jsonargparse import CLI

from transformer_nuggets.quant.qlora import get_mlp_weights, get_sample_inputs, MLP, QloraMLP
from transformer_nuggets.quant.qlora import (
get_mlp_weights,
get_sample_inputs,
MLP,
QloraMLP,
)
from transformer_nuggets.utils.benchmark import save_memory_snapshot

logging.basicConfig(level=logging.INFO)
Expand Down
38 changes: 7 additions & 31 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,11 @@ dependencies = [

[project.optional-dependencies]
dev = [
"black==23.3.0",
"usort==1.0.6",
"ufmt==2.3.0",
"libcst==1.1.0",
"pre-commit==3.6.0",
"bumpver",
"pip-tools",
"pytest",
"ruff==0.3.0",
"ruff",
"jsonargparse",
"docstring-parser"
]
Expand All @@ -54,7 +50,8 @@ llama = [

# ---------- RUFF ------------
[tool.ruff]
ignore = ['E231', 'E731']
target-version = "py38"
line-length = 99
# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
Expand Down Expand Up @@ -85,29 +82,8 @@ exclude = [
"venv",
]

[tool.ruff.per-file-ignores]
"__init__.py" = ["F401", "F403"]

# ---------- UFMT ------------

[tool.usort]
first_party_detection = false
[tool.ruff.lint]
ignore = ['E231', 'E731']

# ---------- Black ------------
[tool.black]
target-version = ["py38"]
line-length = 99
include = '\.pyi?$'
exclude = '''
/(
\.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
'''
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401", "F403"]
7 changes: 6 additions & 1 deletion test/test_flash.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import pytest
import torch

from transformer_nuggets.flash import attention, BiasMode, build_causal_mask, build_rel_mask
from transformer_nuggets.flash import (
attention,
BiasMode,
build_causal_mask,
build_rel_mask,
)


def clone_grad_and_reset(tensor):
Expand Down
13 changes: 10 additions & 3 deletions test/test_qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@


@pytest.mark.parametrize(
"inpt_size, block_size, scaler_block_size", [(16384, 64, 256), (256, 16, 16), (1024, 32, 32)]
"inpt_size, block_size, scaler_block_size",
[(16384, 64, 256), (256, 16, 16), (1024, 32, 32)],
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
Expand All @@ -42,7 +43,8 @@ def test_reconstruction(


@pytest.mark.parametrize(
"inpt_size, block_size, scaler_block_size", [(16384, 64, 256), (256, 16, 16), (1024, 32, 32)]
"inpt_size, block_size, scaler_block_size",
[(16384, 64, 256), (256, 16, 16), (1024, 32, 32)],
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
Expand Down Expand Up @@ -210,7 +212,12 @@ def test_bitsandbytes_mlp_parity(embed_dim, compile, dtype):
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
def test_qlora_linear(
embed_dim: int, compile: bool, r: int, dropout: float, run_backward: bool, dtype: torch.dtype
embed_dim: int,
compile: bool,
r: int,
dropout: float,
run_backward: bool,
dtype: torch.dtype,
):
torch._dynamo.reset()
torch.manual_seed(0)
Expand Down
8 changes: 5 additions & 3 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def test_breakpoint():
0.0,
]
)
with pytest.raises(RuntimeError, match="returned a NaN"), mock.patch(
"builtins.breakpoint"
) as mock_breakpoint, NanInfDetect(do_breakpoint=True):
with (
pytest.raises(RuntimeError, match="returned a NaN"),
mock.patch("builtins.breakpoint") as mock_breakpoint,
NanInfDetect(do_breakpoint=True),
):
print(torch.div(a, a))
mock_breakpoint.assert_called_once()

Expand Down
10 changes: 9 additions & 1 deletion transformer_nuggets/fp8/scaled_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,15 @@ def scaled_quant(
tl_dtype = {torch.float8_e4m3fn: tl.float8e4nv, torch.float8_e5m2: tl.float8e5}[fp8_dtype]
max_val = torch.finfo(fp8_dtype).max if saturated else 0.0
scaled_cast[grid](
inpt_tensor, out_tensor, scale, abs_max, numel, 4096, tl_dtype, max_val, num_warps=8
inpt_tensor,
out_tensor,
scale,
abs_max,
numel,
4096,
tl_dtype,
max_val,
num_warps=8,
)
return out_tensor

Expand Down
5 changes: 4 additions & 1 deletion transformer_nuggets/llama/finetune.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Used to train a model from scratch on big dense blocks of text data using causal attention.
"""

import argparse
import functools
import logging
Expand Down Expand Up @@ -77,7 +78,9 @@ def main(
)
qlora.swap_for_qlora(model, qlora_config, torch.bfloat16)
model.setup_caches(
hyper_params.micro_batch_size, hyper_params.max_seq_length, training_config.device
hyper_params.micro_batch_size,
hyper_params.max_seq_length,
training_config.device,
)

if rank == 0:
Expand Down
3 changes: 1 addition & 2 deletions transformer_nuggets/llama/prepare_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
""" Use the tokenizer and data to prepare pretrain dataset - Heavily inspired by Nanogpt"""

"""Use the tokenizer and data to prepare pretrain dataset - Heavily inspired by Nanogpt"""

import logging
import os
Expand Down
5 changes: 4 additions & 1 deletion transformer_nuggets/llama/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Used to train a model from scratch on big dense blocks of text data using causal attention.
"""

import csv
import logging
import math
Expand Down Expand Up @@ -164,7 +165,9 @@ def main(
model.init_parameters()

model.setup_caches(
hyper_params.micro_batch_size, hyper_params.max_seq_length, training_config.device
hyper_params.micro_batch_size,
hyper_params.max_seq_length,
training_config.device,
)

logging.info("Setting up the dataloaders")
Expand Down
12 changes: 10 additions & 2 deletions transformer_nuggets/quant/dequant_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ def dequantize(inputs, nf4_lut):

@triton.jit
def dequantize_scalers(
quantized_scalers_ptr, quantization_factor_ptr, scaler_mean_ptr, block_size, scaler_block_size
quantized_scalers_ptr,
quantization_factor_ptr,
scaler_mean_ptr,
block_size,
scaler_block_size,
):
"""Dequantizes the quantized scalers to bfloat16
Args:
Expand Down Expand Up @@ -68,7 +72,11 @@ def dequant_nf4_tensor_kernel(

# Dequantize the double quantized scalers
block_scaler = dequantize_scalers(
quantized_scalers_ptr, quantization_factor_ptr, scaler_mean_ptr, XBLOCK, scaler_block_size
quantized_scalers_ptr,
quantization_factor_ptr,
scaler_mean_ptr,
XBLOCK,
scaler_block_size,
)

scaled_first = dequantized_first * block_scaler
Expand Down
Loading

0 comments on commit 694f73d

Please sign in to comment.