Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4-bit is 10x slower compared to fp16 LLaMa #82

Closed
fpgaminer opened this issue Mar 26, 2023 · 27 comments
Closed

4-bit is 10x slower compared to fp16 LLaMa #82

fpgaminer opened this issue Mar 26, 2023 · 27 comments

Comments

@fpgaminer
Copy link

On my setup the stock 16-bit 7B LLaMa model runs at 0.6s per iteration with a 1x2048 input. The 4-bit quantized model runs at 8.3s per iteration. That makes the 4-bit version 10x slower than the non-quantized model. Is that normal?

Setup is GPTQ-for-LLaMa at 19c0535; RTX 3090; Environment listed below; My code is listed below; 7B model quantized using c4 --wbits 4 --true-sequential --act-order; Driver Version: 515.86.01; CUDA Version: 11.7

Environment

name: llama
channels:
  - pytorch
  - nvidia
  - conda-forge
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - asttokens=2.2.1=pyhd8ed1ab_0
  - backcall=0.2.0=pyh9f0ad1d_0
  - backports=1.0=pyhd8ed1ab_3
  - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
  - blas=1.0=mkl
  - brotlipy=0.7.0=py310h7f8727e_1002
  - bzip2=1.0.8=h7b6447c_0
  - ca-certificates=2022.12.7=ha878542_0
  - certifi=2022.12.7=pyhd8ed1ab_0
  - cffi=1.15.1=py310h5eee18b_3
  - charset-normalizer=2.0.4=pyhd3eb1b0_0
  - cryptography=39.0.1=py310h9ce1e76_0
  - cuda-cudart=11.7.99=0
  - cuda-cupti=11.7.101=0
  - cuda-libraries=11.7.1=0
  - cuda-nvrtc=11.7.99=0
  - cuda-nvtx=11.7.91=0
  - cuda-runtime=11.7.1=0
  - cudatoolkit-dev=11.7.0=h1de0b5d_6
  - debugpy=1.5.1=py310h295c915_0
  - decorator=5.1.1=pyhd8ed1ab_0
  - entrypoints=0.4=pyhd8ed1ab_0
  - executing=1.2.0=pyhd8ed1ab_0
  - ffmpeg=4.3=hf484d3e_0
  - filelock=3.9.0=py310h06a4308_0
  - flit-core=3.8.0=py310h06a4308_0
  - freetype=2.12.1=h4a9f257_0
  - giflib=5.2.1=h5eee18b_3
  - gmp=6.2.1=h295c915_3
  - gmpy2=2.1.2=py310heeb90bb_0
  - gnutls=3.6.15=he1e5248_0
  - idna=3.4=py310h06a4308_0
  - intel-openmp=2021.4.0=h06a4308_3561
  - ipykernel=6.15.0=pyh210e3f2_0
  - ipython=8.11.0=pyh41d4057_0
  - jedi=0.18.2=pyhd8ed1ab_0
  - jinja2=3.1.2=py310h06a4308_0
  - jpeg=9e=h5eee18b_1
  - jupyter_client=7.3.4=pyhd8ed1ab_0
  - jupyter_core=4.12.0=py310hff52083_0
  - lame=3.100=h7b6447c_0
  - lcms2=2.12=h3be6417_0
  - ld_impl_linux-64=2.38=h1181459_1
  - lerc=3.0=h295c915_0
  - libcublas=11.10.3.66=0
  - libcufft=10.7.2.124=h4fbf590_0
  - libcufile=1.6.0.25=0
  - libcurand=10.3.2.56=0
  - libcusolver=11.4.0.1=0
  - libcusparse=11.7.4.91=0
  - libdeflate=1.17=h5eee18b_0
  - libffi=3.4.2=h6a678d5_6
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libiconv=1.16=h7f8727e_2
  - libidn2=2.3.2=h7f8727e_0
  - libnpp=11.7.4.75=0
  - libnvjpeg=11.8.0.2=0
  - libpng=1.6.39=h5eee18b_0
  - libsodium=1.0.18=h36c2ea0_1
  - libstdcxx-ng=11.2.0=h1234567_1
  - libtasn1=4.16.0=h27cfd23_0
  - libtiff=4.5.0=h6a678d5_2
  - libunistring=0.9.10=h27cfd23_0
  - libuuid=1.41.5=h5eee18b_0
  - libwebp=1.2.4=h11a3e52_1
  - libwebp-base=1.2.4=h5eee18b_1
  - lz4-c=1.9.4=h6a678d5_0
  - markupsafe=2.1.1=py310h7f8727e_0
  - matplotlib-inline=0.1.6=pyhd8ed1ab_0
  - mkl=2021.4.0=h06a4308_640
  - mkl-service=2.4.0=py310h7f8727e_0
  - mkl_fft=1.3.1=py310hd6ae3a3_0
  - mkl_random=1.2.2=py310h00e6091_0
  - mpc=1.1.0=h10f8cd9_1
  - mpfr=4.0.2=hb69a4c5_1
  - ncurses=6.4=h6a678d5_0
  - nest-asyncio=1.5.6=pyhd8ed1ab_0
  - nettle=3.7.3=hbbd107a_1
  - networkx=2.8.4=py310h06a4308_1
  - numpy=1.23.5=py310hd5efca6_0
  - numpy-base=1.23.5=py310h8e6c178_0
  - openh264=2.1.1=h4ff587b_0
  - openssl=1.1.1t=h7f8727e_0
  - packaging=23.0=pyhd8ed1ab_0
  - parso=0.8.3=pyhd8ed1ab_0
  - pexpect=4.8.0=pyh1a96a4e_2
  - pickleshare=0.7.5=py_1003
  - pillow=9.4.0=py310h6a678d5_0
  - pip=23.0.1=py310h06a4308_0
  - prompt-toolkit=3.0.38=pyha770c72_0
  - prompt_toolkit=3.0.38=hd8ed1ab_0
  - ptyprocess=0.7.0=pyhd3deb0d_0
  - pure_eval=0.2.2=pyhd8ed1ab_0
  - pycparser=2.21=pyhd3eb1b0_0
  - pygments=2.14.0=pyhd8ed1ab_0
  - pyopenssl=23.0.0=py310h06a4308_0
  - pysocks=1.7.1=py310h06a4308_0
  - python=3.10.10=h7a1cb2a_2
  - python-dateutil=2.8.2=pyhd8ed1ab_0
  - python_abi=3.10=2_cp310
  - pytorch=2.0.0=py3.10_cuda11.7_cudnn8.5.0_0
  - pytorch-cuda=11.7=h778d358_3
  - pytorch-mutex=1.0=cuda
  - pyzmq=23.2.0=py310h6a678d5_0
  - readline=8.2=h5eee18b_0
  - requests=2.28.1=py310h06a4308_1
  - setuptools=65.6.3=py310h06a4308_0
  - six=1.16.0=pyhd3eb1b0_1
  - sqlite=3.41.1=h5eee18b_0
  - stack_data=0.6.2=pyhd8ed1ab_0
  - sympy=1.11.1=py310h06a4308_0
  - tk=8.6.12=h1ccaba5_0
  - torchaudio=2.0.0=py310_cu117
  - torchtriton=2.0.0=py310
  - torchvision=0.15.0=py310_cu117
  - tornado=6.1=py310h5764c6d_3
  - traitlets=5.9.0=pyhd8ed1ab_0
  - typing_extensions=4.4.0=py310h06a4308_0
  - tzdata=2022g=h04d1e81_0
  - urllib3=1.26.14=py310h06a4308_0
  - wcwidth=0.2.6=pyhd8ed1ab_0
  - wheel=0.38.4=py310h06a4308_0
  - xz=5.2.10=h5eee18b_1
  - zeromq=4.3.4=h9c3ff4c_1
  - zlib=1.2.13=h5eee18b_0
  - zstd=1.5.2=ha4553b6_0
  - pip:
      - accelerate==0.17.1
      - aiohttp==3.8.4
      - aiosignal==1.3.1
      - async-timeout==4.0.2
      - attrs==22.2.0
      - datasets==2.10.1
      - dill==0.3.6
      - frozenlist==1.3.3
      - fsspec==2023.3.0
      - huggingface-hub==0.13.3
      - mpmath==1.2.1
      - multidict==6.0.4
      - multiprocess==0.70.14
      - pandas==1.5.3
      - psutil==5.9.4
      - pyarrow==11.0.0
      - pytz==2022.7.1
      - pyyaml==6.0
      - quant-cuda==0.0.0
      - regex==2023.3.22
      - responses==0.18.0
      - safetensors==0.3.0
      - sentencepiece==0.1.97
      - tokenizers==0.13.2
      - tqdm==4.65.0
      - transformers==4.28.0.dev0
      - xxhash==3.2.0
      - yarl==1.8.2

