diff --git a/paddleseg/models/losses/bootstrapped_cross_entropy.py b/paddleseg/models/losses/bootstrapped_cross_entropy.py index 5ca95feb69..6443ccffec 100644 --- a/paddleseg/models/losses/bootstrapped_cross_entropy.py +++ b/paddleseg/models/losses/bootstrapped_cross_entropy.py @@ -38,9 +38,10 @@ def __init__(self, min_K, loss_th, weight=None, ignore_index=255): self.ignore_index = ignore_index self.K = min_K self.threshold = loss_th + if weight is not None: + weight = paddle.to_tensor(weight, dtype='float32') self.weight = weight - self.ignore_index = ignore_index - + def forward(self, logit, label): n, c, h, w = logit.shape @@ -55,7 +56,6 @@ def forward(self, logit, label): y = paddle.transpose(y, (0, 2, 3, 1)) x = paddle.reshape(x, shape=(-1, c)) y = paddle.reshape(y, shape=(-1, )) - loss = F.cross_entropy( x, y,