Skip to content

Commit

Permalink
defenses: add unit test for hardness manipulation defense
Browse files Browse the repository at this point in the history
  • Loading branch information
cdluminate committed May 12, 2022
1 parent 25501d3 commit d229f74
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 12 deletions.
11 changes: 1 addition & 10 deletions robrank/defenses/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def __init__(self):
self.dataset = 'mnist'
self.loss = 'ptripletN'
self.lossfunc = ptripletN()
self.is_advtrain_pnp = True
self.config = rc2f2(self.dataset, self.loss)
self.metric = self.lossfunc.determine_metric()
self.datasetspec = self.lossfunc.datasetspec()
#self.is_advtrain_pnp = True
def forward(self, x):
x = th.nn.Flatten()(x)
return self.fc(x)
Expand All @@ -42,12 +42,3 @@ def test_testnet():
output = model.forward(images)
loss = output.mean()
loss.backward()


@pytest.mark.skip(reason='this is test helper')
def test_xxx_training_step(training_step: callable):
model = TestNet()
images = th.rand(10, 1, 28, 28)
labels = th.stack([th.arange(5), th.arange(5)]).T.flatten()
loss = training_step(model, (images, labels), 0)
loss.backward()
40 changes: 38 additions & 2 deletions robrank/defenses/test_defenses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,33 @@
See the License for the specific language governing permissions and
limitations under the License.
'''
import itertools as it
import torch as th
from ..losses import ptripletN
from . import test_common
from .test_common import TestNet
from .pnp import *
from .amd import *
from .est import *
from .ses import *
import pytest

@pytest.mark.skip(reason='this is test helper')
def __test_xxx_training_step(training_step: callable):
model = TestNet()
images = th.rand(10, 1, 28, 28)
labels = th.stack([th.arange(5), th.arange(5)]).T.flatten()
loss = training_step(model, (images, labels), 0)
loss.backward()

@pytest.mark.skip(reason='this is test helper')
def __test_hm_training_step(g, r, hm, srch, dsth, ics):
model = TestNet()
images = th.rand(10, 1, 28, 28)
labels = th.stack([th.arange(5), th.arange(5)]).T.flatten()
loss = hm_training_step(model, (images, labels), 0,
gradual=g, fix_anchor=r, hm=hm, srch=srch, desth=dsth,
ics=ics)
loss.backward()

@pytest.mark.parametrize('ts', [
pnp_training_step,
Expand All @@ -30,4 +48,22 @@
ses_training_step,
])
def test_xxx_training_step(ts: callable):
return test_common.test_xxx_training_step(ts)
__test_xxx_training_step(ts)

H_MAP = {'r': 'spc2-random',
'm': 'spc2-semihard',
's': 'spc2-softhard',
'd': 'spc2-distance',
'h': 'spc2-hard'}

@pytest.mark.parametrize('g, r, hm, srch, dsth, ics',
it.product(
(False, True),
(False,),
('KL', 'L2', 'ET'),
'rmsdh',
'rmsdh',
(False, True),
))
def test_hm_training_step(g, r, hm, srch, dsth, ics):
__test_hm_training_step(g, r, hm, H_MAP[srch], H_MAP[dsth], ics)

0 comments on commit d229f74

Please sign in to comment.