diff --git a/trattack/attack_methods.py b/trattack/attack_methods.py index b227b4e..645aadd 100644 --- a/trattack/attack_methods.py +++ b/trattack/attack_methods.py @@ -238,7 +238,7 @@ def tr_attack_adaptive_iter(model, data, target, eps, c = 9, p = 2, iter = 100, update_num += torch.sum(tmp_mask.long()) if torch.sum(tmp_mask.long()) < 1: return X_adv.cpu(), update_num - attack_mask = tmp_mask.nonzero().view(-1) + attack_mask = tmp_mask.nonzero().view(-1).cpu() X_adv[attack_mask,:], eps[attack_mask,:] = tr_attack_adaptive(model, X_adv[attack_mask,:], target[attack_mask], target_ind[attack_mask], eps[attack_mask,:], p = p) return X_adv.cpu(), update_num