Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aitm model #756

Merged
merged 11 commits into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions models/rank/aitm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
43 changes: 43 additions & 0 deletions models/rank/aitm/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# global settings

runner:
train_data_dir: "./data/sample_data/train"
train_reader_path: "reader" # importlib format
use_gpu: True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

demo数据下不建议使用gpu,可以采取减小数据集,减小epoch等方式缩短demo训练时间至一分钟以内

train_batch_size: 10
epochs: 2
print_interval: 10
#model_init_path: "output_model/0" # init model
model_save_path: "output_model_aitm_test/"
test_data_dir: "./data/sample_data/test"
infer_reader_path: "reader" # importlib format
infer_batch_size: 10
infer_load_path: "output_model_aitm_test/"
infer_start_epoch: 0
infer_end_epoch: 2

# hyper parameters of user-defined network
hyper_parameters:
# optimizer config
optimizer:
class: Adam
learning_rate: 0.0001
# user-defined <key, value> pairs
embedding_size: 5
dims: [128, 64, 32]
drop_prob: [0.1, 0.3, 0.3]
feature_vocabulary: [['101', 238635], ['121', 98], ['122', 14], ['124', 3], ['125', 8], ['126', 4], ['127', 4], ['128', 3], ['129', 5], ['205', 467298], ['206', 6929], ['207', 263942], ['216', 106399], ['508', 5888], ['509', 104830], ['702', 51878], ['853', 37148], ['301', 4]]
43 changes: 43 additions & 0 deletions models/rank/aitm/config_bigdata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# global settings

runner:
train_data_dir: "./data/whole_data/train"
train_reader_path: "reader" # importlib format
use_gpu: True
train_batch_size: 2000
epochs: 6
print_interval: 500
#model_init_path: "output_model/0" # init model
model_save_path: "output_model_aitm/"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

保存路径命名可以参考其他模型,改为output_model_aitm_all

test_data_dir: "./data/whole_data/test"
infer_reader_path: "reader" # importlib format
infer_batch_size: 2000
infer_load_path: "output_model_aitm/"
infer_start_epoch: 0
infer_end_epoch: 6

