-
Notifications
You must be signed in to change notification settings - Fork 4
/
loss.py
74 lines (62 loc) · 3.09 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from utils import torch_skew_symmetric
def batch_episym(x1, x2, F):
batch_size, num_pts = x1.shape[0], x1.shape[1]
x1 = torch.cat([x1, x1.new_ones(batch_size, num_pts,1)], dim=-1).reshape(batch_size, num_pts,3,1)
x2 = torch.cat([x2, x2.new_ones(batch_size, num_pts,1)], dim=-1).reshape(batch_size, num_pts,3,1)
F = F.reshape(-1,1,3,3).repeat(1,num_pts,1,1)
x2Fx1 = torch.matmul(x2.transpose(2,3), torch.matmul(F, x1)).reshape(batch_size,num_pts)
Fx1 = torch.matmul(F,x1).reshape(batch_size,num_pts,3)
Ftx2 = torch.matmul(F.transpose(2,3),x2).reshape(batch_size,num_pts,3)
ys = x2Fx1**2 * (
1.0 / (Fx1[:, :, 0]**2 + Fx1[:, :, 1]**2 + 1e-15) +
1.0 / (Ftx2[:, :, 0]**2 + Ftx2[:, :, 1]**2 + 1e-15))
return ys
class MatchLoss(object):
def __init__(self, config):
self.loss_essential = config.loss_essential
self.loss_classif = config.loss_classif
self.ess_loss_margin = config.ess_loss_margin
self.obj_geod_th = config.obj_geod_th
self.loss_essential_init_iter = config.loss_essential_init_iter
def weight_estimation(self, gt_geod_d, is_pos, ones):
dis = torch.abs(gt_geod_d - self.obj_geod_th) / self.obj_geod_th
weight_p = torch.exp(-dis)
weight_p = weight_p*is_pos
weight_n = ones
weight_n = weight_n*(1 - is_pos)
weight = weight_p + weight_n
return weight
def run(self, global_step, data, logits, ys, e_hat, y_hat):
R_in, t_in, xs, pts_virt = data['Rs'], data['ts'], data['xs'], data['virtPts']
pts1_virts, pts2_virts = pts_virt[:, :, :2], pts_virt[:,:,2:]
loss = 0
classif_loss = 0
# Classification loss
with torch.no_grad():
ones = torch.ones((xs.shape[0], 1)).to(xs.device)
for i in range(len(logits)):
gt_geod_d = ys[i]
is_pos = (gt_geod_d < self.obj_geod_th).type(gt_geod_d.type())
is_neg = (gt_geod_d >= self.obj_geod_th).type(gt_geod_d.type())
with torch.no_grad():
pos = torch.sum(is_pos, dim=-1, keepdim=True)
pos_num = F.relu(pos - 1) + 1
neg = torch.sum(is_neg, dim=-1, keepdim=True)
neg_num = F.relu(neg - 1) + 1
pos_w = neg_num / pos_num
pos_w = torch.max(pos_w, ones)
weight = self.weight_estimation(gt_geod_d, is_pos, ones)
classif_loss += F.binary_cross_entropy_with_logits(weight * logits[i], is_pos, pos_weight=pos_w)
geod = batch_episym(pts1_virts, pts2_virts, e_hat[-1])
e_l = torch.min(geod, self.ess_loss_margin*geod.new_ones(geod.shape))
essential_loss = e_l.mean()
# Check global_step and add essential loss
if self.loss_essential > 0 and global_step >= self.loss_essential_init_iter:
loss += self.loss_essential * essential_loss
if self.loss_classif > 0:
loss += self.loss_classif * classif_loss
return loss, essential_loss.item(), classif_loss.item()