diff --git a/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py b/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py index 4445ab02244..65b4ff943dc 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py @@ -420,6 +420,9 @@ def forward(self, data: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: """ Forward-pass routine for GatherNd op """ + if self.batch_dims == 0: + return self._gather_nd(data, indices) + data_rank = len(data.shape) assert indices.shape[-1] <= data_rank @@ -455,6 +458,29 @@ def forward(self, data: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: return torch.tensor(output_data_buffer, device=data.device).reshape(output_shape) return torch.cat(output_data_buffer).reshape(output_shape) + @staticmethod + def _gather_nd(data: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + """ + GatherNd operation for batch_dim=0 case + + :param data: Tensor to gather values + :param indices: Index tensor to be used to gather values + :return: Tensor after GatherNd operation + """ + data_rank, m = len(data.shape), indices.shape[-1] + assert ( + m <= data_rank + ), f"m: {m} should be less than or equal to data_rank: {data_rank}" + + total_samples = indices.shape[:-1].numel() + output_shape = indices.shape[:-1] + data.shape[m:] + reshaped_indices = torch.split( + tensor=indices.reshape(total_samples, m).transpose(0, 1), + split_size_or_sections=1, + ) + + return data[reshaped_indices].reshape(output_shape).contiguous() + class ScatterElements(torch.nn.Module): """ ScatterElements op implementation """