diff --git a/Readme.md b/Readme.md index 373ded1..5dd716f 100644 --- a/Readme.md +++ b/Readme.md @@ -21,6 +21,7 @@ Bound to a long term development. Below is the list of present contents: - Pointcloud: - Furthest point sampling + - Ball query - Robotics: - Stable inverse jacobian diff --git a/heimdall/pointcloud/sampling.py b/heimdall/pointcloud/sampling.py index add0577..cb82b7a 100644 --- a/heimdall/pointcloud/sampling.py +++ b/heimdall/pointcloud/sampling.py @@ -1,3 +1,5 @@ +from typing import List + import torch from heimdall.utils import convert_numpy_to_tensor @@ -70,10 +72,10 @@ def furthest_point_sampling( @convert_numpy_to_tensor def ball_query( - vector: torch.Tensor, pointcloud: torch.Tensor, radius: float = 0.01, **kwargs -) -> torch.Tensor: - distances = torch.linalg.norm(pointcloud - vectors.unsqueeze(dim=0), dim=-1, ord=2) + vectors: torch.Tensor, pointcloud: torch.Tensor, radius: float = 0.01, **kwargs +) -> List[torch.Tensor]: + distances = torch.linalg.norm(pointcloud - vectors.unsqueeze(dim=1), dim=-1, ord=2) sampling_mask = (distances <= radius) * (distances > 0.0) - return pointcloud[sampling_mask] + return [pointcloud[mask] for mask in sampling_mask]