From 02a5274b97382f26a0cee383d43a950d85b09256 Mon Sep 17 00:00:00 2001 From: Chaitanya Sri Krishna Lolla Date: Tue, 12 May 2020 16:24:40 -0700 Subject: [PATCH] Enable support for sparse tensors for multi_tensor_apply (#6) --- csrc/multi_tensor_apply.cuh | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/csrc/multi_tensor_apply.cuh b/csrc/multi_tensor_apply.cuh index 2b790cd2b..e0cbe7d10 100644 --- a/csrc/multi_tensor_apply.cuh +++ b/csrc/multi_tensor_apply.cuh @@ -56,7 +56,7 @@ void multi_tensor_apply( for(int t = 0; t < tensor_lists[l].size(); t++) { // TODO: Print which tensor fails. - bool contiguous_memory = tensor_lists[l][t].is_contiguous(); + bool contiguous_memory = (tensor_lists[l][t].is_sparse()) ? tensor_lists[l][t]._values().is_contiguous() : tensor_lists[l][t].is_contiguous(); #ifdef VERSION_GE_1_5 contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast)); #endif @@ -78,8 +78,15 @@ void multi_tensor_apply( for(int t = 0; t < ntensors; t++) { tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); - for(int d = 0; d < depth; d++) - tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); + for(int d = 0; d < depth; d++) { + if (tensor_lists[d][t].is_sparse()) { + at::Tensor dst = at::zeros(tensor_lists[d][t].sizes(), tensor_lists[d][t].options().layout(at::kStrided)); + dst.add_(tensor_lists[d][t]); + tl.addresses[d][loc_tensor_info] = dst.data_ptr(); + } else { + tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); + } + } loc_tensor_info++; int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;