Skip to content

Commit

Permalink
added unit tests, fix linter for dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
kprokofi committed Oct 10, 2024
1 parent 878a844 commit 2ee2ac6
Show file tree
Hide file tree
Showing 24 changed files with 845 additions and 387 deletions.
87 changes: 4 additions & 83 deletions src/otx/algo/object_detection_3d/backbones/monodetr_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,85 +67,6 @@ def forward(self, tensor_list: NestedTensor) -> torch.Tensor:
return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)


class PositionEmbeddingLearned(nn.Module):
"""Absolute pos embedding, learned."""

def __init__(self, num_pos_feats: int = 256):
"""Positional embedding."""
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)

def forward(self, tensor_list: NestedTensor) -> torch.Tensor:
"""Forward pass of the PositionEmbeddingLearned module.
Args:
tensor_list (NestedTensor): Input tensor.
Returns:
torch.Tensor: Position embeddings.
"""
x = tensor_list.tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device) / w * 49
j = torch.arange(h, device=x.device) / h * 49
x_emb = self.get_embed(i, self.col_embed)
y_emb = self.get_embed(j, self.row_embed)
return (
torch.cat(
[
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
],
dim=-1,
)
.permute(2, 0, 1)
.unsqueeze(0)
.repeat(x.shape[0], 1, 1, 1)
)

def get_embed(self, coord: torch.Tensor, embed: nn.Embedding) -> torch.Tensor:
"""Get the embedding for the given coordinates.
Args:
coord (torch.Tensor): The coordinates.
embed (nn.Embedding): The embedding layer.
Returns:
torch.Tensor: The embedding for the coordinates.
"""
floor_coord = coord.floor()
delta = (coord - floor_coord).unsqueeze(-1)
floor_coord = floor_coord.long()
ceil_coord = (floor_coord + 1).clamp(max=49)
return embed(floor_coord) * (1 - delta) + embed(ceil_coord) * delta


def build_position_encoding(
hidden_dim: int,
position_embedding: str | PositionEmbeddingSine | PositionEmbeddingLearned,
) -> PositionEmbeddingSine | PositionEmbeddingLearned:
"""Build the position encoding module.
Args:
hidden_dim (int): The hidden dimension.
position_embedding (Union[str, PositionEmbeddingSine, PositionEmbeddingLearned]): The position embedding type.
Returns:
Union[PositionEmbeddingSine, PositionEmbeddingLearned]: The position encoding module.
"""
n_steps = hidden_dim // 2
if position_embedding in ("v2", "sine"):
position_embedding = PositionEmbeddingSine(n_steps, normalize=True)
elif position_embedding in ("v3", "learned"):
position_embedding = PositionEmbeddingLearned(n_steps)
else:
msg = f"not supported {position_embedding}"
raise ValueError(msg)

return position_embedding


class BackboneBase(nn.Module):
"""BackboneBase module."""

