Skip to content
forked from dhyuan99/VecKM

Official GitHub repo for VecKM. A very efficient dense local geometry encoder.

License

Notifications You must be signed in to change notification settings

deepak-1530/VecKM

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

VecKM: A Linear Time and Space Local Point Cloud Geometry Encoder

Dehao Yuan ,  Cornelia Fermüller ,  Tahseen Rabbani ,  Furong Huang ,  Yiannis Aloimonos  

arXiv-Preprint, 2024      [arXiv]

Highlighted Features

Installation

First, install the dependencies:

conda create -n VecKM python=3.11
conda activate VecKM
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
conda install -c conda-forge cudatoolkit-dev
pip install scipy
pip install complexPyTorch

If you want to use the pure PyTorch implementation (slower but more convenient), without CUDA runtime and memory optimization, simply install by:

pip install .

If you want to use the CUDA optimized implementation, please run:

cd src/cuvkm
python setup.py install
cd -
pip install .

Usage

Case 1: If you have small point cloud size, e.g. < 5000, it is recommended to use the following implementation:

from VecKM.cuvkm.cuvkm import VecKM
vkm = VecKM(d=128, alpha=30, beta=9, positional_encoding=False).cuda()

Or if you want to use the slower Python implementation without installation,

from VecKM.pyvkm.vkm_small import VecKM
vkm = VecKM(d=128, alpha=30, beta=9, positional_encoding=False).cuda()

Case 2: If you have large point cloud size, e.g. > 10000, it is recommended to use the following implementation:

from VecKM.pyvkm.vkm_large import VecKM
vkm = VecKM(d=256, alpha=30, beta=9, p=2048).cuda()
# Please refer to the "Implementation by Yourself" section for suggestion to pick d and p.

Then you will get a local geometry encoding by:

pts = torch.randn(n, 3).cuda() # your input point cloud.
G = vkm(pts)

⚠️ Caution: VecKM is sensitive to scaling. Please make sure your data is properly scaled before passing into VecKM.

Implementation by Yourself

If you are struggled with installation (e.g. due to some environment issues), it is very simple to implement VecKM if you want to incorporate it into your own code. Suppose your input point cloud pts has shape (n,3) or (b,n,3), then the following code will give you the VecKM local geometry encoding with output shape (n,d) or (b,n,d). It is recommended to have PyTorch >= 1.13.0 since it has better support for complex tensors, but lower versions shall also work.

import torch
import torch.nn as nn
import numpy as np
from scipy.stats import norm

def strict_standard_normal(d):
    # this function generate very similar outcomes as torch.randn(d)
    # but the numbers are strictly standard normal, no randomness.
    y = np.linspace(0, 1, d+2)
    x = norm.ppf(y)[1:-1]
    np.random.shuffle(x)
    x = torch.tensor(x).float()
    return x

