diff --git a/include/flashinfer/frag_layout_swizzle.cuh b/include/flashinfer/frag_layout_swizzle.cuh index 3dbfdb9d..39cf92bc 100644 --- a/include/flashinfer/frag_layout_swizzle.cuh +++ b/include/flashinfer/frag_layout_swizzle.cuh @@ -29,9 +29,8 @@ __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b(uint32_t x) { } __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b_trans(uint32_t x) { - x = __byte_perm(x, x, 0x3120); uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x4); - x = __byte_perm(x, tmp, ((threadIdx.x & 0x4) == 0) ? 0x5410 : 0x3276); + x = __byte_perm(x, tmp, ((threadIdx.x & 0x4) == 0) ? 0x6420 : 0x3175); tmp = __shfl_xor_sync(0xffffffff, x, 0x8); x = __byte_perm(x, tmp, ((threadIdx.x & 0x8) == 0) ? 0x5410 : 0x3276); tmp = __shfl_xor_sync(0xffffffff, x, 0x10); diff --git a/include/flashinfer/vec_dtypes.cuh b/include/flashinfer/vec_dtypes.cuh index 596763c1..a40b4575 100644 --- a/include/flashinfer/vec_dtypes.cuh +++ b/include/flashinfer/vec_dtypes.cuh @@ -126,7 +126,6 @@ __device__ void fast_dequant_f8f16x4(uint32_t* input, uint2* output) { constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); constexpr int MASK3 = MASK2 & 0x7fffffff; constexpr int MASK = MASK3 | (MASK3 >> 16); - // Final MASK value: 0x7F007F00 q = __byte_perm(q, q, 0x1302); // Extract and shift FP8 values to FP16 format