Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Color Transfer: add Initial Reference Image parameter #2848

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion backend/src/nodes/impl/color_transfer/mean_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def scale_array(
def mean_std_transfer(
img: np.ndarray,
ref_img: np.ndarray,
init_img: np.ndarray,
colorspace: TransferColorSpace,
overflow_method: OverflowMethod,
valid_indices: np.ndarray,
Expand Down Expand Up @@ -118,12 +119,14 @@ def mean_std_transfer(
c_clip_min, c_clip_max = (-127, 127)
img = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2LAB)
init_img = cv2.cvtColor(init_img, cv2.COLOR_BGR2LAB)
elif colorspace == TransferColorSpace.RGB:
a_clip_min, a_clip_max = (0, 1)
b_clip_min, b_clip_max = (0, 1)
c_clip_min, c_clip_max = (0, 1)
img = img[:, :, :3]
ref_img = ref_img[:, :, :3]
init_img = init_img[:, :, :3]
else:
raise ValueError(f"Invalid color space {colorspace}")

Expand All @@ -135,7 +138,7 @@ def mean_std_transfer(
b_std_tar,
c_mean_tar,
c_std_tar,
) = image_stats(img[valid_indices])
) = image_stats(init_img[valid_indices])
(
a_mean_src,
a_std_src,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,17 @@ class TransferColorAlgorithm(Enum):
icon="MdInput",
inputs=[
ImageInput("Image", channels=[3, 4]),
ImageInput("Reference Image", channels=[3, 4]),
ImageInput("Goal Reference Image", channels=[3, 4]),
EnumInput(
TransferColorAlgorithm,
label="Algorithm",
option_labels=TRANSFER_COLOR_ALGORITHM_LABELS,
default=TransferColorAlgorithm.MEAN_STD,
).with_id(5),
if_enum_group(5, TransferColorAlgorithm.MEAN_STD)(
ImageInput("Initial Reference Image", channels=[3, 4])
.make_optional()
.with_id(6),
EnumInput(
TransferColorSpace,
label="Colorspace",
Expand All @@ -65,10 +68,14 @@ def color_transfer_node(
img: np.ndarray,
ref_img: np.ndarray,
algorithm: TransferColorAlgorithm,
init_img: np.ndarray | None,
colorspace: TransferColorSpace,
overflow_method: OverflowMethod,
reciprocal_scale: bool,
) -> np.ndarray:
if init_img is None:
init_img = img

_, _, img_c = get_h_w_c(img)

# Preserve alpha
Expand All @@ -77,6 +84,13 @@ def color_transfer_node(
alpha = img[:, :, 3]
bgr_img = img[:, :, :3]

_, _, init_img_c = get_h_w_c(init_img)

init_alpha = None
if init_img_c == 4:
init_alpha = init_img[:, :, 3]
bgr_init_img = init_img[:, :, :3]

_, _, ref_img_c = get_h_w_c(ref_img)

ref_alpha = None
Expand All @@ -86,9 +100,9 @@ def color_transfer_node(

# Don't process RGB data if the pixel is fully transparent, since
# such RGB data is indeterminate.
valid_rgb_indices = np.ones(img.shape[:-1], dtype=bool)
if alpha is not None:
valid_rgb_indices = alpha > 0
init_valid_rgb_indices = np.ones(init_img.shape[:-1], dtype=bool)
if init_alpha is not None:
init_valid_rgb_indices = init_alpha > 0

ref_valid_rgb_indices = np.ones(ref_img.shape[:-1], dtype=bool)
if ref_alpha is not None:
Expand All @@ -99,19 +113,20 @@ def color_transfer_node(
transfer = mean_std_transfer(
bgr_img,
bgr_ref_img,
bgr_init_img,
colorspace,
overflow_method,
reciprocal_scale=reciprocal_scale,
valid_indices=valid_rgb_indices,
valid_indices=init_valid_rgb_indices,
ref_valid_indices=ref_valid_rgb_indices,
)
elif algorithm == TransferColorAlgorithm.LINEAR_HISTOGRAM:
transfer = linear_histogram_transfer(
bgr_img, bgr_ref_img, valid_rgb_indices, ref_valid_rgb_indices
bgr_img, bgr_ref_img, init_valid_rgb_indices, ref_valid_rgb_indices
)
elif algorithm == TransferColorAlgorithm.PRINCIPAL_COLOR:
transfer = principal_color_transfer(
bgr_img, bgr_ref_img, valid_rgb_indices, ref_valid_rgb_indices
bgr_img, bgr_ref_img, init_valid_rgb_indices, ref_valid_rgb_indices
)

if alpha is not None:
Expand Down
Loading