diff --git a/src/layer/arm/multiheadattention_arm.cpp b/src/layer/arm/multiheadattention_arm.cpp index 7ea0f62272b..15eca715699 100644 --- a/src/layer/arm/multiheadattention_arm.cpp +++ b/src/layer/arm/multiheadattention_arm.cpp @@ -44,7 +44,8 @@ MultiHeadAttention_arm::MultiHeadAttention_arm() int MultiHeadAttention_arm::create_pipeline(const Option& _opt) { Option opt = _opt; - opt.use_bf16_storage = false; + opt.use_fp16_storage &= support_fp16_storage; + opt.use_bf16_storage &= support_bf16_storage; { qk_softmax = ncnn::create_layer(ncnn::LayerType::Softmax); @@ -224,7 +225,8 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt) int MultiHeadAttention_arm::destroy_pipeline(const Option& _opt) { Option opt = _opt; - opt.use_bf16_storage = false; + opt.use_fp16_storage &= support_fp16_storage; + opt.use_bf16_storage &= support_bf16_storage; if (qk_softmax) { @@ -286,7 +288,8 @@ int MultiHeadAttention_arm::forward(const std::vector& bottom_blobs, std::v const Mat& attn_mask_blob = attn_mask ? bottom_blobs[bottom_blobs.size() - 1] : Mat(); Option opt = _opt; - opt.use_bf16_storage = false; + opt.use_fp16_storage &= support_fp16_storage; + opt.use_bf16_storage &= support_bf16_storage; Mat attn_mask_blob_unpacked; if (attn_mask_blob.elempack != 1)