Skip to content

Commit

Permalink
[prototype] Speed up adjust_hue_image_tensor (#6938)
Browse files Browse the repository at this point in the history
* Performance optimization on adjust_hue_image_tensor

* handle ints

* Inplace logical ops

* Remove unnecessary casting.

* Fix linter.
  • Loading branch information
datumbox authored Nov 10, 2022
1 parent 70edf96 commit d72e906
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
21 changes: 11 additions & 10 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,10 @@ def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor:

mask_maxc_neq_r = maxc != r
mask_maxc_eq_g = maxc == g
mask_maxc_neq_g = ~mask_maxc_eq_g

hr = (bc - gc).mul_(~mask_maxc_neq_r)
hg = (2.0 + rc).sub_(bc).mul_(mask_maxc_eq_g & mask_maxc_neq_r)
hb = (4.0 + gc).sub_(rc).mul_(mask_maxc_neq_g & mask_maxc_neq_r)
hg = rc.add(2.0).sub_(bc).mul_(mask_maxc_eq_g & mask_maxc_neq_r)
hr = bc.sub_(gc).mul_(~mask_maxc_neq_r)
hb = gc.add_(4.0).sub_(rc).mul_(mask_maxc_neq_r.logical_and_(mask_maxc_eq_g.logical_not_()))

h = hr.add_(hg).add_(hb)
h = h.mul_(1.0 / 6.0).add_(1.0).fmod_(1.0)
Expand All @@ -221,14 +220,16 @@ def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor:

def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
h, s, v = img.unbind(dim=-3)
h6 = h * 6
h6 = h.mul(6)
i = torch.floor(h6)
f = h6 - i
f = h6.sub_(i)
i = i.to(dtype=torch.int32)

p = (v * (1.0 - s)).clamp_(0.0, 1.0)
q = (v * (1.0 - s * f)).clamp_(0.0, 1.0)
t = (v * (1.0 - s * (1.0 - f))).clamp_(0.0, 1.0)
sxf = s * f
one_minus_s = 1.0 - s
q = (1.0 - sxf).mul_(v).clamp_(0.0, 1.0)
t = sxf.add_(one_minus_s).mul_(v).clamp_(0.0, 1.0)
p = one_minus_s.mul_(v).clamp_(0.0, 1.0)
i.remainder_(6)

mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
Expand All @@ -238,7 +239,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
a3 = torch.stack((p, p, t, v, v, q), dim=-3)
a4 = torch.stack((a1, a2, a3), dim=-4)

return (a4.mul_(mask.to(dtype=img.dtype).unsqueeze(dim=-4))).sum(dim=-3)
return (a4.mul_(mask.unsqueeze(dim=-4))).sum(dim=-3)


def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def convert_format_bounding_box(
if new_format == old_format:
return bounding_box

# TODO: Add _xywh_to_cxcywh and _cxcywh_to_xywh to improve performance
if old_format == BoundingBoxFormat.XYWH:
bounding_box = _xywh_to_xyxy(bounding_box, inplace)
elif old_format == BoundingBoxFormat.CXCYWH:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import unittest.mock
from typing import Any, Dict, Tuple, Union

import numpy as np
Expand All @@ -20,6 +19,8 @@ def decode_image_with_pil(encoded_image: torch.Tensor) -> features.Image:

@torch.jit.unused
def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
import unittest.mock

with unittest.mock.patch("torchvision.io.video.os.path.exists", return_value=True):
return read_video(ReadOnlyTensorBuffer(encoded_video)) # type: ignore[arg-type]

Expand Down

0 comments on commit d72e906

Please sign in to comment.