-
Notifications
You must be signed in to change notification settings - Fork 0
/
Evaluator.h
48 lines (40 loc) · 1.32 KB
/
Evaluator.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#pragma once
#include <torch/torch.h>
#include "Attackers/IAttacker.h"
#include "utilities.h"
template <typename NetworkType>
class Evaluator
{
public:
Evaluator(std::shared_ptr<IAttacker<NetworkType>> attacker, const c10::Device& device) : _device(device), _attacker(attacker) {}
void evaluate_single_batch(torch::nn::ModuleHolder<NetworkType> network, torch::data::Example<>& example)
{
auto data = example.data;
auto label = example.target;
auto device = data.device();
network->to(device);
{ torch::NoGradGuard _nogradguard;
auto prediction = network(data);
_clean_accuracy.update(calculate_torch_accuracy(prediction, label));
if (_attacker->getType() != AttackType::Noop)
{
auto adversarial_input = (*_attacker)(network, data, label);
auto adv_prediction = network(data);
_adversarial_accuracy.update(calculate_torch_accuracy(adv_prediction, label));
}
}
}
std::pair<double, double> get_accuracies()
{
return std::make_pair(_clean_accuracy.getMean(), _adversarial_accuracy.getMean());
}
void reset()
{
_clean_accuracy.reset();
_adversarial_accuracy.reset();
}
average_meter _clean_accuracy = average_meter("clean accuracy");
average_meter _adversarial_accuracy = average_meter("adversarial accuracy");
c10::Device _device;
std::shared_ptr<IAttacker<NetworkType>> _attacker;
};