-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
79 lines (61 loc) · 2.31 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
import pandas as pd
import scpp
import util
parser = argparse.ArgumentParser()
parser.add_argument('--input_fn', type=str, default=None)
parser.add_argument('--output_dir', type=str, default='out/tmp/')
parser.add_argument('--replace_outdir', action='store_true')
parser.add_argument('--gamma', type=float, default=1)
parser.add_argument('--beta', type=float, default=1)
parser.add_argument('--a', type=float, default=1)
parser.add_argument('--b', type=float, default=1)
# Learning setting
parser.add_argument('--tol', type=float, default=100)
parser.add_argument('--max_iter', type=int, default=20)
parser.add_argument('--verbose', type=bool, default=True)
# Options
parser.add_argument('--time_col', type=str, default='date')
parser.add_argument('--item_col', type=str, default='item')
parser.add_argument('--user_col', type=str, default='user')
parser.add_argument('--n_sample', type=int, default=100000)
parser.add_argument('--sampling_rate', type=str, default='D')
# DEMO
parser.add_argument('--retail', action='store_true')
config = parser.parse_args()
# Data preparation
if config.retail:
data = pd.read_csv('data/retail_transaction.csv')
data = util.sample_events(data, 10000)
data = util.encode_timestamp(data, 'date', 'D')
data = util.encode_attribute(data, col='item_id', prefix='item')
data = util.encode_attribute(data, col='user_id', prefix='user')
config.output_dir = 'out/retail/'
else:
if config.input_fn is None:
raise ValueError("Specify your input filename")
data = pd.read_csv(config.input_fn)
if config.n_sample > 0:
data = util.sample_events(data, config.n_sample)
data = util.encode_timestamp(data,
datetime_col=config.time_col,
freq=config.sampling_rate)
data = util.encode_attribute(data, col=config.item_col, prefix='item')
data = util.encode_attribute(data, col=config.user_col, prefix='user')
print()
print(" INPUT ")
print("=======")
# print(data.head())
print(data.nunique())
print()
util.prepare_workspace(config.output_dir, replace=config.replace_outdir)
model = scpp.SCPP()
model.fit(data,
gamma=config.gamma,
beta=config.beta,
a=config.a,
b=config.b,
max_iter=config.max_iter,
tol=config.tol,
verbose=config.verbose)
model.save(config.output_dir)