You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
hello,
I tried to fit a curve (discrete points) using Soft-DTW-Loss as a loss function. But the loss does not converge to the exact result in the end. Is there something wrong with the way I am using it?
The code is as follows:
if name == "main":
batch_size = 1
len_x = 15
len_predict = 10
dims = 1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
x = torch.unsqueeze(torch.linspace(1, 4, steps=len_x, requires_grad=True), dim=0)
y = x ** 2
y = y.view(1, len_x, 1)
x = x.view(1, len_x, 1)
#(batch,length,dims)---->(1,15,2)
truth_points = torch.cat((y, x), dim=2).cuda()
#(1,20)
input = torch.unsqueeze(torch.linspace(1, 4, steps=len_predict*2, requires_grad=True), dim=0).cuda()
class testNN(torch.nn.Module):
def __init__(self):
super(testNN, self).__init__()
self.layer = nn.Sequential(
nn.Linear(20, 50),
nn.ReLU(),
nn.Linear(50, 200),
nn.ReLU(),
nn.Linear(200, 50),
nn.ReLU(),
nn.Linear(50, 20),
nn.ReLU(),
)
def forward(self, x):
x = self.layer(x)
return x
test = testNN()
test = test.to(device)
loss_function = SoftDTW(use_cuda=True, gamma=0.01, normalize=False)
optimizer = torch.optim.Adam(test.parameters(), lr=0.01)
for epoch in range(1000):
predict = test(input)
#(1,20) reshape to (1,10,2)
predict = predict.reshape(1, len_predict, 2)
loss = loss_function(predict, truth_points)
optimizer.zero_grad()
loss.mean().backward(retain_graph=True)
optimizer.step()
if epoch % 10 == 0:
print("epoch : %d | loss : %f" % (epoch, loss))
plt_predict = predict.cpu().detach().numpy()
# print(plt_predict)
plt_predict = plt_predict.reshape(1, len_predict, 2)
print(plt_predict[0, :, 0])
print(plt_predict[0, :, 1])
The text was updated successfully, but these errors were encountered:
hello,
I tried to fit a curve (discrete points) using Soft-DTW-Loss as a loss function. But the loss does not converge to the exact result in the end. Is there something wrong with the way I am using it?
The code is as follows:
if name == "main":
The text was updated successfully, but these errors were encountered: