python setup.py install
pip install -r requirements.txt
Here, we demonstrate how to apply S2FT in training with the LLaMA architecture. See advanced usage in experiments/train/finetune.py
(Line 266-330) and experiments/utils/s2_utils.py
.
from s2ft import S2ColumnLinear, S2RowLinear
parameters_d = {}
intermediate_dim = model.config.intermediate_size
ffn_indices = [
i for i in range(intermediate_dim * model.config.num_hidden_layers)
]
for i in range(model.config.num_hidden_layers):
parameters_d[i] = []
num_d = int(intermediate_dim * model.config.num_hidden_layers * args.d_ratio)
select_d = sorted(random.sample(ffn_indices, num_d))
for d in select_d:
parameters_d[d // intermediate_dim].append(d % intermediate_dim)
selected_parameters = {"down_proj": parameters_d}
def convert_ffn_layer_to_s2(model, selected_parameters):
for i in range(model.config.num_hidden_layers):
layer = model.model.layers[i]
order = selected_parameters["down_proj"][i]
for j in range(model.config.intermediate_size):
if j not in order:
order.append(j)
module = layer.mlp.down_proj
checkpoint = copy.deepcopy(module.state_dict())
layer.mlp.down_proj = S2RowLinear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias,
start=len(only_u),
end=(len(only_u) + len(ud) + len(only_d)),
device=next(module.parameters()).device,
dtype=next(module.parameters()).dtype,
)
layer.mlp.down_proj.load_state_dict(checkpoint, strict=False)
del module
del checkpoint
u_weight = layer.mlp.up_proj.weight.data
layer.mlp.up_proj.weight.data = u_weight[order, :]
g_weight = layer.mlp.gate_proj.weight.data
layer.mlp.gate_proj.weight.data = g_weight[order, :]
d_weight = layer.mlp.down_proj.weight.data
layer.mlp.down_proj.weight.data = d_weight[:, order]
To reproduce S2FT results on the commonsense and arithmetic reasoning dataset, refer to the instruction and code available in the experiments
directory.
For the efficient implementation of S2FT and the related efficiency analysis code, see the efficiency
directory.
@inproceedings{yang2024s2ft,
title={S2FT: Efficient, Scalable and Generalizable LLM Fine-tuning by Structured Sparsity},
author={Yang, Xinyu and Leng, Jixuan and Guo, Geyang and Zhao, Jiawei and Nakada, Ryumei and Zhang, Linjun and Yao, Huaxiu and Chen, Beidi},
booktitle={The 38th Conference on Neural Information Processing Systems (NeurIPS)},
year={2024}
}