Skip to content

Commit

Permalink
Clear optical distortion (#2176)
Browse files Browse the repository at this point in the history
* Empty-Commit

* Cleanup in Optical Distortion

* Fixed fisheye
  • Loading branch information
ternaus authored Dec 5, 2024
1 parent a9000e2 commit 9995c60
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 55 deletions.
36 changes: 19 additions & 17 deletions albumentations/augmentations/geometric/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3356,26 +3356,23 @@ def tps_transform(

def get_camera_matrix_distortion_maps(
image_shape: tuple[int, int],
cx: float,
cy: float,
k: float,
center_xy: tuple[float, float],
) -> tuple[np.ndarray, np.ndarray]:
"""Generate distortion maps using camera matrix model.
Args:
image_shape: Image shape
cx: x-coordinate of distortion center
cy: y-coordinate of distortion center
k: Distortion coefficient
center_xy: Center of distortion
Returns:
tuple of:
- map_x: Horizontal displacement map
- map_y: Vertical displacement map
"""
height, width = image_shape[:2]
camera_matrix = np.array(
[[width, 0, cx], [0, height, cy], [0, 0, 1]],
[[width, 0, center_xy[0]], [0, height, center_xy[1]], [0, 0, 1]],
dtype=np.float32,
)
distortion = np.array([k, k, 0, 0, 0], dtype=np.float32)
Expand All @@ -3391,38 +3388,43 @@ def get_camera_matrix_distortion_maps(

def get_fisheye_distortion_maps(
image_shape: tuple[int, int],
cx: float,
cy: float,
k: float,
center_xy: tuple[float, float],
) -> tuple[np.ndarray, np.ndarray]:
"""Generate distortion maps using fisheye model.
Args:
image_shape: Image shape
cx: x-coordinate of distortion center
cy: y-coordinate of distortion center
k: Distortion coefficient
center_xy: Center of distortion
Returns:
tuple of:
- map_x: Horizontal displacement map
- map_y: Vertical displacement map
"""
height, width = image_shape[:2]

center_x, center_y = center_xy

# Create coordinate grid
y, x = np.mgrid[:height, :width].astype(np.float32)
x = x - cx
y = y - cy

x = x - center_x
y = y - center_y

# Calculate polar coordinates
r = np.sqrt(x * x + y * y)
theta = np.arctan2(y, x)

# Apply fisheye distortion
r_dist = r * (1 + k * r * r)
# Normalize radius by the maximum possible radius to keep distortion in check
max_radius = math.sqrt(max(center_x, width - center_x) ** 2 + max(center_y, height - center_y) ** 2)
r_norm = r / max_radius

# Apply fisheye distortion to normalized radius
r_dist = r * (1 + k * r_norm * r_norm)

# Convert back to cartesian coordinates
map_x = cx + r_dist * np.cos(theta)
map_y = cy + r_dist * np.sin(theta)
map_x = r_dist * np.cos(theta) + center_x
map_y = r_dist * np.sin(theta) + center_y

return map_x, map_y
28 changes: 8 additions & 20 deletions albumentations/augmentations/geometric/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,13 +1532,6 @@ class OpticalDistortion(BaseDistortion):
For fisheye model: recommended range (-0.3, 0.3)
Default: (-0.05, 0.05)
shift_limit (float | tuple[float, float]): Range of relative shifts for the image center.
Values are multiplied by image dimensions to get absolute shift in pixels:
- dx = shift_x * image_width
- dy = shift_y * image_height
If shift_limit is a single float value, the range will be (-shift_limit, shift_limit).
Default: (-0.05, 0.05)
mode (Literal['camera', 'fisheye']): Distortion model to use:
- 'camera': Original camera matrix model
- 'fisheye': Fisheye lens model
Expand Down Expand Up @@ -1571,7 +1564,7 @@ class OpticalDistortion(BaseDistortion):
Example:
>>> import albumentations as A
>>> transform = A.Compose([
... A.OpticalDistortion(distort_limit=0.1, shift_limit=0.1, p=1.0),
... A.OpticalDistortion(distort_limit=0.1, p=1.0),
... ])
>>> transformed = transform(image=image, mask=mask, bboxes=bboxes, keypoints=keypoints)
>>> transformed_image = transformed['image']
Expand All @@ -1582,8 +1575,10 @@ class OpticalDistortion(BaseDistortion):

class InitSchema(BaseDistortion.InitSchema):
distort_limit: SymmetricRangeType
shift_limit: SymmetricRangeType
mode: Literal["camera", "fisheye"]
shift_limit: SymmetricRangeType | None = Field(
deprecated="Deprecated. Does not have any effect.",
)
value: ColorType | None = Field(
deprecated="Deprecated. Does not have any effect.",
)
Expand All @@ -1597,7 +1592,7 @@ class InitSchema(BaseDistortion.InitSchema):
def __init__(
self,
distort_limit: ScaleFloatType = (-0.05, 0.05),
shift_limit: ScaleFloatType = (-0.05, 0.05),
shift_limit: ScaleFloatType | None = None,
interpolation: int = cv2.INTER_LINEAR,
border_mode: int | None = None,
value: ColorType | None = None,
Expand All @@ -1612,7 +1607,6 @@ def __init__(
mask_interpolation=mask_interpolation,
p=p,
)
self.shift_limit = cast(tuple[float, float], shift_limit)
self.distort_limit = cast(tuple[float, float], distort_limit)
self.mode = mode

Expand All @@ -1628,33 +1622,27 @@ def get_params_dependent_on_data(
k = self.py_random.uniform(*self.distort_limit)

# Calculate center shift
dx = round(self.py_random.uniform(*self.shift_limit) * width)
dy = round(self.py_random.uniform(*self.shift_limit) * height)
cx = width * 0.5 + dx
cy = height * 0.5 + dy
center_xy = fgeometric.center(image_shape)

# Get distortion maps based on mode
if self.mode == "camera":
map_x, map_y = fgeometric.get_camera_matrix_distortion_maps(
image_shape,
cx,
cy,
k,
center_xy,
)
else: # fisheye
map_x, map_y = fgeometric.get_fisheye_distortion_maps(
image_shape,
cx,
cy,
k,
center_xy,
)

return {"map_x": map_x, "map_y": map_y}

def get_transform_init_args_names(self) -> tuple[str, ...]:
return (
"distort_limit",
"shift_limit",
"mode",
*super().get_transform_init_args_names(),
)
Expand Down
3 changes: 1 addition & 2 deletions tests/aug_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@
A.OpticalDistortion,
{
"distort_limit": 0.2,
"shift_limit": 0.2,
"interpolation": cv2.INTER_CUBIC,
"interpolation": cv2.INTER_AREA,
},
],
[
Expand Down
4 changes: 0 additions & 4 deletions tests/files/transform_serialization_v2_with_totensor.json
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,6 @@
-0.05,
0.05
],
"shift_limit": [
-0.05,
0.05
],
"interpolation": 1,
"mode": "camera"
},
Expand Down
4 changes: 0 additions & 4 deletions tests/files/transform_serialization_v2_without_totensor.json
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,6 @@
-0.05,
0.05
],
"shift_limit": [
-0.05,
0.05
],
"interpolation": 1,
"mode": "camera"
},
Expand Down
4 changes: 0 additions & 4 deletions tests/files/transform_v1.1.0_with_totensor.json
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,6 @@
-0.05,
0.05
],
"shift_limit": [
-0.05,
0.05
],
"interpolation": 1,
"mode": "camera"
},
Expand Down
4 changes: 0 additions & 4 deletions tests/files/transform_v1.1.0_without_totensor.json
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,6 @@
-0.05,
0.05
],
"shift_limit": [
-0.05,
0.05
],
"interpolation": 1,
"mode": "camera"
},
Expand Down

0 comments on commit 9995c60

Please sign in to comment.