diff --git a/paddleseg/core/infer.py b/paddleseg/core/infer.py index 29eef1f917..0cf1297c4c 100644 --- a/paddleseg/core/infer.py +++ b/paddleseg/core/infer.py @@ -99,13 +99,15 @@ def get_reverse_list(ori_shape, transforms): def reverse_transform(pred, ori_shape, transforms, mode='nearest'): """recover pred to origin shape""" reverse_list = get_reverse_list(ori_shape, transforms) + intTypeList = [paddle.int8, paddle.int16, paddle.int32, paddle.int64] + dtype = pred.dtype for item in reverse_list[::-1]: if item[0] == 'resize': h, w = item[1][0], item[1][1] - if paddle.get_device() == 'cpu': - pred = paddle.cast(pred, 'uint8') + if paddle.get_device() == 'cpu' and dtype in intTypeList: + pred = paddle.cast(pred, 'float32') pred = F.interpolate(pred, (h, w), mode=mode) - pred = paddle.cast(pred, 'int32') + pred = paddle.cast(pred, dtype) else: pred = F.interpolate(pred, (h, w), mode=mode) elif item[0] == 'padding':