Skip to content

Commit

Permalink
Revert center_crop
Browse files Browse the repository at this point in the history
  • Loading branch information
shashaka committed Dec 22, 2024
1 parent 8242f13 commit 0f6e7e0
Showing 1 changed file with 104 additions and 107 deletions.
211 changes: 104 additions & 107 deletions keras/src/layers/preprocessing/image_preprocessing/center_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,85 +94,84 @@ def _get_clipped_bbox(bounding_boxes, h_end, h_start, w_end, w_start):
)
return bounding_boxes

if training:
input_shape = transformation["input_shape"]
input_shape = transformation["input_shape"]

init_height, init_width = _get_height_width(input_shape)
init_height, init_width = _get_height_width(input_shape)

bounding_boxes = convert_format(
bounding_boxes,
source=self.bounding_box_format,
target="xyxy",
height=init_height,
width=init_width,
)
bounding_boxes = convert_format(
bounding_boxes,
source=self.bounding_box_format,
target="xyxy",
height=init_height,
width=init_width,
)

h_diff = init_height - self.height
w_diff = init_width - self.width
h_diff = init_height - self.height
w_diff = init_width - self.width

if h_diff >= 0 and w_diff >= 0:
h_start = int(h_diff / 2)
w_start = int(w_diff / 2)
if h_diff >= 0 and w_diff >= 0:
h_start = int(h_diff / 2)
w_start = int(w_diff / 2)

h_end = h_start + self.height
w_end = w_start + self.width
h_end = h_start + self.height
w_end = w_start + self.width

bounding_boxes = _get_clipped_bbox(
bounding_boxes, h_end, h_start, w_end, w_start
)
else:
width = init_width
height = init_height
target_height = self.height
target_width = self.width

crop_height = int(float(width * target_height) / target_width)
crop_height = max(min(height, crop_height), 1)
crop_width = int(float(height * target_width) / target_height)
crop_width = max(min(width, crop_width), 1)
crop_box_hstart = int(float(height - crop_height) / 2)
crop_box_wstart = int(float(width - crop_width) / 2)

h_start = crop_box_hstart
w_start = crop_box_wstart

h_end = crop_box_hstart + crop_height
w_end = crop_box_wstart + crop_width
bounding_boxes = _get_clipped_bbox(
bounding_boxes, h_end, h_start, w_end, w_start
)

bounding_boxes = convert_format(
bounding_boxes,
source="xyxy",
target="rel_xyxy",
height=crop_height,
width=crop_width,
)

bounding_boxes = convert_format(
bounding_boxes,
source="rel_xyxy",
target="xyxy",
height=self.height,
width=self.width,
)

bounding_boxes = clip_to_image_size(
bounding_boxes=bounding_boxes,
height=self.height,
width=self.width,
bounding_box_format="xyxy",
bounding_boxes = _get_clipped_bbox(
bounding_boxes, h_end, h_start, w_end, w_start
)
else:
width = init_width
height = init_height
target_height = self.height
target_width = self.width

crop_height = int(float(width * target_height) / target_width)
crop_height = max(min(height, crop_height), 1)
crop_width = int(float(height * target_width) / target_height)
crop_width = max(min(width, crop_width), 1)
crop_box_hstart = int(float(height - crop_height) / 2)
crop_box_wstart = int(float(width - crop_width) / 2)

h_start = crop_box_hstart
w_start = crop_box_wstart

h_end = crop_box_hstart + crop_height
w_end = crop_box_wstart + crop_width
bounding_boxes = _get_clipped_bbox(
bounding_boxes, h_end, h_start, w_end, w_start
)

bounding_boxes = convert_format(
bounding_boxes,
source="xyxy",
target=self.bounding_box_format,
target="rel_xyxy",
height=crop_height,
width=crop_width,
)

bounding_boxes = convert_format(
bounding_boxes,
source="rel_xyxy",
target="xyxy",
height=self.height,
width=self.width,
)

bounding_boxes = clip_to_image_size(
bounding_boxes=bounding_boxes,
height=self.height,
width=self.width,
bounding_box_format="xyxy",
)

bounding_boxes = convert_format(
bounding_boxes,
source="xyxy",
target=self.bounding_box_format,
height=self.height,
width=self.width,
)

return bounding_boxes

def transform_segmentation_masks(
Expand All @@ -184,62 +183,60 @@ def transform_segmentation_masks(

def transform_images(self, images, transformation=None, training=True):
inputs = self.backend.cast(images, self.compute_dtype)
if training:
if self.data_format == "channels_first":
init_height = inputs.shape[-2]
init_width = inputs.shape[-1]
else:
init_height = inputs.shape[-3]
init_width = inputs.shape[-2]
if self.data_format == "channels_first":
init_height = inputs.shape[-2]
init_width = inputs.shape[-1]
else:
init_height = inputs.shape[-3]
init_width = inputs.shape[-2]

if init_height is None or init_width is None:
# Dynamic size case. TODO.
raise ValueError(
"At this time, CenterCrop can only "
"process images with a static spatial "
f"shape. Received: inputs.shape={inputs.shape}"
)
if init_height is None or init_width is None:
# Dynamic size case. TODO.
raise ValueError(
"At this time, CenterCrop can only "
"process images with a static spatial "
f"shape. Received: inputs.shape={inputs.shape}"
)

h_diff = init_height - self.height
w_diff = init_width - self.width
h_diff = init_height - self.height
w_diff = init_width - self.width

h_start = int(h_diff / 2)
w_start = int(w_diff / 2)
h_start = int(h_diff / 2)
w_start = int(w_diff / 2)

if h_diff >= 0 and w_diff >= 0:
if len(inputs.shape) == 4:
if self.data_format == "channels_first":
return inputs[
:,
:,
h_start : h_start + self.height,
w_start : w_start + self.width,
]
if h_diff >= 0 and w_diff >= 0:
if len(inputs.shape) == 4:
if self.data_format == "channels_first":
return inputs[
:,
:,
h_start : h_start + self.height,
w_start : w_start + self.width,
:,
]
elif len(inputs.shape) == 3:
if self.data_format == "channels_first":
return inputs[
:,
h_start : h_start + self.height,
w_start : w_start + self.width,
]
return inputs[
:,
h_start : h_start + self.height,
w_start : w_start + self.width,
:,
]
elif len(inputs.shape) == 3:
if self.data_format == "channels_first":
return inputs[
:,
h_start : h_start + self.height,
w_start : w_start + self.width,
:,
]
return image_utils.smart_resize(
inputs,
[self.height, self.width],
data_format=self.data_format,
backend_module=self.backend,
)
return images
return inputs[
h_start : h_start + self.height,
w_start : w_start + self.width,
:,
]
return image_utils.smart_resize(
inputs,
[self.height, self.width],
data_format=self.data_format,
backend_module=self.backend,
)

def compute_output_shape(self, input_shape):
input_shape = list(input_shape)
Expand Down

0 comments on commit 0f6e7e0

Please sign in to comment.