Expand Down Expand Up @@ -204,13 +125,13 @@ class Joiner(nn.Sequential):
def __init__(
self,
backbone: nn.Module,
position_embedding: PositionEmbeddingSine | PositionEmbeddingLearned,
position_embedding: PositionEmbeddingSine,
) -> None:
"""Initialize the Joiner module.
Args:
backbone (nn.Module): The backbone module.
position_embedding (Union[PositionEmbeddingSine, PositionEmbeddingLearned]): The position embedding module.
position_embedding (Union[PositionEmbeddingSine]): The position embedding module.
"""
super().__init__(backbone, position_embedding)
self.strides = backbone.strides
Expand Down Expand Up @@ -240,7 +161,6 @@ class BackboneBuilder:
"return_interm_layers": True,
"positional_encoding": {
"hidden_dim": 256,
"position_embedding": "sine",
},
},
}
Expand All @@ -249,5 +169,6 @@ def __new__(cls, model_name: str) -> Joiner:
"""Constructor for Backbone MonoDetr."""
# TODO (Kirill): change backbone to already implemented in OTX
backbone = Backbone(**cls.CFG[model_name])
position_embedding = build_position_encoding(**cls.CFG[model_name]["positional_encoding"])
n_steps = cls.CFG[model_name]["positional_encoding"]["hidden_dim"] // 2
position_embedding = PositionEmbeddingSine(n_steps, normalize=True)
return Joiner(backbone, position_embedding)
7 changes: 5 additions & 2 deletions src/otx/algo/object_detection_3d/detectors/monodetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def __init__(
backbone: nn.Module,
depthaware_transformer: nn.Module,
depth_predictor: nn.Module,
criterion: nn.Module,
num_classes: int,
num_queries: int,
num_feature_levels: int,
criterion: nn.Module | None = None,
aux_loss: bool = True,
with_box_refine: bool = False,
init_box: bool = False,
Expand All @@ -41,7 +41,7 @@ def __init__(
backbone (nn.Module): torch module of the backbone to be used. See backbone.py
depthaware_transformer (nn.Module): depth-aware transformer architecture. See depth_aware_transformer.py
depth_predictor (nn.Module): depth predictor module
criterion (nn.Module): loss criterion module
criterion (nn.Module | None): loss criterion module
num_classes (int): number of object classes
num_queries (int): number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For KITTI, we recommend 50 queries.
Expand Down Expand Up @@ -285,6 +285,9 @@ def forward(
)

if mode == "loss":
if self.criterion is None:
msg = "Criterion is not set for the model"
raise ValueError(msg)
return self.criterion(outputs=out, targets=targets)

return out
Expand Down
54 changes: 1 addition & 53 deletions src/otx/algo/object_detection_3d/heads/depthaware_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,58 +167,6 @@ def get_proposal_pos_embed(self, proposals: Tensor) -> Tensor:
# N, L, 6, 64, 2
return torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)

def gen_encoder_output_proposals(
self,
memory: Tensor,
memory_padding_mask: Tensor,
spatial_shapes: list[tuple[int, int]],
) -> tuple[Tensor, Tensor]:
"""Generate encoder output and proposals.
Args:
memory (Tensor): Memory tensor of shape (N, S, C).
memory_padding_mask (Tensor): Memory padding mask tensor of shape (N, S).
spatial_shapes (List[Tuple[int, int]]): List of spatial shapes.
Returns:
Tuple[Tensor, Tensor]: Encoder output tensor of shape (N, S, C) and proposals tensor of shape (N, L, 6).
"""
n_, _, _ = memory.shape
proposals = []
_cur = 0
for lvl, (h_, w_) in enumerate(spatial_shapes):
mask_flatten_ = memory_padding_mask[:, _cur : (_cur + h_ * w_)].view(n_, h_, w_, 1)
valid_h = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_w = torch.sum(~mask_flatten_[:, 0, :, 0], 1)

grid_y, grid_x = torch.meshgrid(
torch.linspace(0, h_ - 1, h_, dtype=torch.float32, device=memory.device),
torch.linspace(0, w_ - 1, w_, dtype=torch.float32, device=memory.device),
)
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)

scale = torch.cat([valid_w.unsqueeze(-1), valid_h.unsqueeze(-1)], 1).view(n_, 1, 1, 2)
grid = (grid.unsqueeze(0).expand(n_, -1, -1, -1) + 0.5) / scale

lr = torch.ones_like(grid) * 0.05 * (2.0**lvl)
tb = torch.ones_like(grid) * 0.05 * (2.0**lvl)
wh = torch.cat((lr, tb), -1)

proposal = torch.cat((grid, wh), -1).view(n_, -1, 6)
proposals.append(proposal)
_cur += h_ * w_
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
output_proposals = torch.log(output_proposals / (1 - output_proposals))
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf"))
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))

output_memory = memory
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
output_memory = self.enc_output_norm(self.enc_output(output_memory))
return output_memory, output_proposals

