-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
executable file
·110 lines (89 loc) · 3.31 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import gzip
import logging
import argparse
import paddle.v2 as paddle
from network_conf import DeepFM
import reader
logging.basicConfig()
logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(description="PaddlePaddle DeepFM example")
parser.add_argument(
'--train_data_path',
type=str,
default='data/train.txt',
#required=True,
help="The path of training dataset")
parser.add_argument(
'--test_data_path',
type=str,
default='data/valid.txt',
#required=True,
help="The path of testing dataset")
parser.add_argument(
'--batch_size',
type=int,
default=1000,
help="The size of mini-batch (default:1000)")
parser.add_argument(
'--num_passes',
type=int,
default=10,
help="The number of passes to train (default: 10)")
parser.add_argument(
'--factor_size',
type=int,
default=10,
help="The factor size for the factorization machine (default:10)")
parser.add_argument(
'--model_output_dir',
type=str,
default='models',
help='The path for model to store (default: models)')
return parser.parse_args()
def train():
args = parse_args()
if not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir)
paddle.init(use_gpu=False, trainer_count=1)
optimizer = paddle.optimizer.Adam(learning_rate=1e-4)
model = DeepFM(args.factor_size)
params = paddle.parameters.create(model)
trainer = paddle.trainer.SGD(cost=model,
parameters=params,
update_equation=optimizer)
dataset = reader.Dataset()
def __event_handler__(event):
if isinstance(event, paddle.event.EndIteration):
num_samples = event.batch_id * args.batch_size
if event.batch_id % 100 == 0:
logger.warning("Pass %d, Batch %d, Samples %d, Cost %f, %s" %
(event.pass_id, event.batch_id, num_samples,
event.cost, event.metrics))
if event.batch_id % 10000 == 0:
if args.test_data_path:
result = trainer.test(
reader=paddle.batch(
dataset.test(args.test_data_path),
batch_size=args.batch_size),
feeding=reader.feeding)
logger.warning("Test %d-%d, Cost %f, %s" %
(event.pass_id, event.batch_id, result.cost,
result.metrics))
path = "{}/model-pass-{}-batch-{}.tar.gz".format(
args.model_output_dir, event.pass_id, event.batch_id)
with gzip.open(path, 'w') as f:
trainer.save_parameter_to_tar(f)
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(
dataset.train(args.train_data_path),
buf_size=args.batch_size * 10000),
batch_size=args.batch_size),
feeding=reader.feeding,
event_handler=__event_handler__,
num_passes=args.num_passes)
if __name__ == '__main__':
train()