From a89cf3cbdf64462969b8a86b35fa561462ac7096 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Thu, 16 Jan 2025 23:47:59 -0800 Subject: [PATCH] [cp] Add cudnn attention support to Context Parallel [ghstack-poisoned] --- torchtitan/models/llama/__init__.py | 2 +- torchtitan/utils.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/torchtitan/models/llama/__init__.py b/torchtitan/models/llama/__init__.py index 3bb430d2..104a7b78 100644 --- a/torchtitan/models/llama/__init__.py +++ b/torchtitan/models/llama/__init__.py @@ -14,7 +14,7 @@ "debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000), "8B": ModelArgs( dim=4096, - n_layers=32, + n_layers=1, n_heads=32, n_kv_heads=8, ffn_dim_multiplier=1.3, diff --git a/torchtitan/utils.py b/torchtitan/utils.py index 88663c00..4e925864 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -212,7 +212,8 @@ def context(cp_context: Optional[Generator[None, None, None]] = None): # TODO (xilunwu): support cuDNN backend stack.enter_context( sdpa_kernel( - [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] + # [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] + [SDPBackend.CUDNN_ATTENTION] ) ) stack.enter_context(cp_context)