Code

#!/usr/bin/env python3
import argparse

import torch
import torch.nn as nn
import transformers
from datasets import load_dataset
from modelutils import find_layers
from quant import make_quant
from tqdm import tqdm
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM

from llama import get_llama

parser = argparse.ArgumentParser()

parser.add_argument('--model', type=str, help='Path to HuggingFace model')
parser.add_argument('--quant', type=str, help='Path to quantized model')
parser.add_argument('--stride', type=int, default=512, help='Stride for calculating perplexity')
parser.add_argument('--wbits', type=int, default=4, help='Number of bits for weights')
parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize used during quantization')


def main():
	global args  # Hack for load_quant
	args = parser.parse_args()

	if not args.quant:
		model = get_llama(args.model)
		model.eval()
		model.to('cuda')
	else:
		model = load_quant(args.model, args.quant, args.wbits, args.groupsize)
		model.eval()
		model.to('cuda')
	
	tokenizer = AutoTokenizer.from_pretrained(args.model)

	for dataset in ['wikitext-2', 'ptb', 'c4']:
		ppl = calculate_perplexity(model, tokenizer, dataset, max_length=model.seqlen, stride=args.stride)
		print(f"{dataset} perplexity: {ppl}")


# NOTE: Have to modify this to work around the usage of `args` in the original...
def load_quant(model, checkpoint, wbits, groupsize):
	config = LlamaConfig.from_pretrained(model)
	def noop(*args, **kwargs):
		pass
	torch.nn.init.kaiming_uniform_ = noop 
	torch.nn.init.uniform_ = noop 
	torch.nn.init.normal_ = noop 

	torch.set_default_dtype(torch.half)
	transformers.modeling_utils._init_weights = False
	torch.set_default_dtype(torch.half)
	model = LlamaForCausalLM(config)
	torch.set_default_dtype(torch.float)
	model = model.eval()
	layers = find_layers(model)
	for name in ['lm_head']:
		if name in layers:
			del layers[name]
	make_quant(model, layers, wbits, groupsize, faster=False)

	del layers

	print('Loading model ...')
	if checkpoint.endswith('.safetensors'):
		from safetensors.torch import load_file as safe_load
		model.load_state_dict(safe_load(checkpoint))
	else:
		model.load_state_dict(torch.load(checkpoint))
	model.seqlen = 2048
	print('Done.')

	return model


def get_dataset(dataset_name: str, tokenizer) -> torch.Tensor:
	if dataset_name == "wikitext-2":
		test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
		encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt").input_ids
	elif dataset_name == 'ptb':
		test = load_dataset("ptb_text_only", 'penn_treebank', split="validation")
		encodings = tokenizer("\n\n".join(test["sentence"]), return_tensors="pt").input_ids
	elif dataset_name == 'c4':
		# WARNING: Many of the files in the allenai/c4 repo are marked as "Unsafe" by HuggingFace, possibly containing a virus.  This particular file is not, and I doubt it's an issue, but worth noting.
		test = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')
		encodings = [tokenizer(x, return_tensors="pt").input_ids for x in test['text'][:1000]]
		encodings = torch.cat(encodings, dim=1)
	else:
		raise ValueError(f"Unknown dataset {dataset_name}")

	return encodings


def calculate_perplexity(model, tokenizer, dataset: str, max_length: int, stride: int = 512) -> float:
	encodings = get_dataset(dataset, tokenizer)
	seq_len = encodings.size(1)

	print(f"Sequence length: {seq_len}")
	print(f"Max length: {max_length}")
	print(f"Stride: {stride}")

	nlls = []
	prev_end_loc = 0

	for begin_loc in tqdm(range(0, seq_len - 1, stride)):
		end_loc = min(seq_len - 1, begin_loc + max_length)
		trg_len = end_loc - prev_end_loc  # How many tokens we want to predict
		input_ids = encodings[:, begin_loc:end_loc+1].to('cuda')  # +1 for the labels

		with torch.no_grad():
			# Ask the model for logits
			outputs = model(input_ids[:, :-1])
			# We only want the last trg_len logits
			logits = outputs.logits[..., -trg_len:, :].contiguous()
			# The last trg_len tokens are the labels
			labels = input_ids[:, -trg_len:].contiguous()

			# Compute the NLL for this batch using flattened logits and labels
			loss_fct = nn.CrossEntropyLoss()
			loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
		
		nlls.append(loss)

		prev_end_loc = end_loc
		if end_loc == (seq_len - 1):
			break

	ppl = torch.exp(torch.stack(nlls).mean())

	return ppl


if __name__ == '__main__':
	main()

@fpgaminer
Copy link
Author

I've done some preliminary debugging and found that test_kernel.py is able to demonstrate the issue. I adjust B in the code for each of the runs below, with the values 16, 32, and 64. It looks like the quantized kernels perform faster than FP16 at B=16 and below, but at >=32 they slow down considerably.

> CUDA_VISIBLE_DEVICES=0 python test_kernel.py
Benchmarking LLaMa-7B FC2 matvec ...
FP16: 0.0007160284519195557
2bit: 0.0007571296691894531
3bit: 0.0005737292766571045
3bit: 0.0004833848476409912 (faster)
4bit: 0.0005299983024597168
8bit: 0.00042980456352233887

> CUDA_VISIBLE_DEVICES=0 python test_kernel.py
Benchmarking LLaMa-7B FC2 matvec ...
FP16: 0.0007020277976989746
2bit: 0.001501572847366333
3bit: 0.0011098787784576416
3bit: 0.0009780402183532716 (faster)
4bit: 0.0010332670211791993
8bit: 0.0008434817790985107

> CUDA_VISIBLE_DEVICES=0 python test_kernel.py
Benchmarking LLaMa-7B FC2 matvec ...
FP16: 0.0006992788314819336
2bit: 0.0030598652362823484
3bit: 0.0022263057231903076
3bit: 0.0019433841705322266 (faster)
4bit: 0.0020571749210357665
8bit: 0.0016708543300628662

@MasterTaffer
Copy link
Contributor

MasterTaffer commented Mar 26, 2023

I have the same problem on my RTX 3080, driver 530.41.03, cuda 11.7. The performance of the 4bit quantized models is very slow with large contexts. In fact, it seems that it is (much!) faster to unpack the layer weights on the fly and use standard PyTorch matmul when sufficiently large matrices are involved. Here's a hackish implementation of QuantLinear that falls back to PyTorch when context size becomes larger:

MasterTaffer/GPTQ-for-LLaMa@b46c976

On my test setup 4-bit 13B LLaMa generating 20 tokens with 2000 context tokens, inference speed is improved by ~5x or so.

@Qubitium
Copy link
Contributor

@MasterTaffer Want to test your changes but what exactly does this commit fix? I know the comment is groupsize fix but the current repo there are no "reported " bugs of group size issues? Did you run into something that others have not encountered?

7fb58e3

@EyeDeck
Copy link

EyeDeck commented Mar 27, 2023

@MasterTaffer
Any chance of a making a pull request? Hackish or not, it's also ridiculously faster for me.
Testing on the a 3090 with the various LLaMA 30B pre-quantized models linked here: oobabooga/text-generation-webui#530
and running through text-generation-webui, the --wbits 3 --group-size 128, --wbits 3 --group-size 32, --wbits 4 --act-order, and --wbits 4 --group-size 128 quantizations, I see the the same ~5x inference improvement.
I'm seeing a VRAM usage increase of around 5%*, so I guess it'll probably need a toggle.

@diegomontoya
That's a followup commit for MasterTaffer's own MasterTaffer@b46c976, one of the four models I tested (the --wbits 4 --group-size 128 one) crashed without that change.

*edit: Don't quote me on that though, with some models it looks slightly higher, but for some reason I'm currently able to run a LLaMA 30B --wbits 4 --true-sequential --act-order model at full context on my 3090 without OOMing, while the same setup except without these changes would always OOM yesterday. I have no idea why.

@sterlind
Copy link

Does this only apply when group-size is enabled? I compared your change to baseline (as of latest commit as of writing this comment) and time is pretty much the same. But I quantized with --true-sequential --act-order rather than specifying any group size.

@EyeDeck
Copy link

EyeDeck commented Mar 27, 2023

What's your context size? The main improvement I'm seeing is the elimination of the delay before output starts being generated. With a lot of context tokens (like 1800) I usually see a 45-50 second delay on LLaMA 30B; with those changes, that drops to low single digits.

@fpgaminer
Copy link
Author

I'm working on a kernel implementation in Triton. My hope is to lean on Triton's ability to optimize to the hardware on the fly, as well as implement the matmul kernel in a more cache optimal way versus the current CUDA kernel.

So far I have a working kernel, though I haven't fully verified accuracy. It's performance curve is a lot better. Not as good as FP16 PyTorch yet, but at least in the ballpark now and scales correctly with context length. I've included the code below. WIP I still need to more thoroughly evaluate correctness. As of right now I'm seeing an absolute error of 0.0039 on 256x4096 test vectors relative to fp16 simulation.

output

