-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding adversarial weight perturbation protocol #2224
Conversation
Signed-off-by: Muhammad Zaid Hameed <Zaid.Hameed@ibm.com>
Codecov Report
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. @@ Coverage Diff @@
## dev_1.16.0 #2224 +/- ##
==============================================
+ Coverage 84.76% 84.85% +0.09%
==============================================
Files 313 315 +2
Lines 27810 28054 +244
Branches 5086 5123 +37
==============================================
+ Hits 23572 23805 +233
+ Misses 2948 2941 -7
- Partials 1290 1308 +18
... and 6 files with indirect coverage changes 📢 Have feedback on the report? Share it here. |
def fit_generator( # pylint: disable=W0221 | ||
self, | ||
generator: DataGenerator, | ||
validation_data: Optional[Tuple[np.ndarray, np.ndarray]] = None, | ||
nb_epochs: int = 20, | ||
**kwargs | ||
): |
Check notice
Code scanning / CodeQL
Mismatch between signature and use of an overridden method
Signed-off-by: Muhammad Zaid Hameed <Zaid.Hameed@ibm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Zaid-Hameed Thank you very much for your pull request! I have added a few minor comments on using properties to avoid pylint warnings. What do you think?
Have you tested your code on GPUs?
from art.utils import CLASSIFIER_LOSS_GRADIENTS_TYPE | ||
|
||
|
||
class AdversarialTrainerAWP(Trainer, abc.ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is inheriting from abc.ABC
required here? Trainer
is already inheriting from abc.ABC
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done by removing abc.ABC.
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
# SOFTWARE. | ||
""" | ||
This module implements adversarial training with AWP protocol. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's introduce the abbreviation AWP somewhere.
This module implements adversarial training with AWP protocol. | |
This module implements adversarial training with Adversarial Weight Perturbation (AWP) protocol. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
class AdversarialTrainerAWPPyTorch(AdversarialTrainerAWP): | ||
""" | ||
Class performing adversarial training following AWP protocol. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Class performing adversarial training following AWP protocol. | |
Class performing adversarial training following Adversarial Weight Perturbation (AWP) protocol. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
import torch | ||
|
||
logger = logging.getLogger(__name__) | ||
EPS = 1e-8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the definition of EPS
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added definition.
self._classifier: PyTorchClassifier | ||
self._proxy_classifier: PyTorchClassifier | ||
self._attack: EvasionAttack | ||
self._mode: str | ||
self.gamma: float | ||
self._beta: float | ||
self._warmup: int | ||
self._apply_wp: bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these type assignments needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, for static type checking by mypy.
|
||
params_dict = OrderedDict() # type: ignore | ||
list_params = [] | ||
for name, param in p_classifier._model.state_dict().items(): # pylint: disable=W0212 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for name, param in p_classifier._model.state_dict().items(): # pylint: disable=W0212 | |
for name, param in p_classifier.model.state_dict().items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
self, p_classifier: PyTorchClassifier, list_keys: List[str], w_perturb: Dict[str, "torch.Tensor"], op: str | ||
) -> None: | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a description of the method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added description.
:param w_perturb: dictionary containing model parameters' names as keys and model parameters as values | ||
:param op: controls whether weight perturbation will be added or subtracted from model parameters | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
else: | ||
raise ValueError("Incorrect op provided for weight perturbation. 'op' must be among 'add' and 'subtract'.") | ||
with torch.no_grad(): | ||
for name, param in p_classifier._model.named_parameters(): # pylint: disable=W0212 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for name, param in p_classifier._model.named_parameters(): # pylint: disable=W0212 | |
for name, param in p_classifier.model.named_parameters(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
examples/adversarial_training_awp.py
Outdated
) | ||
|
||
|
||
# Build a Keras image augmentation object and wrap it in ART |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it Keras?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to correct description.
Signed-off-by: Muhammad Zaid Hameed <Zaid.Hameed@ibm.com>
Signed-off-by: Muhammad Zaid Hameed <Zaid.Hameed@ibm.com>
Signed-off-by: Muhammad Zaid Hameed <Zaid.Hameed@ibm.com>
Description
AWP is an important adversarial training approach because it provides better robustness against adversarial attacks and mitigates robust overfitting. AWP has been proposed in paper "Adversarial Weight Perturbation Helps
Robust Generalization".
Paper link: https://proceedings.neurips.cc/paper/2020/file/1ef91c212e30e14bf125e9374262401f-Paper.pdf
It is also a base component of more advanced adversarial training approaches.
Fixes #2164
Type of change
Please check all relevant options.
Testing
Please describe the tests that you ran to verify your changes. Consider listing any relevant details of your test configuration.
Test Configuration:
Checklist