From e21399e486a82ea9658d6077ec60710ecffc0ddc Mon Sep 17 00:00:00 2001 From: KanishkNavale Date: Tue, 6 Feb 2024 16:03:17 +0100 Subject: [PATCH] Added: Ball query for batch vectors --- Readme.md | 1 + heimdall/pointcloud/sampling.py | 10 ++++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/Readme.md b/Readme.md index 373ded1..5dd716f 100644 --- a/Readme.md +++ b/Readme.md @@ -21,6 +21,7 @@ Bound to a long term development. Below is the list of present contents: - Pointcloud: - Furthest point sampling + - Ball query - Robotics: - Stable inverse jacobian diff --git a/heimdall/pointcloud/sampling.py b/heimdall/pointcloud/sampling.py index add0577..cb82b7a 100644 --- a/heimdall/pointcloud/sampling.py +++ b/heimdall/pointcloud/sampling.py @@ -1,3 +1,5 @@ +from typing import List + import torch from heimdall.utils import convert_numpy_to_tensor @@ -70,10 +72,10 @@ def furthest_point_sampling( @convert_numpy_to_tensor def ball_query( - vector: torch.Tensor, pointcloud: torch.Tensor, radius: float = 0.01, **kwargs -) -> torch.Tensor: - distances = torch.linalg.norm(pointcloud - vectors.unsqueeze(dim=0), dim=-1, ord=2) + vectors: torch.Tensor, pointcloud: torch.Tensor, radius: float = 0.01, **kwargs +) -> List[torch.Tensor]: + distances = torch.linalg.norm(pointcloud - vectors.unsqueeze(dim=1), dim=-1, ord=2) sampling_mask = (distances <= radius) * (distances > 0.0) - return pointcloud[sampling_mask] + return [pointcloud[mask] for mask in sampling_mask]