def get_valid_ratio(self, mask: Tensor) -> Tensor:
"""Calculate the valid ratio of the mask.
Expand Down Expand Up @@ -830,7 +778,7 @@ def forward(
intermediate_reference_dims,
)

return output, reference_points
return output, reference_points, None


class DepthAwareTransformerBuilder:
Expand Down
21 changes: 14 additions & 7 deletions src/otx/core/data/dataset/object_detection_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
from PIL import Image as PILImage
from torchvision import tv_tensors

from otx.core.data.dataset.utils.kitti_utils import Calibration, affine_transform, angle2class, get_affine_transform
from otx.core.data.dataset.utils.kitti_utils import (
affine_transform,
angle2class,
get_affine_transform,
get_calib_from_file,
rect_to_img,
ry2alpha,
)
from otx.core.data.entity.base import ImageInfo
from otx.core.data.entity.object_detection_3d import Det3DBatchDataEntity, Det3DDataEntity
from otx.core.data.mem_cache import NULL_MEM_CACHE_HANDLER, MemCacheHandlerBase
Expand Down Expand Up @@ -45,7 +52,7 @@ def __init__(
max_refetch: int = 1000,
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
to_tv_image: bool = True,
to_tv_image: bool = False,
max_objects: int = 50,
depth_threshold: int = 65,
resolution: tuple[int, int] = (1280, 384), # (W, H)
Expand All @@ -69,7 +76,7 @@ def _get_item_impl(self, index: int) -> Det3DDataEntity | None:
entity = self.dm_subset[index]
image = entity.media_as(Image)
image = self._get_img_data_and_shape(image)[0]
calib = Calibration(entity.attributes["calib_path"])
calib = get_calib_from_file(entity.attributes["calib_path"])
original_kitti_format = None # don't use for training
if self.subset_type != "train":
# TODO (Kirill): remove this or duplication of the inputs
Expand Down Expand Up @@ -106,7 +113,7 @@ def _get_item_impl(self, index: int) -> Det3DDataEntity | None:
dtype=torch.float32,
),
labels=torch.as_tensor(targets["labels"], dtype=torch.long),
calib_matrix=torch.as_tensor(calib.P2, dtype=torch.float32),
calib_matrix=torch.as_tensor(calib, dtype=torch.float32),
boxes_3d=torch.as_tensor(targets["boxes_3d"], dtype=torch.float32),
size_2d=torch.as_tensor(targets["size_2d"], dtype=torch.float32),
size_3d=torch.as_tensor(targets["size_3d"], dtype=torch.float32),
Expand All @@ -123,7 +130,7 @@ def collate_fn(self) -> Callable:
"""Collection function to collect DetDataEntity into DetBatchDataEntity in data loader."""
return partial(Det3DBatchDataEntity.collate_fn, stack_images=self.stack_images)

def _decode_item(self, img: PILImage, annotations: list[Bbox], calib: Calibration) -> tuple: # noqa: C901
def _decode_item(self, img: PILImage, annotations: list[Bbox], calib: np.ndarray) -> tuple: # noqa: C901
"""Decode item for training."""
# data augmentation for image
img_size = np.array(img.size)
Expand Down Expand Up @@ -219,7 +226,7 @@ def _decode_item(self, img: PILImage, annotations: list[Bbox], calib: Calibratio
],
) # real 3D center in 3D space
center_3d = center_3d.reshape(-1, 3) # shape adjustment (N, 3)
center_3d, _ = calib.rect_to_img(center_3d) # project 3D center to image plane
center_3d, _ = rect_to_img(calib, center_3d) # project 3D center to image plane
center_3d = center_3d[0] # shape adjustment
if random_flip_flag: # random flip for center3d
center_3d[0] = img_size[0] - center_3d[0]
Expand Down Expand Up @@ -264,7 +271,7 @@ def _decode_item(self, img: PILImage, annotations: list[Bbox], calib: Calibratio
depth[i] = cur_obj["location"][-1] * crop_scale

# encoding heading angle
heading_angle = calib.ry2alpha(cur_obj["rotation_y"], (bbox2d[i][0] + bbox2d[i][2]) / 2)
heading_angle = ry2alpha(calib, cur_obj["rotation_y"], (bbox2d[i][0] + bbox2d[i][2]) / 2)
if heading_angle > np.pi:
heading_angle -= 2 * np.pi # check range
if heading_angle < -np.pi:
Expand Down
Loading

0 comments on commit 2ee2ac6

Please sign in to comment.