Skip to content

Commit

Permalink
refactor: replace dataclasses to attrs (#27)
Browse files Browse the repository at this point in the history
Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp>
  • Loading branch information
ktro2828 authored Nov 16, 2024
1 parent 7cb0735 commit 293e144
Show file tree
Hide file tree
Showing 28 changed files with 203 additions and 418 deletions.
27 changes: 27 additions & 0 deletions t4_devkit/common/converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
from pyquaternion import Quaternion

if TYPE_CHECKING:
from t4_devkit.typing import ArrayLike, NDArray

__all__ = ["as_quaternion"]


def as_quaternion(value: ArrayLike | NDArray) -> Quaternion:
"""Convert input rotation like array to `Quaternion`.
Args:
value (ArrayLike | NDArray): Rotation matrix or quaternion.
Returns:
Quaternion: Converted instance.
"""
return (
Quaternion(matrix=value)
if isinstance(value, np.ndarray) and value.ndim == 2
else Quaternion(value)
)
36 changes: 12 additions & 24 deletions t4_devkit/dataclass/box.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, TypeVar

import numpy as np
from pyquaternion import Quaternion
from attrs import define, field
from attrs.converters import optional
from shapely.geometry import Polygon
from typing_extensions import Self

from t4_devkit.common.converter import as_quaternion

from .roi import Roi
from .trajectory import to_trajectories

Expand Down Expand Up @@ -57,7 +59,7 @@ def distance_box(box: BoxType, tf_matrix: HomogeneousMatrix) -> float | None:
return np.linalg.norm(position)


@dataclass(eq=False)
@define(eq=False)
class BaseBox:
"""Abstract base class for box objects."""

Expand All @@ -72,7 +74,7 @@ class BaseBox:
# >>> e.g.) box.as_state() -> BoxState


@dataclass(eq=False)
@define(eq=False)
class Box3D(BaseBox):
"""A class to represent 3D box.
Expand Down Expand Up @@ -109,25 +111,15 @@ class Box3D(BaseBox):
... )
"""

position: TranslationType
rotation: RotationType
position: TranslationType = field(converter=np.asarray)
rotation: RotationType = field(converter=as_quaternion)
shape: Shape
velocity: VelocityType | None = field(default=None)
velocity: VelocityType | None = field(default=None, converter=optional(np.asarray))
num_points: int | None = field(default=None)

# additional attributes: set by `with_**`
future: list[Trajectory] | None = field(default=None, init=False)

def __post_init__(self) -> None:
if not isinstance(self.position, np.ndarray):
self.position = np.array(self.position)

if not isinstance(self.rotation, Quaternion):
self.rotation = Quaternion(self.rotation)

if self.velocity is not None and not isinstance(self.velocity, np.ndarray):
self.velocity = np.array(self.velocity)

def with_future(
self,
waypoints: list[TrajectoryType],
Expand Down Expand Up @@ -195,7 +187,7 @@ def corners(self, box_scale: float = 1.0) -> NDArrayF64:
return np.dot(self.rotation.rotation_matrix, corners).T + self.position


@dataclass(eq=False)
@define(eq=False)
class Box2D(BaseBox):
"""A class to represent 2D box.
Expand All @@ -222,15 +214,11 @@ class Box2D(BaseBox):
>>> box2d = box2d.with_position(position=(1.0, 1.0, 1.0))
"""

roi: Roi | None = field(default=None)
roi: Roi | None = field(default=None, converter=lambda x: None if x is None else Roi(x))

# additional attributes: set by `with_**`
position: TranslationType | None = field(default=None, init=False)

def __post_init__(self) -> None:
if self.roi is not None and not isinstance(self.roi, Roi):
self.roi = Roi(self.roi)

def with_position(self, position: TranslationType) -> Self:
"""Return a self instance setting `position` attribute.
Expand All @@ -240,7 +228,7 @@ def with_position(self, position: TranslationType) -> Self:
Returns:
Self instance after setting `position`.
"""
self.position = np.array(position) if not isinstance(position, np.ndarray) else position
self.position = np.asarray(position)
return self

def __eq__(self, other: Box2D | None) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions t4_devkit/dataclass/label.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

from dataclasses import dataclass, field
from attrs import define, field

__all__ = ["SemanticLabel"]


@dataclass(frozen=True, eq=False)
@define(frozen=True, eq=False)
class SemanticLabel:
"""A dataclass to represent semantic labels.
Expand All @@ -15,7 +15,7 @@ class SemanticLabel:
"""

name: str
attributes: list[str] = field(default_factory=list)
attributes: list[str] = field(factory=list)

def __eq__(self, other: str | SemanticLabel) -> bool:
return self.name == other if isinstance(other, str) else self.name == other.name
22 changes: 13 additions & 9 deletions t4_devkit/dataclass/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import struct
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, TypeVar

import numpy as np
from attrs import define, field

if TYPE_CHECKING:
from typing_extensions import Self
Expand All @@ -21,14 +21,18 @@
]


@dataclass
@define
class PointCloud:
"""Abstract base dataclass for pointcloud data."""

points: NDArrayFloat
points: NDArrayFloat = field(converter=np.asarray)

def __post_init__(self) -> None:
assert self.points.shape[0] == self.num_dims()
@points.validator
def check_dims(self, attribute, value) -> None:
if value.shape[0] != self.num_dims():
raise ValueError(
f"Expected point dimension is {self.num_dims()}, but got {value.shape[0]}"
)

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -74,7 +78,7 @@ def transform(self, matrix: NDArrayFloat) -> None:
)[:3, :]


