-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtest2.py
41 lines (35 loc) · 1.06 KB
/
test2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from torch import nn
import torchsnooper
# example 1
# @torchsnooper.snoop()
# def myfunc(mask, x):
# y = torch.zeros(6, device='cuda')
# y.masked_scatter_(mask, x)
# return y
#
# mask = torch.tensor([0, 1, 0, 1, 1, 0], device='cuda')
# source = torch.tensor([1.0, 2.0, 3.0], device='cuda')
# mask = torch.tensor([0, 1, 0, 1, 1, 0], device='cuda', dtype=torch.uint8)
# y = myfunc(mask, source)
# example 2
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(2, 1)
def forward(self, x):
return self.layer(x)
x = torch.tensor([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
y = torch.tensor([3.0, 5.0, 4.0, 6.0])
model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
with torchsnooper.snoop():
for _ in range(100):
optimizer.zero_grad()
# pred = model(x)
pred = model(x).squeeze()
squared_diff = (y - pred) ** 2
loss = squared_diff.mean()
print(loss.item())
loss.backward()
optimizer.step()