The one major snag I've hit is that Triton doesn't seem to have a way of expanding a tensor. i.e. something similar to PyTorch's repeat_interleave. i.e. I can't fetch the quantized weights of size [K//8, N], and then unpack them in SRAM into [K, N]. The hack for now is to configure the fetch like a repeat_interleave. But that means we lose a lot of performance since it's doing 8x the loads compared to optimal. I think the kernel would run faster than PyTorch if I could fix this; bandwidth tends to be the major performance bottleneck.

@triton.autotune(
	configs=[
		triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
		triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
		triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
		triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
		triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
		triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
		triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
		triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
		triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
		triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
		triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
	],
	key=['M', 'N', 'K'],
)
@triton.jit
def matmul4_kernel(
	a_ptr, b_ptr, c_ptr, #debug_ptr,
	scales_ptr, zeros_ptr,
	M, N, K,
	stride_am, stride_ak,
	stride_bk, stride_bn,
	stride_cm, stride_cn,
	stride_scales, stride_zeros, #stride_dk, stride_dn,
	BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
	GROUP_SIZE_M: tl.constexpr,
):
	"""
	Compute the matrix multiplication C = A x B.
	A is of shape (M, K) float16
	B is of shape (K//8, N) int32
	C is of shape (M, N) float16
	scales is of shape (1, N) float16
	zeros is of shape (1, N) float16

	WARNING: This kernel assumes that K is a multiple of BLOCK_SIZE_K.
	"""
	pid = tl.program_id(axis=0)
	num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
	num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
	num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
	num_pid_in_group = GROUP_SIZE_M * num_pid_n
	group_id = pid // num_pid_in_group
	first_pid_m = group_id * GROUP_SIZE_M
	group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
	pid_m = first_pid_m + (pid % group_size_m)
	pid_n = (pid % num_pid_in_group) // group_size_m

	offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
	offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
	offs_k = tl.arange(0, BLOCK_SIZE_K)
	a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)   # (BLOCK_SIZE_M, BLOCK_SIZE_K)
	# b_ptrs is set up such that it repeats elements along the K axis 8 times
	b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn)   # (BLOCK_SIZE_K, BLOCK_SIZE_N)
	scales_ptrs = scales_ptr + offs_bn * stride_scales
	zeros_ptrs = zeros_ptr + offs_bn * stride_zeros

	# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
	scales = tl.load(scales_ptrs)  # (BLOCK_SIZE_N,)
	zeros = tl.load(zeros_ptrs)  # (BLOCK_SIZE_N,)

	# shifter is used to extract the 4 bits of each element in the 32-bit word from B
	shifter = (offs_k % 8) * 4

	# For debugging
	#offs_dk = 0 + tl.arange(0, BLOCK_SIZE_K)
	#offs_dn = 0 + tl.arange(0, BLOCK_SIZE_N)
	#debug_ptrs = debug_ptr + stride_dk * offs_dk[:, None] + stride_dn * offs_dn[None, :]
	#tl.store(debug_ptrs, b)

	# Now calculate a block of output of shape (BLOCK_SIZE_M, BLOCK_SIZE_N)
	# M is along the batch dimension, N is along the outfeatures dimension, K is along the infeatures dimension
	# So this loop is along the infeatures dimension (K)
	# It's calculating BLOCK_SIZE_M batches in parallel, and for each batch, BLOCK_SIZE_N outfeatures in parallel
	accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
	for k in range(0, num_pid_k):
		a = tl.load(a_ptrs)
		b = tl.load(b_ptrs)   # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated

		# Now we need to unpack b (which is 4-bit values) into 32-bit values
		b = (b >> shifter[:, None]) & 0xF  # Extract the 4-bit values
		b = b * scales[None, :] - zeros[None, :]  # Scale and shift
		#tl.store(debug_ptrs, b)

		accumulator += tl.dot(a, b)
		a_ptrs += BLOCK_SIZE_K * stride_ak
		b_ptrs += (BLOCK_SIZE_K // 8) * stride_bk
		#debug_ptrs += BLOCK_SIZE_K * stride_dk
	
	c = accumulator.to(tl.float16)
	
	# Store the result
	offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
	offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
	c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
	c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
	tl.store(c_ptrs, accumulator, mask=c_mask)


def triton_matmul4(a: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, bias: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
	"""
	Compute the matrix multiplication C = A x B + bias.
	Where B is quantized using GPTQ and groupsize = -1 into 4-bit values.

	A is of shape (..., K) float16
	qweight is of shape (K//8, N) int32
	scales is of shape (1, N) float16
	qzeros is of shape (1, N//8) int32
	bias is of shape (1, N) float16

	Returns C of shape (..., N) float16
	"""
	assert a.shape[-1] == qlayer.infeatures
	assert a.is_contiguous()

	# Flatten a into (-1, K)
	x = a.view(-1, a.shape[-1])

	M, K = x.shape
	N = qweight.shape[1]
	assert K % 32 == 0

	# Unpack zeros into (1, N) float16
	zeros = qzeros.flatten()  # (N//8,) int32
	zeros = torch.repeat_interleave(zeros, 8)  # ((N//8)*8,) int32
	shifter = (torch.arange(0, N, device='cuda', dtype=torch.int32) % 8) * 4
	zeros = (zeros >> shifter) & 0xF  # (N,) int32
	zeros = (zeros + 1) * scales[0]  # (N,) float16

	c = torch.empty((M, N), device='cuda', dtype=torch.float16)
	#debug = torch.empty((32*32, 128), device='cuda', dtype=torch.float32)
	grid = lambda META: (
		triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
	)
	#grid = lambda META: (1,)  # For debugging
	matmul4_kernel[grid](
		x, qweight, c, #debug,
		scales, zeros,
		M, N, K,
		x.stride(0), x.stride(1),
		qweight.stride(0), qweight.stride(1),
		c.stride(0), c.stride(1),
		scales.stride(1), zeros.stride(0),
		#debug.stride(0), debug.stride(1),
	)

	# Reshape c
	c = c.view(a.shape[:-1] + (qlayer.outfeatures,))  # (..., N)

	# Add bias
	if bias is not None:
		c = c + bias
	
	return c
	#return debug


#zeros = torch.empty(4096, dtype=torch.float32)
#for i in range(0, 4096):
#	zeros[i] = qlayer.qzeros[0, i // 8] >> ((i % 8) * 4) & 0xF
#	zeros[i] = (zeros[i] + 1) * qlayer.scales[0, i]

#zeros = zeros.to('cuda')

out = triton_matmul4(vec, qlayer.qweight, qlayer.scales.to(torch.float16, copy=True), qlayer.qzeros, qlayer.bias.to(torch.float16, copy=True))
#out = out + layer.bias.data
(out - sim).abs().max()

@sterlind
Copy link

sterlind commented Mar 27, 2023

@EyeDeck I'm not sure what you mean by context size. I'm using llama_inference.py to demonstrate inference, not text-generation-webui.

I quantized LLaMA-7B to 4 bits without group size (llama7b-4bit.pt) and with group size 128 (llama-7b-4bit-gp128.pt). Here are the precise args:
python llama.py ./llama-hf/llama-7b c4 --wbits 4 --true-sequential --act-order --save llama7b-4bit.pt
python llama.py ./llama-hf/llama-7b c4 --wbits 4 --true-sequential --groupsize 128 --save llama7b-4bit-gp128.pt

For benchmarking, I modify llama_inference.py to load checkpoints with accelerate, for my own sanity as it was so slow before:

    import accelerate
    model = accelerate.load_checkpoint_and_dispatch(
        model = model,
        checkpoint = checkpoint,
        device_map = "auto",
        no_split_module_classes = ["LlamaDecoderLayer"]
    )
    # if checkpoint.endswith('.safetensors'):
    #     from safetensors.torch import load_file as safe_load
    #     model.load_state_dict(safe_load(checkpoint))
    # else:
    #     model.load_state_dict(torch.load(checkpoint))

I then took @MasterTaffer 's quant.py and measured generation speed and loading time. For each, I ran a command like:

python llama_inference.py ./llama-hf/llama-7b --wbits 4 --load llama7b-4bit-gp128.pt --text "I think the meaning of life is" --max_length 256 --min_length 256 --groupsize 128, varying the version of quant.py, the group size (and checkpoint), and generating exactly 256 tokens.

--faster-kernel? New quant.py? Groupsize Model Loading Time (s) Evaluation Time (s)
Yes New 128 18.24 27.42
Yes New -1 12.49 24.89
Yes Old 128 17.20 29.47
Yes Old -1 9.26 27.17
No New 128 11.07 33.02
No New -1 40.26 26.91
No Old 128 11.46 27.42
No Old -1 15.92 26.97

It's very strange isn't it? Results are very scattered. I can't tell if anything has any performance benefit.

I am noticing only ~25% CUDA utilization, which is disappointing. Maybe the HuggingFace implementation is inefficient? Fwiw I'm on WSL2, using an RTX 4090 and PyTorch 2.0.0+cu117.

@IdiotSandwichTheThird
Copy link

@sterlind
What is meant by context size is the length of the input text, which has been a huge bottle neck up until now.
The new optimization, if I understand correctly, only kicks in after there is >128 tokens of context.
It should be really obvious, once you test with a few paragraphs for the input.

@Qubitium
Copy link
Contributor

@sterlind What is your cpu? 4090 requires the very least 12th gen intel or zen4 to have the cpu keep up with feeding the cuda cores. 25% gpu utilization is way too low. It should be 80% or higher when generating tokens. To make the result deterministic, try setting a fixed seed to the inference for all tests also.

@EyeDeck
Copy link

EyeDeck commented Mar 28, 2023

@sterlind
--text "I think the meaning of life is"
The tokenizer averages about 2/3rds of a word per token (depending on the words). In MasterTaffer's latest code (at least, the PR that was merged here), the new code doesn't kick in until 128 context tokens. Retry with a ~1300 word input and the difference should be immediately obvious even on a smaller 7B model, probably something on the order of 8-10 seconds of "latency" before you start getting output vs nearly instantaneous afterwards. Go up to a 30B model and the difference is enormous.

@fpgaminer
Copy link
Author

Okay, I've publish a more polished version of my Triton kernel: https://github.com/fpgaminer/GPTQ-triton

The README on that repo has more detailed metrics, but the Triton kernel indeed performs 10x faster than the CUDA kernel with context length 2048. In fact, it's always faster than the CUDA kernel for all context lengths. It's almost on par with FP16. And memory usage is the same as CUDA.

As for accuracy, it's exactly as accurate as the CUDA kernel across wikitext2, PTB, and C4 validation sets.

Currently the kernel only supports 4-bits and groupsize -1, and I've only tested the 7B LLaMa weights.

@alexconstant9108
Copy link

@fpgaminer hypothetically the speed up should be even bigger with the > 7B models, right? Have you had a chance to test with the 13B model for example?

@sterlind
Copy link

@sterlind What is your cpu? 4090 requires the very least 12th gen intel or zen4 to have the cpu keep up with feeding the cuda cores. 25% gpu utilization is way too low. It should be 80% or higher when generating tokens. To make the result deterministic, try setting a fixed seed to the inference for all tests also.

My CPU is an "old" Threadripper 1950X. Maybe I'm confused, but why should an older CPU struggle to feed the CUDA cores? Shouldn't it need to transfer very little to/from GPU memory once the model weights are loaded?

@Qubitium
Copy link
Contributor

Almost all inference code is single threaded so it doesn't matter if you have 16 cores, it will only use 1 per gpu.

Just monitor your cpu usage vs gpu usage. If your cpu (the core that is running python inference) is at 100% and gpu is 25%, the bottleneck is cpu. The gpu is waiting for more work while cpu is maxed out. For ref, 13900k is 2x the single core performance vs 1950x. After oc, likely 2.2x.

My CPU is an "old" Threadripper 1950X. Maybe I'm confused, but why should an older CPU struggle to feed the CUDA cores? Shouldn't it need to transfer very little to/from GPU memory once the model weights are loaded?

@qwopqwop200
Copy link
Owner

Okay, I've publish a more polished version of my Triton kernel: https://github.com/fpgaminer/GPTQ-triton

The README on that repo has more detailed metrics, but the Triton kernel indeed performs 10x faster than the CUDA kernel with context length 2048. In fact, it's always faster than the CUDA kernel for all context lengths. It's almost on par with FP16. And memory usage is the same as CUDA.

As for accuracy, it's exactly as accurate as the CUDA kernel across wikitext2, PTB, and C4 validation sets.

Currently the kernel only supports 4-bits and groupsize -1, and I've only tested the 7B LLaMa weights.

I rewrote the current GTPQ kernel to triton using your code. I actually experienced a very high speedup.
In addition, it has been changed to support 2,8 bit support and groupsize.
A 3-bit kernel seems difficult to implement as triton does not support indexing.

@Qubitium
Copy link
Contributor

Qubitium commented Mar 30, 2023

@qwopqwop200 Does the triton branch require re-quantization? Switching from cuda to triton branch is throwing the following for my 30b test model:

RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM:
	Missing key(s) in state_dict: "model.layers.0.self_attn.k_proj.g_idx", "model.layers.0.self_attn.o_proj.g_idx", "model.layers.0.self_attn.q_proj.g_idx", "model.layers.0.self_attn.v_proj.g_idx", "model.layers.0.mlp.down_proj.g_idx", "model.layers.0.mlp.gate_proj.g_idx", "model.layers.0.mlp.up_proj.g_idx", "model.layers.1.self_attn.k_proj.g_idx", "model.layers.1.self_attn.o_proj.g_idx", "model.layers.1.self_attn.q_proj.g_idx", "model.layers.1.self_attn.v_proj.g_idx", "model.layers.1.mlp.down_proj.g_idx", "model.layers.1.mlp.gate_proj.g_idx", "model.layers.1.mlp.up_proj.g_idx", "model.layers.2.self_attn.k_proj.g_idx", "model.layers.2.self_attn.o_proj.g_idx", "model.layers.2.self_attn.q_proj.g_idx", "model.layers.2.self_attn.v_proj.g_idx", "model.layers.2.mlp.down_proj.g_idx", "model.layers.2.mlp.gate_proj.g_idx", "model.layers.2.mlp.up_proj.g_idx", "model.layers.3.self_attn.k_proj.g_idx", "model.layers.3.self_attn.o_proj.g_idx", "model.layers.3.self_attn.q_proj.g_idx", "model.layers.3.self_attn.v_proj.g_idx", "model.layers.3.mlp.down_proj.g_idx", "model.layers.3.mlp.gate_proj.g_idx", "model.layers.3.mlp.up_proj.g_idx", "model.layers.4.self_attn.k_proj.g_idx", "model.layers.4.self_attn.o_proj.g_idx", "model.layers.4.self_attn.q_proj.g_idx", "model.layers.4.self_attn.v_proj.g_idx", "model.layers.4.mlp.down_proj.g_idx", "model.layers.4.mlp.gate_proj.g_idx", "model.layers.4.mlp.up_proj.g_idx", "model.layers.5.self_attn.k_proj.g_idx", "model.layers.5.self_attn.o_proj.g_idx", "model.layers.5.self_attn.q_proj.g_idx", "model.layers.5.self_attn.v_proj.g_idx", "model.layers.5.mlp.down_proj.g_idx", "model.layers.5.mlp.gate_proj.g_idx", "model.layers.5.mlp.up_proj.g_idx", "model.layers.6.self_attn.k_proj.g_idx", "model.layers.6.self_attn.o_proj.g_idx", "model.layers.6.self_attn.q_proj.g_idx", "model.layers.6.self_attn.v_proj.g_idx", "model.layers.6.mlp.down_proj.g_idx", "model.layers.6.mlp.gate_proj.g_idx", "model.layers.6.mlp.up_proj.g_idx", "model.layers.7.self_attn.k_proj.g_idx", "model.layers.7.self_attn.o_proj.g_idx", "model.layers.7.self_attn.q_proj.g_idx", "model.layers.7.self_attn.v_proj.g_idx", "model.layers.7.mlp.d
 Unexpected key(s) in state_dict: "model.layers.0.self_attn.k_proj.bias", "model.layers.0.self_attn.o_proj.bias", "model.layers.0.self_attn.q_proj.bias", "model.layers.0.self_attn.v_proj.bias", "model.layers.0.mlp.down_proj.bias", "model.layers.0.mlp.gate_proj.bias", "model.layers.0.mlp.up_proj.bias", "model.layers.1.self_attn.k_proj.bias", "model.layers.1.self_attn.o_proj.bias", "model.layers.1.self_attn.q_proj.bias", "model.layers.1.self_attn.v_proj.bias", "model.layers.1.mlp.down_proj.bias", "model.layers.1.mlp.gate_proj.bias", "model.layers.1.mlp.up_proj.bias", "model.layers.2.self_attn.k_proj.bias", "model.layers.2.self_attn.o_proj.bias", "model.layers.2.self_attn.q_proj.bias", "model.layers.2.self_attn.v_proj.bias", "model.layers.2.mlp.down_proj.bias", "model.layers.2.mlp.gate_proj.bias", "model.layers.2.mlp.up_proj.bias", "model.layers.3.self_attn.k_proj.bias", "model.layers.3.self_attn.o_proj.bias", "model.layers.3.self_attn.q_proj.bias", "model.layers.3.self_attn.v_proj.bias", "model.layers.3.mlp.down_proj.bias", "model.layers.3.mlp.gate_proj.bias", "model.layers.3.mlp.up_proj.bias", "model.layers.4.self_attn.k_proj.bias", "model.layers.4.self_attn.o_proj.bias", "model.layers.4.self_attn.q_proj.bias", "model.layers.4.self_attn.v_proj.bias", "model.layers.4.mlp.down_proj.bias", "model.layers.4.mlp.gate_proj.bias", "model.layers.4.mlp.up_proj.bias", "model.layers.5.self_attn.k_proj.bias", "model.layers.5.self_attn.o_proj.bias", "model.layers.5.self_attn.q_proj.bias", "model.layers.5.self_attn.v_proj.bias", "model.layers.5.mlp.down_proj.bias", "model.layers.5.mlp.gate_proj.bias", "model.layers.5.mlp.up_proj.bias", "model.layers.6.self_attn.k_proj.bias", "model.layers.6.self_attn.o_proj.bias",

The 30b model is 4bit quantized using only act-order and true-sequential.

@qwopqwop200
Copy link
Owner

qwopqwop200 commented Mar 30, 2023

It probably needs re-quantization.
maybe,You can get it working with some code changes.
https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/triton/quant.py#L248

setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, True))

