diff --git a/setup.py b/setup.py index 309d7d3372..eb22d8f6a3 100644 --- a/setup.py +++ b/setup.py @@ -104,7 +104,7 @@ # Flash 2 group kept for backwards compatibility extra_deps['gpu-flash2'] = [ - 'flash-attn==2.5.8', + 'flash-attn>=2.5.8,<3', ] extra_deps['gpu'] = copy.deepcopy(extra_deps['gpu-flash2'])