Skip to content

Commit

Permalink
code release of TrajecotryFormer
Browse files Browse the repository at this point in the history
  • Loading branch information
Cedarch committed Aug 22, 2023
1 parent 9721a70 commit dfc83da
Show file tree
Hide file tree
Showing 13 changed files with 143 additions and 100 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
An Efficient, Flexible, and General deep learning framework that retains minimal. Users can use EFG to explore any research topics following project templates.

# What's New
* 2023.08.22 Code release of ICCV2023 paper: [TrajectoryFormer: 3D Object Tracking Transformer with Predictive Trajectory Hypotheses](https://github.com/poodarchu/EFG/blob/master/playground/tracking.3d/waymo/trajectoryformer/README.md).
* 2023.04.13 Support COCO Panoptic Segmentation with Mask2Former.
* 2023.03.30 Support Pytorch 2.0.
* 2023.03.21 Code release of CVPR2023 **Highlight** paper: [ConQueR: Query Contrast Voxel-DETR for 3D Object Detection](https://github.com/poodarchu/EFG/blob/master/playground/detection.3d/waymo/conquer/README.md).
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import copy


def greedy_assignment(dist):
matched_indices = []
if dist.shape[1] == 0:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from sample import *
from track_evaluator import *
from env import *
from aug import CusTomFilterByRange, CusTomRandomFlip3D, CusTomGlobalScaling, CusTomGlobalRotation
from sample import SeqInferenceSampler
from track_evaluator import CustomWaymoTrackEvaluator
from env import CustomWDDataset
from trajectoryformer import TrajectoryFormer


__all__ = [
"CusTomFilterByRange", "CusTomRandomFlip3D", "CusTomGlobalScaling", "CusTomGlobalRotation",
"SeqInferenceSampler", "CustomWaymoTrackEvaluator", "CustomWDDataset",
]


def build_model(self, config):
model = TrajectoryFormer(config).cuda()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def evaluate(self):
for label in target["annotations"]["labels"]
]
)

target["annotations"]["gt_boxes"][:, -1] = limit_period(
target["annotations"]["gt_boxes"][:, -1],
offset=0.5,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ def __init__(self, config):
self.cyc_embed = torch.tensor([0, 0, 1]).cuda().float().reshape(1, 1, 3)
self.train_nms_thresh = self.config.dataset.nms_thresh
self.train_score_thresh = self.config.dataset.score_thresh

### eval ###
# eval
self.max_id = 0
self.WAYMO_TRACKING_NAMES = config.dataset.classes
self.nms_thresh = self.config.model.nms_thresh
Expand Down Expand Up @@ -432,8 +431,7 @@ def get_history_traj(self, cur_ids):
.clone()
)
transfered_traj, transfered_vel = transform_global_to_current_torch(
boxes_cat, vels_cat, pose_cur_cuda
)
boxes_cat, vels_cat, pose_cur_cuda)
traj[0, : boxes_cat.shape[0], k] = transfered_traj
traj_vels[0, : vels_cat.shape[0], k] = transfered_vel

Expand Down Expand Up @@ -1091,11 +1089,10 @@ def get_pred_candi(self, traj, traj_vels):

def get_pred_motion(self, traj, pred_vel=None):
traj_rois = traj.clone().unsqueeze(3)
batch_size, len_traj, num_track, num_hypo = (
batch_size, len_traj, num_track = (
traj_rois.shape[0],
traj_rois.shape[1],
traj_rois.shape[2],
traj_rois.shape[3],
)
self.num_future = 10 # pretrained motion model predict 10 future frames
history_traj = traj_rois
Expand Down
Original file line number Diff line number Diff line change
@@ -1,36 +1,34 @@
from torch import nn
import torch.nn.functional as F

class TransformerEncoder(nn.Module):

def __init__(self, encoder_layer, num_layers, norm=None,config=None):
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None, config=None):
super().__init__()
self.layers = nn.ModuleList(encoder_layer)
self.num_layers = num_layers
self.norm = norm

def forward(self, token, src, pos=None):

token_list = []
output = src
for layer in self.layers:
output,token = layer(token, output,pos=pos)
output, token = layer(token, output, pos=pos)
token_list.append(token)
if self.norm is not None:
output = self.norm(output)

return token_list

class TransformerEncoderGlobalLocal(nn.Module):

def __init__(self, encoder_layer, num_layers, norm=None,config=None):

class TransformerEncoderGlobalLocal(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None, config=None):
super().__init__()
self.layers = nn.ModuleList(encoder_layer)
self.num_layers = num_layers
self.norm = norm

def forward(self, src):

token_list = []
output = src
for layer in self.layers:
Expand All @@ -41,6 +39,7 @@ def forward(self, src):

return token_list


