diff --git a/paddleseg/core/infer.py b/paddleseg/core/infer.py index bfb3888909..9d6df78b8a 100644 --- a/paddleseg/core/infer.py +++ b/paddleseg/core/infer.py @@ -42,6 +42,23 @@ def get_reverse_list(ori_shape, transforms): if op.__class__.__name__ in ['Padding']: reverse_list.append(('padding', (h, w))) w, h = op.target_size[0], op.target_size[1] + if op.__class__.__name__ in ['LimitLong']: + long_edge = max(h, w) + short_edge = min(h, w) + if ((op.max_long is not None) and (long_edge > op.max_long)): + reverse_list.append(('resize', (h, w))) + long_edge = op.max_long + short_edge = int(round(short_edge * op.max_long / long_edge)) + elif ((op.min_long is not None) and (long_edge < op.min_long)): + reverse_list.append(('resize', (h, w))) + long_edge = op.min_long + short_edge = int(round(short_edge * op.min_long / long_edge)) + if h > w: + h = long_edge + w = short_edge + else: + w = long_edge + h = short_edge return reverse_list diff --git a/paddleseg/core/val.py b/paddleseg/core/val.py index f6371a1c17..003273e01f 100644 --- a/paddleseg/core/val.py +++ b/paddleseg/core/val.py @@ -123,7 +123,8 @@ def evaluate(model, intersect_area_list = [] pred_area_list = [] label_area_list = [] - paddle.distributed.all_gather(intersect_area_list, intersect_area) + paddle.distributed.all_gather(intersect_area_list, + intersect_area) paddle.distributed.all_gather(pred_area_list, pred_area) paddle.distributed.all_gather(label_area_list, label_area) @@ -135,7 +136,8 @@ def evaluate(model, label_area_list = label_area_list[:valid] for i in range(len(intersect_area_list)): - intersect_area_all = intersect_area_all + intersect_area_list[i] + intersect_area_all = intersect_area_all + intersect_area_list[ + i] pred_area_all = pred_area_all + pred_area_list[i] label_area_all = label_area_all + label_area_list[i] else: diff --git a/paddleseg/transforms/transforms.py b/paddleseg/transforms/transforms.py index 7f285ed340..52ba7a29f7 100644 --- a/paddleseg/transforms/transforms.py +++ b/paddleseg/transforms/transforms.py @@ -228,6 +228,71 @@ def __call__(self, im, label=None): return (im, label) +@manager.TRANSFORMS.add_component +class LimitLong: + """ + Limit the long edge of image. + + If the long edge is larger than max_long, resize the long edge + to max_long, while scale the short edge proportionally. + + If the long edge is smaller than min_long, resize the long edge + to min_long, while scale the short edge proportionally. + + Args: + max_long (int, optional): If the long edge of image is larger than max_long, + it will be resize to max_long. Default: None. + min_long (int, optional): If the long edge of image is smaller than min_long, + it will be resize to min_long. Default: None. + """ + + def __init__(self, max_long=None, min_long=None): + if max_long is not None: + if not isinstance(max_long, int): + raise TypeError( + "Type of `max_long` is invalid. It should be int, but it is {}" + .format(type(max_long))) + if min_long is not None: + if not isinstance(min_long, int): + raise TypeError( + "Type of `min_long` is invalid. It should be int, but it is {}" + .format(type(min_long))) + if (max_long is not None) and (min_long is not None): + if min_long > max_long: + raise ValueError( + '`max_long should not smaller than min_long, but they are {} and {}' + .format(max_long, min_long)) + self.max_long = max_long + self.min_long = min_long + + def __call__(self, im, label=None): + """ + Args: + im (np.ndarray): The Image data. + label (np.ndarray, optional): The label data. Default: None. + + Returns: + (tuple). When label is None, it returns (im, ), otherwise it returns (im, label). + """ + h, w = im.shape[0], im.shape[1] + long_edge = max(h, w) + target = long_edge + if (self.max_long is not None) and (long_edge > self.max_long): + target = self.max_long + elif (self.min_long is not None) and (long_edge < self.min_long): + target = self.min_long + + if target != long_edge: + im = functional.resize_long(im, target) + if label is not None: + label = functional.resize_long(label, target, cv2.INTER_NEAREST) + + if label is None: + return (im, ) + else: + return (im, label) + + @manager.TRANSFORMS.add_component class ResizeRangeScaling: """