Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fast vector operation for pillar scatter #1676

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 109 additions & 3 deletions pcdet/models/backbones_2d/map_to_bev/pointpillar_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, model_cfg, grid_size, **kwargs):
self.nx, self.ny, self.nz = grid_size
assert self.nz == 1

def forward(self, batch_dict, **kwargs):
def slow_forward(self, batch_dict, **kwargs):
pillar_features, coords = batch_dict['pillar_features'], batch_dict['voxel_coords']
batch_spatial_features = []
batch_size = coords[:, 0].max().int().item() + 1
Expand All @@ -35,6 +35,26 @@ def forward(self, batch_dict, **kwargs):
batch_spatial_features = batch_spatial_features.view(batch_size, self.num_bev_features * self.nz, self.ny, self.nx)
batch_dict['spatial_features'] = batch_spatial_features
return batch_dict

def forward(self, batch_dict, **kwargs):
# coords -> (N, 4) [batch_idx, grid_z_idx, grid_y_idx, grid_x_idx]
# pillar_features -> (N, C): N == num total voxels, C == channel features
pillar_features, coords = batch_dict['pillar_features'], batch_dict['voxel_coords']
batch_size = coords[:, 0].max().int().item() + 1

spatial_features = torch.zeros(self.num_bev_features, batch_size*self.nz*self.ny*self.nx, dtype=pillar_features.dtype, device=pillar_features.device)

coors_unique_idx = coords[:, 0] * self.nx * self.ny + coords[:, 2] * self.nx + coords[:, 3] # (N)
coors_unique_idx_expand = coors_unique_idx.unsqueeze(0).expand(self.num_bev_features, -1) # (C,N)

feature_values = pillar_features.t() # (C,N)

spatial_features.scatter_(1, coors_unique_idx_expand.type(torch.long), feature_values)
spatial_features = spatial_features.view(self.num_bev_features, batch_size, self.ny, self.nx).permute(1, 0, 2, 3)

# Add spatial features back to batch_dict
batch_dict['spatial_features'] = spatial_features
return batch_dict


class PointPillarScatter3d(nn.Module):
Expand All @@ -46,7 +66,7 @@ def __init__(self, model_cfg, grid_size, **kwargs):
self.num_bev_features = self.model_cfg.NUM_BEV_FEATURES
self.num_bev_features_before_compression = self.model_cfg.NUM_BEV_FEATURES // self.nz

def forward(self, batch_dict, **kwargs):
def slow_forward(self, batch_dict, **kwargs):
pillar_features, coords = batch_dict['pillar_features'], batch_dict['voxel_coords']

batch_spatial_features = []
Expand All @@ -70,4 +90,90 @@ def forward(self, batch_dict, **kwargs):
batch_spatial_features = torch.stack(batch_spatial_features, 0)
batch_spatial_features = batch_spatial_features.view(batch_size, self.num_bev_features_before_compression * self.nz, self.ny, self.nx)
batch_dict['spatial_features'] = batch_spatial_features
return batch_dict
return batch_dict

def forward(self, batch_dict, **kwargs):
# coords -> (N, 4) [batch_idx, grid_z_idx, grid_y_idx, grid_x_idx]
# pillar_features -> (N, C): N == num total voxels, C == channel features
pillar_features, coords = batch_dict['pillar_features'], batch_dict['voxel_coords']
batch_size = coords[:, 0].max().int().item() + 1

spatial_features = torch.zeros(self.num_bev_features_before_compression, batch_size*self.nz*self.ny*self.nx, dtype=pillar_features.dtype, device=pillar_features.device)

coors_unique_idx = coords[:, 0] * self.nz * self.ny * self.nx + coords[:, 1] * self.ny * self.nx + coords[:, 2] * self.nx + coords[:, 3] # (N)
coors_unique_idx_expand = coors_unique_idx.unsqueeze(0).expand(self.num_bev_features_before_compression, -1) # (C,N)

feature_values = pillar_features.t() # (C,N)

spatial_features.scatter_(1, coors_unique_idx_expand.type(torch.long), feature_values)
spatial_features = spatial_features.view(self.num_bev_features_before_compression, batch_size, self.nz, self.ny, self.nx).permute(1, 0, 2, 3, 4)
batch_spatial_features = spatial_features.view(batch_size, self.num_bev_features_before_compression * self.nz, self.ny, self.nx)

# Add spatial features back to batch_dict
batch_dict['spatial_features'] = batch_spatial_features
return batch_dict


if __name__ == '__main__':
# To test vector operation forward approach vs slow for loop forward
import numpy as np
from munch import Munch

def generate_sample_data_coors(nx, ny, nz, batch_size, channels, num_voxels_each_sample):
N = sum(num_voxels_each_sample)
coors = [] # (N,4) // b_idx,z,y,x
for i in range(batch_size):
random_coords = set()
while len(random_coords) < num_voxels_each_sample[i]:
x = np.random.randint(0, nx)
y = np.random.randint(0, ny)
z = np.random.randint(0, nz)
random_coords.add((x, y, z))
random_coords = list(random_coords)
for j in range(num_voxels_each_sample[i]):
coors.append([i, random_coords[j][2], random_coords[j][1], random_coords[j][0]])
# Stack coors vertically as torch tensor
coors = torch.tensor(coors, dtype=torch.int64)
return coors

def generate_sample_data_features(nx, ny, nz, batch_size, channels, num_voxels_each_sample):
N = sum(num_voxels_each_sample)
voxel_features = [] # (N,C)
for i in range(N):
voxel_features.append(np.random.normal(size=channels))
voxel_features = torch.tensor(voxel_features)
return voxel_features


nx = 24 * 8
ny = 8 * 8
nz = 1 # For pillar based scatter
bs = 24
channels = 64
num_voxels_each_sample = [np.random.randint(1, nx*ny*nz) for _ in range(bs)]

coors = generate_sample_data_coors(nx, ny, nz, bs, channels, num_voxels_each_sample)
features = generate_sample_data_features(nx, ny, nz, bs, channels, num_voxels_each_sample)

fake_dict = {
'pillar_features': features,
'voxel_coords': coors
}
fake_model_cfg = Munch()
fake_model_cfg.NUM_BEV_FEATURES = channels
fake_model_cfg.INPUT_SHAPE = (nx, ny, nz)

def test_pillar_scatter_vectorized():
pillar_scatter = PointPillarScatter(fake_model_cfg, (nx, ny, nz))
slow_forward = pillar_scatter.slow_forward(fake_dict)
fast_forward = pillar_scatter.forward(fake_dict)
assert torch.all(torch.isclose(slow_forward['spatial_features'], fast_forward['spatial_features']))

def test_pillar_scatter3d_vectorized():
pillar_scatter3d = PointPillarScatter3d(fake_model_cfg, (nx, ny, nz))
slow_forward = pillar_scatter3d.slow_forward(fake_dict)
fast_forward = pillar_scatter3d.forward(fake_dict)
assert torch.all(torch.isclose(slow_forward['spatial_features'], fast_forward['spatial_features']))

test_pillar_scatter_vectorized()
test_pillar_scatter3d_vectorized()