-
Notifications
You must be signed in to change notification settings - Fork 5
/
softargmax.py
35 lines (24 loc) · 904 Bytes
/
softargmax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import numpy as np
import torch
import torch.nn as nn
def softargmax2d(input, beta=100):
*_, h, w = input.shape
input = input.reshape(*_, h * w)
input = nn.functional.softmax(beta * input, dim=-1)
indices_c, indices_r = np.meshgrid(
np.linspace(0, 1, w),
np.linspace(0, 1, h),
indexing='xy'
)
indices_r = torch.tensor(np.reshape(indices_r, (-1, h * w)))
indices_c = torch.tensor(np.reshape(indices_c, (-1, h * w)))
result_r = torch.sum((h - 1) * input * indices_r, dim=-1)
result_c = torch.sum((w - 1) * input * indices_c, dim=-1)
result = torch.stack([result_r, result_c], dim=-1)
return result
def softargmax1d(input, beta=100):
*_, n = input.shape
input = nn.functional.softmax(beta * input, dim=-1)
indices = torch.linspace(0, 1, n)
result = torch.sum((n - 1) * input * indices, dim=-1)
return result