# hyper parameters of user-defined network
hyper_parameters:
# optimizer config
optimizer:
class: Adam
learning_rate: 0.0002
# user-defined <key, value> pairs
embedding_size: 5
dims: [128, 64, 32]
drop_prob: [0.1, 0.3, 0.3]
feature_vocabulary: [['101', 238635], ['121', 98], ['122', 14], ['124', 3], ['125', 8], ['126', 4], ['127', 4], ['128', 3], ['129', 5], ['205', 467298], ['206', 6929], ['207', 263942], ['216', 106399], ['508', 5888], ['509', 104830], ['702', 51878], ['853', 37148], ['301', 4]]
50 changes: 50 additions & 0 deletions models/rank/aitm/data/sample_data/test/test.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
0,0,180467,4,1,2,3,1,1,2,1,157266,5673,164308,73943,753,0,0,32454,2
0,0,180467,4,1,2,3,1,1,2,1,426254,5673,216630,0,753,82564,0,24700,2
1,0,180467,4,1,2,3,1,1,2,1,172023,2114,63355,0,2261,80294,0,0,2
0,0,225145,39,3,2,4,1,1,2,2,70837,1668,46734,61858,1952,0,0,36223,2
0,0,225145,39,3,2,4,1,1,2,2,45118,6423,137020,39048,0,0,0,0,2
0,0,225145,39,3,2,4,1,1,2,2,0,3384,55310,13430,287,0,0,35494,2
0,1,225145,39,3,2,4,1,1,2,2,0,2015,0,81649,467,0,10904,0,2
0,0,225145,39,3,2,4,1,1,2,2,27259,4109,88551,49425,4417,0,39516,0,2
1,0,225145,39,3,2,4,1,1,2,2,297113,4109,219686,5769,4417,0,0,0,2
0,0,225145,39,3,2,4,1,1,2,2,151749,3228,148375,89123,2599,47614,10053,0,2
0,0,225145,39,3,2,4,1,1,2,2,0,1081,47185,0,3865,0,0,0,2
0,0,225145,39,3,2,4,1,1,2,2,0,2015,0,81649,467,0,10904,0,2
0,0,225145,39,3,2,4,1,1,2,2,26903,4630,159418,0,0,80716,0,24471,2
0,0,225145,39,3,2,4,1,1,2,2,282069,2967,230832,84386,5400,0,0,0,2
0,1,186821,75,10,1,3,2,1,2,2,62621,4544,58151,70939,0,0,0,0,3
0,0,186821,75,10,1,3,2,1,2,2,93138,532,115823,87562,3213,0,0,16010,3
0,0,186821,75,10,1,3,2,1,2,2,45206,2820,183822,13990,3138,63787,37851,4086,3
0,0,146759,60,11,2,6,0,2,2,0,368243,2215,123885,92072,5027,0,0,0,1
0,1,146759,60,11,2,6,0,2,2,0,215196,959,195714,0,3110,0,0,9556,1
0,0,131084,23,11,2,6,3,1,2,2,239163,1830,99024,90467,2743,0,0,0,1
0,0,131084,23,11,2,6,3,1,2,2,264944,2096,111223,67585,4293,0,0,0,1
0,0,131084,23,11,2,6,3,1,2,2,233483,598,197331,53113,0,0,0,0,1
0,0,131084,23,11,2,6,3,1,2,2,312012,3757,205654,22313,1892,0,0,0,1
0,0,131084,23,11,2,6,3,1,2,2,198046,4010,46170,0,567,0,0,0,1
0,0,25124,60,5,2,5,0,1,2,2,314484,6224,138638,0,533,37733,0,0,1
0,0,25124,60,5,2,5,0,1,2,2,258277,5670,74133,72066,1197,72538,20372,0,1
0,0,25124,60,5,2,5,0,1,2,2,0,4770,137189,37029,1778,0,0,0,1
0,0,25124,60,5,2,5,0,1,2,2,424602,2114,130310,60811,2261,0,0,32948,1
0,0,25124,60,5,2,5,0,1,2,2,209585,3746,60855,34322,3483,87904,26605,0,1
1,0,25124,60,5,2,5,0,1,2,2,433251,572,36404,0,319,103340,0,0,1
0,0,139588,60,10,1,3,0,1,2,2,319344,6048,176737,0,0,0,0,0,1
0,0,139588,60,10,1,3,0,1,2,2,170111,5294,255332,26976,0,0,0,0,1
0,0,139588,60,10,1,3,0,1,2,2,35,5674,100397,21139,0,0,0,0,1
0,0,233279,60,13,1,5,0,1,2,3,225154,5748,229278,0,4488,0,0,28266,3
0,0,233279,60,13,1,5,0,1,2,3,0,3612,158255,46162,5182,0,0,0,3
0,0,233279,60,13,1,5,0,1,2,3,86976,2541,234649,75046,4702,0,0,0,3
0,0,233279,60,13,1,5,0,1,2,3,161648,2748,15385,0,3270,0,0,12825,3
0,0,233279,60,13,1,5,0,1,2,3,71718,2371,74049,83673,4259,36871,46702,9076,3
0,0,233279,60,13,1,5,0,1,2,3,108993,2139,107889,0,5313,0,0,0,3
0,0,57653,60,10,1,3,0,1,2,1,405019,5473,238445,86159,1489,0,31263,14752,3
0,0,57653,60,10,1,3,0,1,2,1,103212,5473,62663,13073,1489,0,0,0,3
0,0,57653,60,10,1,3,0,1,2,1,464808,887,212115,0,3249,0,0,14835,3
0,0,94847,0,0,0,0,0,0,0,0,10251,4770,1730,0,1778,0,0,989,1
0,0,94847,0,0,0,0,0,0,0,0,206244,2139,246373,97164,5313,0,0,11014,1
0,0,94847,0,0,0,0,0,0,0,0,222670,4770,91229,13685,1778,0,0,8594,1
0,0,94847,0,0,0,0,0,0,0,0,308509,4010,50918,104006,0,0,0,0,1
0,0,94847,0,0,0,0,0,0,0,0,60836,3639,47235,65915,662,0,0,0,1
0,0,94847,0,0,0,0,0,0,0,0,437045,2697,6369,105353,0,0,0,0,1
0,0,94847,0,0,0,0,0,0,0,0,28761,1005,91855,91674,5425,0,0,0,1
0,0,94847,0,0,0,0,0,0,0,0,0,990,0,0,0,0,0,0,1
49 changes: 49 additions & 0 deletions models/rank/aitm/data/sample_data/train/train.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
click,purchase,101,121,122,124,125,126,127,128,129,205,206,207,216,508,509,702,853,301
0,0,85077,60,5,2,5,0,1,2,1,303983,3751,249750,4124,5565,0,0,6260,2
0,0,85077,60,5,2,5,0,1,2,1,200850,1005,156972,315,5425,0,0,31950,2
0,0,85077,60,5,2,5,0,1,2,1,361912,2748,116642,103627,3270,0,0,0,2
0,0,85077,60,5,2,5,0,1,2,1,235052,4791,18014,92858,1575,18172,39447,12632,2
0,0,85077,60,5,2,5,0,1,2,1,236563,1464,49011,27526,2400,0,40351,0,2
1,1,85077,60,5,2,5,0,1,2,1,0,2668,73088,0,1349,36448,0,3202,2
0,0,85077,60,5,2,5,0,1,2,1,868,2114,134232,27903,2261,0,42361,8186,2
0,1,85077,60,5,2,5,0,1,2,1,278769,887,252320,35421,0,0,0,0,2
0,0,85077,60,5,2,5,0,1,2,1,378778,959,237727,24961,3110,0,12437,0,2
0,0,85077,60,5,2,5,0,1,2,1,286476,2668,36149,50475,1349,0,17247,24785,2
1,0,85077,60,5,2,5,0,1,2,1,37325,1219,210814,0,732,0,0,0,2
0,0,85077,60,5,2,5,0,1,2,1,0,6396,145671,104006,0,0,10562,0,2
0,1,85077,60,5,2,5,0,1,2,1,0,6396,187630,3695,0,0,39953,0,2
0,0,85077,60,5,2,5,0,1,2,1,174938,2542,121978,25017,2752,0,0,16432,2
0,0,85077,60,5,2,5,0,1,2,1,126218,2748,99220,3695,3270,1641,39953,0,2
0,0,85077,60,5,2,5,0,1,2,1,137678,2668,156813,76812,1349,0,0,0,2
0,0,85077,60,5,2,5,0,1,2,1,0,6639,75496,0,1641,0,0,0,2
0,0,85077,60,5,2,5,0,1,2,1,0,3358,28909,24961,1715,33178,12437,0,2
0,0,85077,60,5,2,5,0,1,2,1,463355,2321,143602,0,2827,0,0,0,2
0,0,85077,60,5,2,5,0,1,2,1,208394,5441,165858,100800,3175,0,0,14526,2
0,0,85077,60,5,2,5,0,1,2,1,0,4770,160421,49203,1778,104211,12570,0,2
0,0,85077,60,5,2,5,0,1,2,1,0,5130,152781,34023,3305,0,0,0,2
0,0,85077,60,5,2,5,0,1,2,1,0,1219,71745,37124,732,0,43897,0,2
0,0,85077,60,5,2,5,0,1,2,1,0,4770,78413,0,1778,68142,0,0,2
0,0,85077,60,5,2,5,0,1,2,1,0,4770,0,51246,1778,0,12494,0,2
0,0,85077,60,5,2,5,0,1,2,1,0,4770,12306,99058,1778,94838,23330,0,2
0,0,126936,75,10,1,3,1,1,2,2,436088,5342,251939,55848,5686,0,43802,891,3
0,0,225894,36,5,2,5,1,1,2,1,236983,2114,100033,45859,2261,104041,37050,32948,2
0,0,225894,36,5,2,5,1,1,2,1,0,2668,81805,94691,1349,80860,23323,36346,2
1,0,225894,36,5,2,5,1,1,2,1,236983,2114,100033,45859,2261,104041,37050,32948,2
0,0,225894,36,5,2,5,1,1,2,1,76908,1323,2845,4224,0,0,0,0,2
0,0,225894,36,5,2,5,1,1,2,1,0,2139,27250,65082,5313,86671,24195,0,2
0,0,214292,60,13,1,5,0,1,2,2,405168,4408,200625,53253,2913,0,0,0,3
0,0,214292,60,13,1,5,0,1,2,2,0,3711,105033,0,3877,0,0,0,3
0,0,214292,60,13,1,5,0,1,2,2,342042,3346,230129,0,610,0,0,0,3
0,0,214292,60,13,1,5,0,1,2,2,140466,3711,160865,33806,3877,0,0,19879,3
0,0,214292,60,13,1,5,0,1,2,2,243201,3924,130907,30524,1289,0,0,0,3
0,0,214292,60,13,1,5,0,1,2,2,229158,3643,123318,92849,0,0,0,0,3
0,0,214292,60,13,1,5,0,1,2,2,438610,4782,215293,100092,0,0,0,0,3
0,0,214292,60,13,1,5,0,1,2,2,106945,5334,16167,34467,4967,0,0,0,3
0,0,214292,60,13,1,5,0,1,2,2,72148,2177,182249,0,5689,93690,0,13168,3
0,0,214292,60,13,1,5,0,1,2,2,145491,989,248896,65379,0,0,0,0,3
0,0,214292,60,13,1,5,0,1,2,2,49677,5673,57894,0,0,0,0,0,3
0,0,214292,60,13,1,5,0,1,2,2,401746,3334,54001,104747,0,83894,0,0,3
0,0,214292,60,13,1,5,0,1,2,2,293581,1685,30366,47952,2631,0,0,12116,3
0,0,214292,60,13,1,5,0,1,2,2,211989,2815,162936,92766,0,0,0,0,3
0,0,214292,60,13,1,5,0,1,2,2,398227,4782,12082,65379,0,0,0,0,3
0,0,214292,60,13,1,5,0,1,2,2,0,3713,0,68249,1813,0,27110,0,3
108 changes: 108 additions & 0 deletions models/rank/aitm/dygraph_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import net


