-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
74 lines (52 loc) · 1.42 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import random
import kraft
import kraft.device
import kraft.optim
from kraft import nn
from kraft.nn import functional as fun
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.Linear(2, 32),
nn.ReLU(),
nn.Linear(32, 1),
nn.Sigmoid(),
)
def forward(self, xs):
return self.network(xs)
DATA = [
([0.0, 0.0], 0.0),
([0.0, 1.0], 1.0),
([1.0, 0.0], 1.0),
([1.0, 1.0], 0.0),
]
def main():
device = kraft.device.get_gpu_device()
net = MLP()
net.to_(device)
sgd = kraft.optim.SGD(net.parameters(), lr=1e-1)
n_epochs = 300
train = [
(
kraft.Variable(item[0], device=device, requires_grad=False),
kraft.Variable(item[1], device=device, requires_grad=False)
)
for item in DATA
]
for _ in range(n_epochs):
sgd.zero_grad()
random.shuffle(train)
for inputs, targets in train:
outputs = net(inputs)
loss = fun.mse_loss(outputs, targets)
loss.backward()
sgd.step()
xs = [item[0] for item in DATA]
ys = [item[1] for item in DATA]
for xs, y in zip(xs, ys):
inputs = kraft.Variable(xs, device=device)
targets = kraft.Variable(y, device=device)
print(xs, net(inputs).data, y)
if __name__ == "__main__":
main()