@USBhost
Copy link
Contributor

USBhost commented Mar 30, 2023

I was able to run my pytorch branch converted model on triton under ooba fine. Tho I had to remove the options that are no longer used for triton.

@Qubitium
Copy link
Contributor

@USBhost Are you getting degraded quality of output under triton branch? I am getting both performance regression and massive quality drop-off under triton branch using re-quantized 30b models. Eval score are normal. The output is wildly diverging from cuda branch with same temp/top-p/top_k/etc config. Still trying to isolate issue.

@USBhost
Copy link
Contributor

USBhost commented Mar 31, 2023

Can't say I have. But I am using 65b so idk. I'm still trying to figure out why evaluating c4 etc... Keeps changing per day or half a day later. Either I'm being a dummy and doing something wrong or something else is happening.

@sterlind
Copy link

sterlind commented Apr 1, 2023

Almost all inference code is single threaded so it doesn't matter if you have 16 cores, it will only use 1 per gpu.

Just monitor your cpu usage vs gpu usage. If your cpu (the core that is running python inference) is at 100% and gpu is 25%, the bottleneck is cpu. The gpu is waiting for more work while cpu is maxed out. For ref, 13900k is 2x the single core performance vs 1950x. After oc, likely 2.2x.

The bottleneck seems to just be that there's a ho-jillion CUDA kernels being executed one after another:
image

