Skip to content

Commit

Permalink
Fix transpose on neon for real
Browse files Browse the repository at this point in the history
  • Loading branch information
lilith committed Aug 22, 2024
1 parent 566725d commit b26f5b6
Showing 1 changed file with 37 additions and 40 deletions.
77 changes: 37 additions & 40 deletions imageflow_core/src/graphics/transpose.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![allow(non_snake_case)]

// Consider using mit-licensed https://github.com/ejmahler/transpose/blob/master/src/out_of_place.rs
// for recursive approach?
use crate::graphics::prelude::*;
use multiversion::multiversion;

Expand Down Expand Up @@ -113,50 +114,46 @@ unsafe fn transpose_8x8_avx2(src: *const u32, dst: *mut u32, src_stride: usize,
_mm256_storeu_si256(dst.add(dst_stride * 7) as *mut __m256i, row7);
}

#[inline]
#[target_feature(enable = "neon")]
#[cfg(target_arch = "aarch64")]
unsafe fn transpose_8x8_neon(src: *const u32, dst: *mut u32, src_stride: usize, dst_stride: usize) {
// Load 8 rows of 8 32-bit integers each
let row0 = vld1q_u32(src as *const u32);
let row1 = vld1q_u32(src.add(src_stride) as *const u32);
let row2 = vld1q_u32(src.add(src_stride * 2) as *const u32);
let row3 = vld1q_u32(src.add(src_stride * 3) as *const u32);
let row4 = vld1q_u32(src.add(src_stride * 4) as *const u32);
let row5 = vld1q_u32(src.add(src_stride * 5) as *const u32);
let row6 = vld1q_u32(src.add(src_stride * 6) as *const u32);
let row7 = vld1q_u32(src.add(src_stride * 7) as *const u32);
unsafe fn transpose_4x4_neon(src: *const u32, dst: *mut u32, src_stride: usize, dst_stride: usize) {
let r0 = vld1q_f32(src as *const f32);
let r1 = vld1q_f32(src.add(src_stride) as *const f32);
let r2 = vld1q_f32(src.add(src_stride * 2) as *const f32);
let r3 = vld1q_f32(src.add(src_stride * 3) as *const f32);

let c0 = vzip1q_f32(r0, r1);
let c1 = vzip2q_f32(r0, r1);
let c2 = vzip1q_f32(r2, r3);
let c3 = vzip2q_f32(r2, r3);

let t0 = vcombine_f32(vget_low_f32(c0), vget_low_f32(c2));
let t1 = vcombine_f32(vget_high_f32(c0), vget_high_f32(c2));
let t2 = vcombine_f32(vget_low_f32(c1), vget_low_f32(c3));
let t3 = vcombine_f32(vget_high_f32(c1), vget_high_f32(c3));

vst1q_f32(dst as *mut f32, t0);
vst1q_f32(dst.add(dst_stride) as *mut f32, t1);
vst1q_f32(dst.add(dst_stride * 2) as *mut f32, t2);
vst1q_f32(dst.add(dst_stride * 3) as *mut f32, t3);
}

// Transpose 8x8 matrix
let tmp01 = vtrnq_u32(row0, row1);
let tmp23 = vtrnq_u32(row2, row3);
let tmp45 = vtrnq_u32(row4, row5);
let tmp67 = vtrnq_u32(row6, row7);

let tmp89 = vuzpq_u32(tmp01.0, tmp23.0);
let tmp1011 = vuzpq_u32(tmp01.1, tmp23.1);
let tmp1213 = vuzpq_u32(tmp45.0, tmp67.0);
let tmp1415 = vuzpq_u32(tmp45.1, tmp67.1);

let result0 = vreinterpretq_u32_u64(vtrn1q_u64(vreinterpretq_u64_u32(tmp89.0), vreinterpretq_u64_u32(tmp1213.0)));
let result1 = vreinterpretq_u32_u64(vtrn1q_u64(vreinterpretq_u64_u32(tmp89.1), vreinterpretq_u64_u32(tmp1213.1)));
let result2 = vreinterpretq_u32_u64(vtrn1q_u64(vreinterpretq_u64_u32(tmp1011.0), vreinterpretq_u64_u32(tmp1415.0)));
let result3 = vreinterpretq_u32_u64(vtrn1q_u64(vreinterpretq_u64_u32(tmp1011.1), vreinterpretq_u64_u32(tmp1415.1)));
let result4 = vreinterpretq_u32_u64(vtrn2q_u64(vreinterpretq_u64_u32(tmp89.0), vreinterpretq_u64_u32(tmp1213.0)));
let result5 = vreinterpretq_u32_u64(vtrn2q_u64(vreinterpretq_u64_u32(tmp89.1), vreinterpretq_u64_u32(tmp1213.1)));
let result6 = vreinterpretq_u32_u64(vtrn2q_u64(vreinterpretq_u64_u32(tmp1011.0), vreinterpretq_u64_u32(tmp1415.0)));
let result7 = vreinterpretq_u32_u64(vtrn2q_u64(vreinterpretq_u64_u32(tmp1011.1), vreinterpretq_u64_u32(tmp1415.1)));
#[target_feature(enable = "neon")]
#[cfg(target_arch = "aarch64")]
pub unsafe fn transpose_8x8_neon(src: *const u32, dst: *mut u32, src_stride: usize, dst_stride: usize) {
// Transpose top-left 4x4 quadrant
transpose_4x4_neon(src, dst, src_stride, dst_stride);

// Store the transposed rows
vst1q_u32(dst as *mut u32, result0);
vst1q_u32(dst.add(dst_stride) as *mut u32, result1);
vst1q_u32(dst.add(dst_stride * 2) as *mut u32, result2);
vst1q_u32(dst.add(dst_stride * 3) as *mut u32, result3);
vst1q_u32(dst.add(dst_stride * 4) as *mut u32, result4);
vst1q_u32(dst.add(dst_stride * 5) as *mut u32, result5);
vst1q_u32(dst.add(dst_stride * 6) as *mut u32, result6);
vst1q_u32(dst.add(dst_stride * 7) as *mut u32, result7);
}
// Transpose top-right 4x4 quadrant
transpose_4x4_neon(src.add(4), dst.add(dst_stride * 4), src_stride, dst_stride);

// Transpose bottom-left 4x4 quadrant
transpose_4x4_neon(src.add(src_stride * 4), dst.add(4), src_stride, dst_stride);

// Transpose bottom-right 4x4 quadrant
transpose_4x4_neon(src.add(src_stride * 4).add(4), dst.add(dst_stride * 4).add(4), src_stride, dst_stride);
}
#[inline]
unsafe fn transpose4x4_generic(A: *mut f32, B: *mut f32, lda: i32, ldb: i32) {
for i in 0..4 {
Expand Down

0 comments on commit b26f5b6

Please sign in to comment.