forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_lite_trainer.cpp
273 lines (253 loc) · 8.35 KB
/
test_lite_trainer.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
#include <gtest/gtest.h>
#include <c10/core/TensorOptions.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/mobile/export_data.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/import_data.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/mobile/optim/sgd.h>
#include <torch/csrc/jit/mobile/sequential.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/data/dataloader.h>
#include <torch/torch.h>
// Tests go in torch::jit
namespace torch {
namespace jit {
TEST(LiteTrainerTest, Params) {
Module m("m");
m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false);
m.define(R"(
def forward(self, x):
b = 1.0
return self.foo * x + b
)");
double learning_rate = 0.1, momentum = 0.1;
int n_epoc = 10;
// init: y = x + 1;
// target: y = 2 x + 1
std::vector<std::pair<Tensor, Tensor>> trainData{
{1 * torch::ones({1}), 3 * torch::ones({1})},
};
// Reference: Full jit
std::stringstream ms;
m.save(ms);
auto mm = load(ms);
// mm.train();
std::vector<::at::Tensor> parameters;
for (auto parameter : mm.parameters()) {
parameters.emplace_back(parameter);
}
::torch::optim::SGD optimizer(
parameters, ::torch::optim::SGDOptions(learning_rate).momentum(momentum));
for (int epoc = 0; epoc < n_epoc; ++epoc) {
for (auto& data : trainData) {
auto source = data.first, targets = data.second;
optimizer.zero_grad();
std::vector<IValue> train_inputs{source};
auto output = mm.forward(train_inputs).toTensor();
auto loss = ::torch::l1_loss(output, targets);
loss.backward();
optimizer.step();
}
}
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
std::vector<::at::Tensor> bc_parameters = bc.parameters();
::torch::optim::SGD bc_optimizer(
bc_parameters,
::torch::optim::SGDOptions(learning_rate).momentum(momentum));
for (int epoc = 0; epoc < n_epoc; ++epoc) {
for (auto& data : trainData) {
auto source = data.first, targets = data.second;
bc_optimizer.zero_grad();
std::vector<IValue> train_inputs{source};
auto output = bc.forward(train_inputs).toTensor();
auto loss = ::torch::l1_loss(output, targets);
loss.backward();
bc_optimizer.step();
}
}
AT_ASSERT(parameters[0].item<float>() == bc_parameters[0].item<float>());
}
TEST(MobileTest, NamedParameters) {
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
m.define(R"(
def add_it(self, x):
b = 4
return self.foo + x + b
)");
Module child("m2");
child.register_parameter("foo", 4 * torch::ones({}), false);
child.register_parameter("bar", 4 * torch::ones({}), false);
m.register_module("child1", child);
m.register_module("child2", child.clone());
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
auto full_params = m.named_parameters();
auto mobile_params = bc.named_parameters();
AT_ASSERT(full_params.size() == mobile_params.size());
for (const auto& e : full_params) {
AT_ASSERT(e.value.item().toInt() == mobile_params[e.name].item().toInt());
}
}
TEST(MobileTest, SaveLoadData) {
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
m.define(R"(
def add_it(self, x):
b = 4
return self.foo + x + b
)");
Module child("m2");
child.register_parameter("foo", 4 * torch::ones({}), false);
child.register_parameter("bar", 3 * torch::ones({}), false);
m.register_module("child1", child);
m.register_module("child2", child.clone());
auto full_params = m.named_parameters();
std::stringstream ss;
std::stringstream ss_data;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
mobile::_save_data(bc, ss_data);
auto mobile_params = mobile::_load_data(ss_data).named_parameters();
AT_ASSERT(full_params.size() == mobile_params.size());
for (const auto& e : full_params) {
AT_ASSERT(e.value.item<int>() == mobile_params[e.name].item<int>());
}
}
TEST(MobileTest, SaveLoadParameters) {
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
m.define(R"(
def add_it(self, x):
b = 4
return self.foo + x + b
)");
Module child("m2");
child.register_parameter("foo", 4 * torch::ones({}), false);
child.register_parameter("bar", 3 * torch::ones({}), false);
m.register_module("child1", child);
m.register_module("child2", child.clone());
auto full_params = m.named_parameters();
std::stringstream ss;
std::stringstream ss_data;
m._save_for_mobile(ss);
// load mobile module, save mobile named parameters
mobile::Module bc = _load_for_mobile(ss);
_save_parameters(bc.named_parameters(), ss_data);
// load back the named parameters, compare to full-jit Module's
auto mobile_params = _load_parameters(ss_data);
AT_ASSERT(full_params.size() == mobile_params.size());
for (const auto& e : full_params) {
AT_ASSERT(e.value.item<int>() == mobile_params[e.name].item<int>());
}
}
TEST(MobileTest, SaveLoadParametersEmpty) {
Module m("m");
m.define(R"(
def add_it(self, x):
b = 4
return x + b
)");
Module child("m2");
m.register_module("child1", child);
m.register_module("child2", child.clone());
std::stringstream ss;
std::stringstream ss_data;
m._save_for_mobile(ss);
// load mobile module, save mobile named parameters
mobile::Module bc = _load_for_mobile(ss);
_save_parameters(bc.named_parameters(), ss_data);
// load back the named parameters, test is empty
auto mobile_params = _load_parameters(ss_data);
AT_ASSERT(mobile_params.size() == 0);
}
TEST(LiteTrainerTest, SGD) {
Module m("m");
m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false);
m.define(R"(
def forward(self, x):
b = 1.0
return self.foo * x + b
)");
double learning_rate = 0.1, momentum = 0.1;
int n_epoc = 10;
// init: y = x + 1;
// target: y = 2 x + 1
std::vector<std::pair<Tensor, Tensor>> trainData{
{1 * torch::ones({1}), 3 * torch::ones({1})},
};
// Reference: Full jit and torch::optim::SGD
std::stringstream ms;
m.save(ms);
auto mm = load(ms);
std::vector<::at::Tensor> parameters;
for (auto parameter : mm.parameters()) {
parameters.emplace_back(parameter);
}
::torch::optim::SGD optimizer(
parameters, ::torch::optim::SGDOptions(learning_rate).momentum(momentum));
for (int epoc = 0; epoc < n_epoc; ++epoc) {
for (auto& data : trainData) {
auto source = data.first, targets = data.second;
optimizer.zero_grad();
std::vector<IValue> train_inputs{source};
auto output = mm.forward(train_inputs).toTensor();
auto loss = ::torch::l1_loss(output, targets);
loss.backward();
optimizer.step();
}
}
// Test: lite interpreter and torch::jit::mobile::SGD
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
std::vector<::at::Tensor> bc_parameters = bc.parameters();
::torch::jit::mobile::SGD bc_optimizer(
bc_parameters,
::torch::jit::mobile::SGDOptions(learning_rate).momentum(momentum));
for (int epoc = 0; epoc < n_epoc; ++epoc) {
for (auto& data : trainData) {
auto source = data.first, targets = data.second;
bc_optimizer.zero_grad();
std::vector<IValue> train_inputs{source};
auto output = bc.forward(train_inputs).toTensor();
auto loss = ::torch::l1_loss(output, targets);
loss.backward();
bc_optimizer.step();
}
}
AT_ASSERT(parameters[0].item<float>() == bc_parameters[0].item<float>());
}
namespace {
struct DummyDataset : torch::data::datasets::Dataset<DummyDataset, int> {
explicit DummyDataset(size_t size = 100) : size_(size) {}
int get(size_t index) override {
return 1 + index;
}
torch::optional<size_t> size() const override {
return size_;
}
size_t size_;
};
} // namespace
TEST(LiteTrainerTest, SequentialSampler) {
// test that sampler can be used with dataloader
const int kBatchSize = 10;
auto data_loader =
torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
DummyDataset(25), kBatchSize);
int i = 1;
for (const auto& batch : *data_loader) {
for (const auto& example : batch) {
AT_ASSERT(i == example);
i++;
}
}
}
} // namespace jit
} // namespace torch