Skip to content

Commit

Permalink
Fix softmax and argmax precision loss
Browse files Browse the repository at this point in the history
  • Loading branch information
nhatdongdang committed Jul 6, 2024
1 parent dc885f0 commit 24f02ec
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/matrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,17 @@ static __device__ inline float fastexp(float x) {
}

__device__ void softmax(float* a, int rows) {
float res = (float)0;
for (int i = 0; i < rows; i++) {
res += exp(a[i]);
float sum = 0.0;
for (size_t i = 0; i < rows; i++) {
sum += __expf(a[i]);
}
for (int i = 0; i < rows; i++) {
a[i] /= res;
for (size_t i = 0; i < rows; i++) {
a[i] = __expf(a[i] - __logf(sum));
}
}

__device__ int argmax(float* a, int rows) {
int res = a[0];
float res = a[0];
int idx = 0;
for (int i = 0; i < rows; i++) {
if (res < a[i]) {
Expand Down

0 comments on commit 24f02ec

Please sign in to comment.