This repository provides the official implementations of efficient 2:4 pre-training toolkit from the following papers.
Accelerating Transformer Pre-training with 2:4 Sparsity [arXiv] [OpenReview] [PDF]
Yuezhou Hu, Kang Zhao, Weiyu Huang, Jianfei Chen, Jun Zhu
International Conference on Machine Learning (ICML), 2024
S-STE: Continuous Pruning Function for Efficient 2:4 Sparse Pre-training [arXiv] [OpenReview]
Yuezhou Hu, Jun Zhu, Jianfei Chen
Neural Information Processing Systems (NeurIPS), 2024
Meanwhile, we also provide our implementation for some relevant papers/algorithms: STEP.
From source:
git clone --recursive
cd 2by4-pretrain
pip install -e .
To get started with 2:4-spMM, official torch.sparse works well enough. However, we've added more features on top of that.
Constructing 2:4 tensor
To construct a sparse semi-structured tensor, simply calling sparse.to_sparse_semi_structured
would work:
import sparse
A = torch.randn(128, 128, device='cuda:0', dtype=torch.half)
A_sparse = sparse.to_sparse_semi_structured(A)
Different from PyTorch, this would automatically prune the smallest two elements out of four.
You can also specify a certain 2:4 mask for this step. Typically, the mask is a 0/1 tensor (dtype does not matter) which indicates how to prune the tensor:
A_sparse = sparse.to_sparse_semi_structured(A, mask=your_mask)
Additionally, our toolkit supports minimum-variance unbiased estimator (MVUE) as its pruning strategy:
A_sparse = sparse.to_sparse_semi_structured(A, MVUE24=True)
Support for different dtype
We now support float16, bfloat16, int8, float8_e5m2 and float8_e4m3fn in dense-sparse conversion. Let's try this out:
A = torch.randn(128, 128, device='cuda:0')
A_sparse = sparse.to_sparse_semi_structured(A, dtype=torch.float16)
A_sparse = sparse.to_sparse_semi_structured(A, dtype=torch.int8)
A_sparse = sparse.to_sparse_semi_structured(A, dtype=torch.float8_e5m2)
This will provide A_sparse in dtype
, regardless of its original type.
2:4 operations
Same as PyTorch, those operations are supported:
- torch.addmm(bias, dense, sparse.t())
-, sparse)
-, dense)
- aten.linear.default(dense, sparse, bias)
- aten.t.default(sparse)
- aten.t.detach(sparse)
There are two 2:4-spMM kernels in total. CUTLASS and cuSPARSElt. The cuSPARSElt backend is used only when you set
sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False
By default, CUTLASS backend is used. Different from PyTorch, we support float8 2:4-spMM via CUTLASS (for RTX 4090 and higher GPUs).
Transposable mask select
Efficient mask select kernel based on convolution:
mask_select = sparse.TransposableSparse()
A_sparse, A_mask = sparse.mask_select(A)
Masked decay
Fused kernel for masked decay:
sparse.masked_add_(grad, p, mask, alpha=2e-4)
This is equivalent to grad.add_( * (1 - mask), alpha=decay)
Soft-thresholding (pseudo)
A_sparse, A_mask = sparse.soft_threshold24_triton(A)
Take nanoGPT as an example.
Step 1
Replace nn.Linear
with self-defined SparseLinearTranspose
class SparseLinearTranspose(nn.Linear):
def __init__(self, in_features: int, out_features: int, bias: bool = True, func=lambda step: 'dense',
super(SparseLinearTranspose, self).__init__(in_features, out_features, bias=bias, **kwargs)
self.weight.freq = 40 # update freq
self.weight.cnt = 0 # how many steps after an optim step
self.weight.counter = 0 # how many optim steps
self.weight.step = 0 # total training step
self.weight.mask = torch.ones_like(self.weight, dtype=torch.bool)
self.weight.weight_sparse = None
self.weight.weight_sparse_T = None
self.weight.mode = 'sparse'
self.func = func
self.transposable_sparse = TransposableSparse(abs=True)
SparseSemiStructuredTensor._FORCE_CUTLASS = True # we won't need this later
def forward(self, x):
if self.weight.mode == 'dense':
x = F.linear(x, self.weight, self.bias)
self.weight.mask =
if self.weight.counter % self.weight.freq == 0 and self.weight.cnt == 0:
_, self.weight.mask = self.transposable_sparse(self.weight)
if self.weight.cnt == 0:
self.weight.weight_sparse = to_sparse_semi_structured(self.weight, mask=self.weight.mask,
self.weight.weight_sparse_T = to_sparse_semi_structured(self.weight.T, mask=self.weight.mask.T,
with autocast(device_type='cuda', dtype=torch.float16):
x = sparse_linear_transpose.apply(x, self.weight, self.weight.weight_sparse,
if self.weight.cnt == 0:
self.weight.counter += 1
self.weight.step += 1
self.weight.cnt += 1
return x
class sparse_linear_transpose(autograd.Function):
def forward(ctx, input, weight, weight_sparse, weight_sparse_T, bias):
ctx.save_for_backward(input, weight_sparse_T, bias)
ctx.shape = input.shape
input = input.view(-1, input.shape[-1])
output =, weight_sparse.t())
if bias is None:
return output.view(*ctx.shape[:-1], -1)
return output.view(*ctx.shape[:-1], -1) + bias
def backward(ctx, grad_output):
grad_output = grad_output
input, weight_T, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
if grad_output.stride() == (0, 0, 0):
grad_output = torch.ones_like(grad_output, device=grad_output.device, dtype=grad_output.dtype)
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_input =, weight_T.t()).view(
if ctx.needs_input_grad[1]:
input = input.view(-1, input.shape[-1])
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_weight =, MVUE24=True), input)
if ctx.needs_input_grad[4]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, None, None, grad_bias
Step 2
Apply masked decay:
with torch.no_grad():
for p in model.parameters():
if hasattr(p, 'mask') and p.mode == 'sparse':
p.grad = p.grad.float()
masked_add_(,, p.mask, alpha=alpha)
p.cnt = 0
Step 3
Dense fine-tuning:
# Step 4: manually convert to dense fine-tune stage
if iter_num == 250000:
for p in model.parameters():
if hasattr(p, 'mask') and p.mode == 'sparse':
p.mode = 'dense'
Replace nn.Linear
with self-defined SparseLinear
class SoftThreshold(autograd.Function):
def forward(ctx, weight, scale):
weight_temp = weight.detach()
weight_sparse, _ = soft_threshold24_triton(weight_temp)
return weight_sparse * scale
def backward(ctx, grad_output):
return grad_output, None
class SparseLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
super(FP8SparseLinear, self).__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
self.register_buffer('scale', torch.tensor(0.))
def get_sparse_weights(self):
return SoftThreshold.apply(self.weight, self.scale)
def init_scale(self):
weight = self.weight.cuda()
weight_temp = weight.detach()
weight_sparse, _ = soft_threshold24_triton(weight_temp)
weight.scale =, torch.flatten(weight_sparse)) /
torch.flatten(weight_sparse), torch.flatten(weight_sparse))
self.weight.scale = self.scale
def forward(self, x):
w = self.get_sparse_weights()
x = F.linear(x, w, self.bias)
return x
The relevant code of this can be found at
STEP: Learning N:M Structured Sparsity Masks from Scratch with Precondition [PDF]
To replicate STEP in two steps:
Step 1
Replace nn.Linear
with a STE-based 2:4 linear module:
import torch
from torch import autograd, nn
import torch.nn.functional as F
import sparse
class Sparse(autograd.Function):
def forward(ctx, weight):
weight_sparse, _ = sparse.sparse24_triton(weight)
return weight_sparse
def backward(ctx, grad_output):
return grad_output
class STEPLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
super(STEPLinear, self).__init__(in_features, out_features, bias, device, dtype)
setattr(self.weight, 'mask', 'dense')
def forward(self, x):
if and self.weight.mask == 'dense':
x = F.linear(x, self.weight, self.bias)
w = Sparse.apply(self.weight)
x = F.linear(x, w, self.bias)
return x
Step 2
Replace AdamW optimizer with STEP specific sparse.AdamW_STEP
import sparse
- adam = torch.optim.Adam(...)
+ adam = sparse.AdamW_STEP(...)
Some notes on extra arguments for AdamW_STEP
: a tuple for$T_{min}$ and$T_{max}$ in Algorithm 2, recommended to be 10% and 50% of total optimization steps. -
: different options to compute$Z_t$ in Algorithm 2, need to be1
If you like our study, please cite:
title={Accelerating Transformer Pre-training with 2:4 Sparsity},
author={Yuezhou Hu and Kang Zhao and Weiyu Huang and Jianfei Chen and Jun Zhu},
booktitle={Forty-first International Conference on Machine Learning},
title={S-{STE}: Continuous Pruning Function for Efficient 2:4 Sparse Pre-training},
author={Yuezhou Hu and Jun Zhu and Jianfei Chen},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},