class TransformerEncoderLayer(nn.Module):
def __init__(self, config, d_model, nhead, dim_feedforward=2048, dropout=0):
super().__init__()
Expand All @@ -60,16 +59,21 @@ def with_pos_embed(self, tensor, pos):
return tensor if pos is None else tensor + pos

def forward_post(self, token, src, pos=None):

src_mix = self.point_attn(query=src.permute(1,0,2), key=src.permute(1,0,2), value=src.permute(1,0,2))[0]
src_mix = src_mix.permute(1,0,2)
src_mix = self.point_attn(
query=src.permute(1, 0, 2),
key=src.permute(1, 0, 2),
value=src.permute(1, 0, 2),
)[0]
src_mix = src_mix.permute(1, 0, 2)
src = src + self.dropout1(src_mix)
src = self.norm1(src)
src_mix = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src_mix)
src = self.norm2(src)
src_summary = self.self_attn(token.permute(1,0,2), key=src.permute(1,0,2), value=src.permute(1,0,2))[0]
src_summary = src_summary.permute(1,0,2)
src_summary = self.self_attn(
token.permute(1, 0, 2), key=src.permute(1, 0, 2), value=src.permute(1, 0, 2)
)[0]
src_summary = src_summary.permute(1, 0, 2)
token = token + self.dropout1(src_summary)
token = self.norm1(token)
src_summary = self.linear2(self.dropout(self.activation(self.linear1(token))))
Expand All @@ -79,48 +83,71 @@ def forward_post(self, token, src, pos=None):
return src, token

def forward(self, token, src, pos=None):

return self.forward_post(token, src, pos)

class TransformerEncoderLayerGlobalLocal(nn.Module):

class TransformerEncoderLayerGlobalLocal(nn.Module):
def __init__(self, config, d_model, nhead, dim_feedforward=2048, dropout=0):
super().__init__()

self.global_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.local_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

self.ffn1 = FFN(d_model,dim_feedforward)
self.ffn2 = FFN(d_model,dim_feedforward)
self.ffn1 = FFN(d_model, dim_feedforward)
self.ffn2 = FFN(d_model, dim_feedforward)

self.activation = F.relu


def with_pos_embed(self, tensor, pos):
return tensor if pos is None else tensor + pos

def forward_post(self, src):

bs, num_track, candi, = src.shape[0], src.shape[1], src.shape[2]
src_global = src.reshape(bs,-1,src.shape[-1])
src_mix = self.global_attn(query=src_global.permute(1,0,2), key=src_global.permute(1,0,2), value=src_global.permute(1,0,2))[0]
src_mix = src_mix.permute(1,0,2)
src_global = self.ffn1(src_global,src_mix)
src_local = src_global.reshape(bs,num_track,candi,-1).reshape(bs*num_track,candi,-1)
src_mix = self.local_attn(query=src_local.permute(1,0,2), key=src_local.permute(1,0,2), value=src_local.permute(1,0,2))[0]
src_mix = src_mix.permute(1,0,2)
src_local = self.ffn2(src_local,src_mix)

return src_local.reshape(bs,num_track,candi,-1)
(
bs,
num_track,
candi,
) = (
src.shape[0],
src.shape[1],
src.shape[2],
)
src_global = src.reshape(bs, -1, src.shape[-1])
src_mix = self.global_attn(
query=src_global.permute(1, 0, 2),
key=src_global.permute(1, 0, 2),
value=src_global.permute(1, 0, 2),
)[0]
src_mix = src_mix.permute(1, 0, 2)
src_global = self.ffn1(src_global, src_mix)
src_local = src_global.reshape(bs, num_track, candi, -1).reshape(
bs * num_track, candi, -1
)
src_mix = self.local_attn(
query=src_local.permute(1, 0, 2),
key=src_local.permute(1, 0, 2),
value=src_local.permute(1, 0, 2),
)[0]
src_mix = src_mix.permute(1, 0, 2)
src_local = self.ffn2(src_local, src_mix)

return src_local.reshape(bs, num_track, candi, -1)

def forward(self, src):
return self.forward_post(src)



class FFN(nn.Module):
def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,dout=None,
activation="relu", normalize_before=False):
def __init__(
self,
d_model,
dim_feedforward=2048,
dropout=0.0,
dout=None,
activation="relu",
normalize_before=False,
):
super().__init__()

self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
Expand All @@ -133,16 +160,16 @@ def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,dout=None,

self.activation = _get_activation_fn(activation)

def forward(self, tgt,tgt_input):

def forward(self, tgt, tgt_input):
tgt = tgt + self.dropout2(tgt_input)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)

return tgt



def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
Expand All @@ -151,4 +178,4 @@ def _get_activation_fn(activation):
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
Loading

0 comments on commit dfc83da

Please sign in to comment.