class DygraphModel():
def create_model(self, config):
feature_vocabulary = config.get("hyper_parameters.feature_vocabulary")
embedding_size = config.get("hyper_parameters.embedding_size")
tower_dims = config.get("hyper_parameters.dims")
drop_prob = config.get('hyper_parameters.drop_prob')
feature_vocabulary = dict(feature_vocabulary)
model = net.AITM(feature_vocabulary, embedding_size, tower_dims,
drop_prob)
return model

# define feeds which convert numpy of batch data to paddle.tensor
def create_feeds(self, batch_data, config):
click, conversion, features = batch_data
return click.astype('float32'), conversion.astype('float32'), features

# define loss function by predicts and label
def create_loss(self,
click_pred,
conversion_pred,
click_label,
conversion_label,
constraint_weight=0.6):
click_loss = F.binary_cross_entropy(click_pred, click_label)
conversion_loss = F.binary_cross_entropy(conversion_pred,
conversion_label)

label_constraint = paddle.maximum(conversion_pred - click_pred,
paddle.zeros_like(click_label))
constraint_loss = paddle.sum(label_constraint)

loss = click_loss + conversion_loss + constraint_weight * constraint_loss
return loss

# define optimizer
def create_optimizer(self, dy_model, config):
lr = config.get("hyper_parameters.optimizer.learning_rate", 0.0001)
optimizer = paddle.optimizer.Adam(
learning_rate=lr,
parameters=dy_model.parameters(),
weight_decay=1e-6)
return optimizer

