-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss_func.py
50 lines (43 loc) · 1.8 KB
/
loss_func.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
'''
Description:
Loss functions for model training.
Author:
Jiaqi Zhang <jiaqi_zhang2@brown.edu>
'''
import numpy as np
import torch.nn as nn
import geomloss
# ===========================================
def MSELoss(true_obs, est_obs, reduction="mean"):
'''
Mean squared error (MSE).
:param true_obs (torch.FloatTensor): True expression data.
:param est_obs (torch.FloatTensor): Predicted expression data.
:return: (float) MSE value.
'''
loss_func = nn.MSELoss(reduction=reduction)
return loss_func(est_obs, true_obs)
def SinkhornLoss(true_obs, est_obs, blur=0.05, scaling=0.5, batch_size=None):
'''
Wasserstein distance computed by Sinkhorn algorithm.
:param true_obs (torch.FloatTensor): True expression data.
:param est_obs (torch.FloatTensor): Predicted expression data.
:param blur (float): Sinkhorn algorithm hyperparameter. Default as 0.05.
:param scaling (float): Sinkhorn algorithm hyperparameter. Default as 0.5.
:param batch_size (None or int): Either None indicates using all true cell in computation, or an integer indicates
using only a batch of true cells to save computational costs.
:return: (float) Wasserstein distance average over all timepoints.
'''
n_tps = len(true_obs)
loss = 0.0
ot_solver = geomloss.SamplesLoss("sinkhorn", p=2, blur=blur, scaling=scaling, debias=True, backend="tensorized")
for t in range(n_tps):
t_est = est_obs[:, t, :]
t_true = true_obs[t]
if batch_size is not None:
cell_idx = np.random.choice(np.arange(t_true.shape[0]), size = batch_size, replace = (t_true.shape[0] < batch_size))
t_true = t_true[cell_idx, :]
t_loss = ot_solver(t_true, t_est)
loss += t_loss
loss = loss / n_tps
return loss