Skip to content

Commit

Permalink
Implement optimized GatherNd for batch_dim=0 case
Browse files Browse the repository at this point in the history
Signed-off-by: Geunho Lee <quic_geunlee@quicinc.com>
  • Loading branch information
quic-geunlee authored Mar 5, 2024
1 parent 8378866 commit 93ec6bc
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ def forward(self, data: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
"""
Forward-pass routine for GatherNd op
"""
if self.batch_dims == 0:
return self._gather_nd(data, indices)

data_rank = len(data.shape)

assert indices.shape[-1] <= data_rank
Expand Down Expand Up @@ -455,6 +458,29 @@ def forward(self, data: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
return torch.tensor(output_data_buffer, device=data.device).reshape(output_shape)
return torch.cat(output_data_buffer).reshape(output_shape)

@staticmethod
def _gather_nd(data: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
"""
GatherNd operation for batch_dim=0 case
:param data: Tensor to gather values
:param indices: Index tensor to be used to gather values
:return: Tensor after GatherNd operation
"""
data_rank, m = len(data.shape), indices.shape[-1]
assert (
m <= data_rank
), f"m: {m} should be less than or equal to data_rank: {data_rank}"

total_samples = indices.shape[:-1].numel()
output_shape = indices.shape[:-1] + data.shape[m:]
reshaped_indices = torch.split(
tensor=indices.reshape(total_samples, m).transpose(0, 1),
split_size_or_sections=1,
)

return data[reshaped_indices].reshape(output_shape).contiguous()


class ScatterElements(torch.nn.Module):
""" ScatterElements op implementation """
Expand Down

0 comments on commit 93ec6bc

Please sign in to comment.