each one of those tiny lines is a launchCudaKernel call. The average duration is of each kernel launch is ~20us. there's literally hundreds of thousands of such events in my trace.
note that I didn't see any appreciable time in memCopyAsync after initial model load, as I expected. very brief (<<1ms) copies, presumably to send the predicted tokens back to the CPU side or something, but it really seems like the overhead is just the mammothly inefficient way that transformers and PyTorch work by default - requiring thousands and thousands of round-trips that are nearly instantaneous to execute on the GPU's side.

Tomorrow I will see about using PyTorch 2.0's jit.compile or ejecting a TorchScript or ONNX model or something. What I've done in the past is capture and replay entire CUDA graphs, but I couldn't really do that without carving into the transformers library in a pretty invasive way.

Btw I'm using nsight-compute 2023.1 for my data capture and analysis, with CUDA 12.1 and nsight-systems 2023.2.1. I built my own PyTorch from source to support this, since I'm on WSL and profiling support there is bleeding edge.

@aljungberg
Copy link
Contributor

I think what you're seeing may be completely normal for an untraced HuggingFace transformer, not really specific to this case? It's a composition of unfused PyTorch modules after all, each one launching at least one kernel, with multiple modules per layer.

Definitely feels like something PyTorch's new compiler might make a difference on, or the older jit tracer. I wonder how well that interplays with the new Triton matmul kernel though. Can it really fuse custom kernels? I would be surprised. On the other hand if it can fuse the vanilla PyTorch stuff that might by itself make a big difference.