# define metrics such as auc/acc
# multi-task need to define multi metric
def create_metrics(self):
metrics_list_name = ["click_auc", "purchase_auc"]
metrics_list = [
paddle.metric.Auc("ROC", num_thresholds=100000),
paddle.metric.Auc("ROC", num_thresholds=100000)
]
return metrics_list, metrics_list_name

# construct train forward phase
def train_forward(self, dy_model, metrics_list, batch_data, config):
click, conversion, features = self.create_feeds(batch_data, config)
click_pred, conversion_pred = dy_model.forward(features)
loss = self.create_loss(click_pred, conversion_pred, click, conversion)
# update metrics

self.update_auc(click_pred, click, metrics_list[0])
self.update_auc(conversion_pred, conversion, metrics_list[1])
print_dict = {'loss': loss}
return loss, metrics_list, print_dict

@staticmethod
def update_auc(prob, label, metrics):
if prob.ndim == 1:
prob = prob.unsqueeze(-1)
assert prob.ndim == 2
predict_2d = paddle.concat(x=[1 - prob, prob], axis=1)
metrics.update(predict_2d, label)

def infer_forward(self, dy_model, metrics_list, batch_data, config):
click, conversion, features = self.create_feeds(batch_data, config)
with paddle.no_grad():
click_pred, conversion_pred = dy_model.forward(features)
# update metrics
self.update_auc(click_pred, click, metrics_list[0])
self.update_auc(conversion_pred, conversion, metrics_list[1])
return metrics_list, None

def forward(self, dy_model, batch_data, config):
click, conversion, features = self.create_feeds(batch_data, config)
with paddle.no_grad():
click_pred, conversion_pred = dy_model.forward(features)
# update metrics
return click, click_pred, conversion, conversion_pred
Loading