Skip to content

Commit

Permalink
Enable support for sparse tensors for multi_tensor_apply (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
lcskrishna authored May 12, 2020
1 parent 2d0f9cf commit 02a5274
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions csrc/multi_tensor_apply.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void multi_tensor_apply(
for(int t = 0; t < tensor_lists[l].size(); t++)
{
// TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
bool contiguous_memory = (tensor_lists[l][t].is_sparse()) ? tensor_lists[l][t]._values().is_contiguous() : tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
#endif
Expand All @@ -78,8 +78,15 @@ void multi_tensor_apply(
for(int t = 0; t < ntensors; t++)
{
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for(int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
for(int d = 0; d < depth; d++) {
if (tensor_lists[d][t].is_sparse()) {
at::Tensor dst = at::zeros(tensor_lists[d][t].sizes(), tensor_lists[d][t].options().layout(at::kStrided));
dst.add_(tensor_lists[d][t]);
tl.addresses[d][loc_tensor_info] = dst.data_ptr();
} else {
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
}
}
loc_tensor_info++;

int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
Expand Down

0 comments on commit 02a5274

Please sign in to comment.