@dataclass
@define
class LidarPointCloud(PointCloud):
"""A dataclass to represent lidar pointcloud."""

Expand All @@ -91,7 +95,7 @@ def from_file(cls, filepath: str) -> Self:
return cls(points.T)


@dataclass
@define
class RadarPointCloud(PointCloud):
# class variables
invalid_states: ClassVar[list[int]] = [0]
Expand Down Expand Up @@ -188,9 +192,9 @@ def from_file(
return cls(points)


@dataclass
@define
class SegmentationPointCloud(PointCloud):
labels: NDArrayU8
labels: NDArrayU8 = field(converter=lambda x: np.asarray(x, dtype=np.uint8))

@staticmethod
def num_dims() -> int:
Expand Down
10 changes: 4 additions & 6 deletions t4_devkit/dataclass/roi.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

from attrs import define, field

if TYPE_CHECKING:
from t4_devkit.typing import RoiType

__all__ = ["Roi"]


@dataclass
@define
class Roi:
roi: RoiType
roi: RoiType = field(converter=tuple)

def __post_init__(self) -> None:
assert len(self.roi) == 4, (
"Expected roi is (x, y, width, height), " f"but got length with {len(self.roi)}."
)

if not isinstance(self.roi, tuple):
self.roi = tuple(self.roi)

@property
def offset(self) -> tuple[int, int]:
return self.roi[:2]
Expand Down
11 changes: 4 additions & 7 deletions t4_devkit/dataclass/shape.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum, auto, unique
from typing import TYPE_CHECKING

import numpy as np
from attrs import define, field
from shapely.geometry import Polygon
from typing_extensions import Self

Expand Down Expand Up @@ -35,7 +35,7 @@ def from_name(cls, name: str) -> Self:
return cls.__members__[name]


@dataclass
@define
class Shape:
"""A dataclass to represent the 3D box shape.
Expand All @@ -47,13 +47,10 @@ class Shape:
"""

shape_type: ShapeType
size: SizeType
size: SizeType = field(converter=np.asarray)
footprint: Polygon = field(default=None)

def __post_init__(self) -> None:
if not isinstance(self.size, np.ndarray):
self.size = np.array(self.size)

def __attrs_post_init__(self) -> None:
if self.shape_type == ShapeType.POLYGON and self.footprint is None:
raise ValueError("`footprint` must be specified for `POLYGON`.")

Expand Down
14 changes: 6 additions & 8 deletions t4_devkit/dataclass/trajectory.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generator

import numpy as np
from attrs import define, field

if TYPE_CHECKING:
from t4_devkit.typing import TrajectoryType, TranslationType

__all__ = ["Trajectory", "to_trajectories"]


@dataclass
@define
class Trajectory:
"""A dataclass to represent trajectory.
Expand Down Expand Up @@ -41,14 +41,12 @@ class Trajectory:
[2. 2. 2.]
"""

waypoints: TrajectoryType
waypoints: TrajectoryType = field(converter=np.asarray)
confidence: float = field(default=1.0)

def __post_init__(self) -> None:
if not isinstance(self.waypoints, np.ndarray):
self.waypoints = np.array(self.waypoints)

assert self.waypoints.shape[1] == 3
def __attrs_post_init__(self) -> None:
if self.waypoints.shape[1] != 3:
raise ValueError("Trajectory dimension must be 3.")

def __len__(self) -> int:
return len(self.waypoints)
Expand Down
Loading

0 comments on commit 293e144

Please sign in to comment.