@EyeDeck
Copy link

EyeDeck commented Apr 1, 2023

I guess nobody else has mentioned it here, but the current CUDA branch generates at a rate of around 9 seconds per token with a freshly requantized 4-bit LLaMA 30B on a 3090 and ~1850 context tokens. This commit 608f3ba (with an older equivalent quantization that works with it), all else equal, runs at around 5 tokens per second.

@qwopqwop200
Copy link
Owner

I guess nobody else has mentioned it here, but the current CUDA branch generates at a rate of around 9 seconds per token with a freshly requantized 4-bit LLaMA 30B on a 3090 and ~1850 context tokens. This commit 608f3ba (with an older equivalent quantization that works with it), all else equal, runs at around 5 tokens per second.

This is because the CUDA kernel has been changed to support act-order and groupsize at the same time. Because of this, we recommend Triton for now.

@fxmarty
Copy link

fxmarty commented Apr 6, 2023

@fpgaminer Do you have an idea why your triton implem is better than the cuda one?

@fpgaminer
Copy link
Author

@fpgaminer Do you have an idea why your triton implem is better than the cuda one?

I'm not terribly well versed in PTX so I can't say for certain.

Each instance of the CUDA kernel only calculates a 1 x 1 x BLOCK_SIZE_K result, and launches a grid of (K//BLOCK_SIZE_K, N, M). BLOCK_SIZE_K is of a fixed size, regardless of M, N, or K.

The Triton kernel calculates a full BLOCK_SIZE_M x BLOCK_SIZE_N x K block in each instance.

To the best of my understanding: the CUDA kernel does less work per thread and launches more threads; the Triton kernel does more work per thread and launches less threads. The end results is that the CUDA kernel has to re-fetch data more often than the Triton kernel. This is fine when the data involved fits in the L2 cache, but when it doesn't the Triton kernel dominates. This occurs in all cases where M>1. The Triton kernel also also auto-adapts the block size based on M, N, and K. In all cases the Triton kernel has competitive performance to PyTorch FP16.

The one downside right now is that Triton doesn't have support for unpacking quantized data, so I have to do some hacks to get it to work. It works fine, but it isn't getting any of the bandwidth benefits it should. In theory a set of re-written CUDA kernels would handily beat it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests