-
Notifications
You must be signed in to change notification settings - Fork 0
/
utilities.h
47 lines (38 loc) · 1.1 KB
/
utilities.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#pragma once
#include <string>
#include <torch/torch.h>
struct average_meter
{
average_meter(std::string name) : _name(name), _sum(0), _mean(0), _count(0) {}
average_meter() : average_meter("") {}
void reset()
{
_sum = 0;
_mean = 0;
_count = 0;
}
void update(long double value, bool incrementalUpdate = true, long long count = 1)
{
if (!incrementalUpdate) this->reset();
_count += count;
_sum += count * value;
_mean = _sum / _count;
}
long double getMean() { return _mean; }
long double getSum() { return _sum; }
long double getCount() { return _count; }
private:
std::string _name;
long double _sum, _mean;
long long _count;
};
/// returns the number of elements in the tensor
int get_element_count(torch::Tensor a);
/// <summary>
/// Different from is_same_size in that assert_equal_content_count checks
/// the number of elements by taking product along dimensions;
/// </summary>
/// <param name="a"></param>
/// <param name="b"></param>
void assert_equal_content_count(torch::Tensor a, torch::Tensor b);
double calculate_torch_accuracy(torch::Tensor output, torch::Tensor target);