Skip to content

Commit

Permalink
fix implementation for lower architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed May 18, 2024
1 parent a25ef9b commit b066517
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 21 deletions.
17 changes: 8 additions & 9 deletions _unittests/ut_ortops/test_optim_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,15 +671,14 @@ def _rotary_cuda(self, itype, side, input_shape=(3, 2, 3, 4)):
)
got = sess.run(None, feeds)[0]

rexp = expected.reshape((-1, expected.shape[-1]))
rgot = got.reshape((-1, got.shape[-1]))
print(expected.shape, rexp.shape, rgot.shape)
for i in range(rgot.shape[0]):
self.assertEqualArray(
rexp[i],
rgot[i],
msg=f"row {i} is wrong,\nexp={rexp[i]}\ngot={rgot[i]}",
)
# rexp = expected.reshape((-1, expected.shape[-1]))
# rgot = got.reshape((-1, got.shape[-1]))
# for i in range(rgot.shape[0]):
# self.assertEqualArray(
# rexp[i],
# rgot[i],
# msg=f"row {i} is wrong,\nexp={rexp[i]}\ngot={rgot[i]}",
# )
self.assertEqualArray(expected, got)

@unittest.skipIf(not has_cuda(), reason="cuda not available")
Expand Down
22 changes: 10 additions & 12 deletions onnx_extended/ortops/optim/cuda/rotary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ struct GridDim {
};
};

template <typename T> __device__ __inline__ T _neg(const T x) { return -x; }

#if __CUDA_ARCH__ < 700
template <> __device__ __inline__ half _neg(const half x) {
return __float2half(-__half2float(x));
}
#endif

template <typename T, RotarySide side>
__global__ void _RotaryKernelLeft(T *output_data, const T *input_data, CUDA_LONG half_N,
CUDA_LONG half_stride) {
Expand All @@ -28,23 +36,13 @@ __global__ void _RotaryKernelLeft(T *output_data, const T *input_data, CUDA_LONG
return;
CUDA_LONG last = id % half_stride;
id = (id - last) * 2 + last;
#if __CUDA_ARCH__ < 700
if (side == RotarySide::LEFT) {
output_data[id + half_stride] = input_data[id];
output_data[id] = __float2half(-__half2float(input_data[id + half_stride]));
output_data[id] = _neg(input_data[id + half_stride]);
} else {
output_data[id + half_stride] = __float2half(-__half2float(input_data[id]));
output_data[id + half_stride] = _neg(input_data[id]);
output_data[id] = input_data[id + half_stride];
}
#else
if (side == RotarySide::LEFT) {
output_data[id + half_stride] = input_data[id];
output_data[id] = -input_data[id + half_stride];
} else {
output_data[id + half_stride] = -input_data[id];
output_data[id] = input_data[id + half_stride];
}
#endif
}

template <typename T>
Expand Down

0 comments on commit b066517

Please sign in to comment.