Skip to content

Commit

Permalink
Fix bug in gather cuda kernel (#588)
Browse files Browse the repository at this point in the history
* fix bug in gather cuda kernel

* add test
  • Loading branch information
nkoppel authored Mar 21, 2023
1 parent cbe38a5 commit 8b9fc37
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/tensor_ops/select_and_gather/gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ __device__ unsigned int get_gathered_index(
}

unsigned int elem_size = 1; // the size of each indexed element
unsigned int row_len = inp_dims[ax]; // the size of the indexed dimension
unsigned int inp_row_len = inp_dims[ax]; // the size of the indexed dimension in the input
unsigned int out_row_len = idx_dims[ax]; // the size of the indexed dimension in the output

for (unsigned int d = 0; d < inp_num_dims - ax - 1; d++) {
unsigned int dim_idx = inp_num_dims - 1 - d;
Expand All @@ -31,13 +32,13 @@ __device__ unsigned int get_gathered_index(
unsigned int idx_idx = get_strided_index(index / elem_size, idx_num_dims, idx_dims, idx_strides);

// indices for dimensions before, at, and after the indexed dimension
unsigned int idx_before = index / (elem_size * row_len);
unsigned int idx_before = index / (elem_size * out_row_len);
unsigned int idx_mid = idx[idx_idx];
assert(idx_mid < inp_dims[ax]);
unsigned int idx_after = index % elem_size;

// recombine
unsigned int new_idx = (idx_before * row_len + idx_mid) * elem_size + idx_after;
unsigned int new_idx = (idx_before * inp_row_len + idx_mid) * elem_size + idx_after;
return get_strided_index(new_idx, inp_num_dims, inp_dims, inp_strides);
}

Expand Down Expand Up @@ -65,6 +66,7 @@ __device__ void gather_fwd(
get_gathered_index(i, inp_num_dims, inp_dims, inp_strides, idx, idx_num_dims, idx_dims, idx_strides, out_num_dims);

out[out_i] = inp[inp_i];
// out[out_i] = inp_i;
}

template<typename T>
Expand Down
11 changes: 11 additions & 0 deletions src/tensor_ops/select_and_gather/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,4 +435,15 @@ mod tests {
let g = r.sum().backward();
assert_eq!(g.get(&t).array(), [[3.; 5], [0.; 5], [1.; 5], [2.; 5]]);
}

#[test]
fn test_gather_smaller_output_row() {
let dev: TestDevice = Default::default();
let t: Tensor<Rank2<2, 3>, TestDtype, _> = dev.sample_normal();
let t_array = t.array();
let r: Tensor<Rank2<2, 2>, _, _> = t.gather(dev.tensor([[0, 1], [0, 1]]));
let r_array = r.array();
assert_eq!(r_array[0][..], t_array[0][..2]);
assert_eq!(r_array[1][..], t_array[1][..2]);
}
}

0 comments on commit 8b9fc37

Please sign in to comment.