diff --git a/paddleseg/utils/metrics.py b/paddleseg/utils/metrics.py index ec2ef26c57..cbc4aec592 100644 --- a/paddleseg/utils/metrics.py +++ b/paddleseg/utils/metrics.py @@ -74,29 +74,33 @@ def auc_roc(logits, label, num_classes, ignore_index=None): Returns: auc_roc(float): The area under roc curve """ - if ignore_index or len(np.unique(label))>num_classes: + if ignore_index or len(np.unique(label)) > num_classes: raise RuntimeError('labels with ignore_index is not supported yet.') - + if len(label.shape) != 4: - raise ValueError('The shape of label is not 4 dimension as (N, C, H, W), it is {}'.format(label.shape)) + raise ValueError( + 'The shape of label is not 4 dimension as (N, C, H, W), it is {}'. + format(label.shape)) if len(logits.shape) != 4: - raise ValueError('The shape of logits is not 4 dimension as (N, C, H, W), it is {}'.format(logits.shape)) - - N, C, H, W = logits.shape + raise ValueError( + 'The shape of logits is not 4 dimension as (N, C, H, W), it is {}'. + format(logits.shape)) + + N, C, H, W = logits.shape logits = np.transpose(logits, (1, 0, 2, 3)) - logits = logits.reshape([C, N*H*W]).transpose([1,0]) + logits = logits.reshape([C, N * H * W]).transpose([1, 0]) label = np.transpose(label, (1, 0, 2, 3)) - label = label.reshape([1, N*H*W]).squeeze() + label = label.reshape([1, N * H * W]).squeeze() if not logits.shape[0] == label.shape[0]: raise ValueError('length of `logit` and `label` should be equal, ' 'but they are {} and {}.'.format( - pred.shape[0], label.shape[0])) - + logits.shape[0], label.shape[0])) + if num_classes == 2: - auc = skmetrics.roc_auc_score(label, logits[:,1]) + auc = skmetrics.roc_auc_score(label, logits[:, 1]) else: auc = skmetrics.roc_auc_score(label, logits, multi_class='ovr') @@ -156,7 +160,7 @@ def dice(intersect_area, pred_area, label_area): dice = (2 * intersect_area[i]) / union[i] class_dice.append(dice) mdice = np.mean(class_dice) - return np.array(class_dice), mdice + return np.array(class_dice), mdice def accuracy(intersect_area, pred_area):