class VecKM(nn.Module):
    def __init__(self, d=256, alpha=30, beta=9, p=4096):
        """ I tested empirically, here are some general suggestions for selecting parameters d and p: 
        d = 256, p = 4096 is for point cloud size ~20k. Runtime is about 28ms.
        d = 128, p = 8192 is for point cloud size ~50k. Runtime is about 76ms.
        For larger point cloud size, please enlarge p, but if that costs too much, please reduce d.
        A general empirical phenomenon is (d*p) is postively correlated with the encoding quality.

        For the selection of parameter alpha and beta, please see the github section below.
        """
        super().__init__()
        self.sqrt_d = d ** 0.5

        self.A = torch.stack(
            [strict_standard_normal(d) for _ in range(3)], 
            dim=0
        ) * alpha
        self.A = nn.Parameter(self.A, False)                                    # (3, d)

        self.B = torch.stack(
            [strict_standard_normal(p) for _ in range(3)], 
            dim=0
        ) * beta
        self.B = nn.Parameter(self.B, False)                                    # (3, d)

    def forward(self, pts):
        """ Compute the dense local geometry encodings of the given point cloud.
        Args:
            pts: (bs, n, 3) or (n, 3) tensor, the input point cloud.

        Returns:
            G: (bs, n, d) or (n, d) tensor
               the dense local geometry encodings. 
               note: it is complex valued. 
        """
        pA = pts @ self.A                                                       # Real(..., n, d)
        pB = pts @ self.B                                                       # Real(..., n, p)
        eA = torch.concatenate((torch.cos(pA), torch.sin(pA)), dim=1)           # Real(..., n, 2d)
        eB = torch.concatenate((torch.cos(pB), torch.sin(pB)), dim=1)           # Real(..., n, 2p)
        G = torch.matmul(
            eB,                                                                 # Real(..., n, 2p)
            eB.transpose(-1,-2) @ eA                                            # Real(..., 2p, 2d)
        )                                                                       # Real(..., n, 2d)
        G = torch.complex(
            G[:,:self.d], G[:,self.d:]
        ) / torch.complex(
            eA[:,:self.d], eA[:,self.d:]
        )                                                                       # Complex(..., n, d)
        G = G / torch.norm(G, dim=-1, keepdim=True) * self.sqrt_d
        return G

vkm = VecKM()
pts = torch.rand((10,1000,3))
print(vkm(pts).shape) # it will be Complex(10,1000,256)
pts = torch.rand((1000,3))
print(vkm(pts).shape) # it will be Complex(1000, 256)

from complexPyTorch.complexLayers import ComplexLinear, ComplexReLU
# You may want to use apply two-layer feature transform to the encoding.
feat_trans = nn.Sequential(
    ComplexLinear(256, 128),
    ComplexReLU(),
    ComplexLinear(128, 128)
)
G = feat_trans(vkm(pts))
G = G.real**2 + G.imag**2 # it will be Real(10, 1000, 128) or Real(1000, 1024).

Effect of Parameters $\alpha$ and $\beta$

There are two parameters alpha and beta in the VecKM encoding. They are controlling the resolution and receptive field of VecKM, respectively. A higher alpha will produce a more detailed encoding of the local geometry, and a smaller alpha will produce a more abstract encoding. A higher beta will result in a smaller receptive field. You could look at the figure below for a rough understanding.

Assuming your input is normalized within a ball with radius 1. The overall advice for picking alpha and beta will be, if your task is low-level, such as feature matching, normal estimation, then alpha in range (60, 120) is suggested. If your task is high-level, such as classification and segmentation, then alpha in range (20, 30) is suggested. For beta, it is closely related to the neighborhood radius. We provide a table of the correspondence. For example, if you want to extract the local geometry encoding with radius 0.3, then you would select beta to be 6.

beta 1 2 3 4 5 6 7 8 9 10
radius 1.800 0.900 0.600 0.450 0.360 0.300 0.257 0.225 0.200 0.180
beta 11 12 13 14 15 16 17 18 19 20
radius 0.163 0.150 0.138 0.129 0.120 0.113 0.106 0.100 0.095 0.090
beta 21 22 23 24 25 26 27 28 29 30
radius 0.086 0.082 0.078 0.075 0.072 0.069 0.067 0.065 0.062 0.060

Examples

Check out the examples for the example analysis of VecKM.

Experiments

Check out the applications of VecKM to normal estimation, classification, part segmentation. The overall architecture change will be like:

Citation

If you find it helpful, please consider citing our papers:

@misc{yuan2024linear,
      title={A Linear Time and Space Local Point Cloud Geometry Encoder via Vectorized Kernel Mixture (VecKM)}, 
      author={Dehao Yuan and Cornelia Fermüller and Tahseen Rabbani and Furong Huang and Yiannis Aloimonos},
      year={2024},
      eprint={2404.01568},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

About

Official GitHub repo for VecKM. A very efficient dense local geometry encoder.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 97.1%
  • Cuda 1.2%
  • C++ 1.2%
  • C 0.5%