diff --git a/xformers/triton/layer_norm.py b/xformers/triton/layer_norm.py index 4956f9f3c..cec4cc9b7 100644 --- a/xformers/triton/layer_norm.py +++ b/xformers/triton/layer_norm.py @@ -221,7 +221,7 @@ def layer_norm( and bias is not None ): return _LayerNorm.apply(x, weight, bias, eps) - except (triton.code_gen.OutOfResources, RuntimeError) as e: + except RuntimeError as e: # Catch cases where the current GPU does not have enough registers to hold a full tensor line # fallback to PyTorch's implementation, which streams the tensor in and out _triton_registered_warnings = True