Skip to content

Commit

Permalink
Merge pull request #651 from wangzhen38/fat_deepffm
Browse files Browse the repository at this point in the history
FAT_DeepFFM
  • Loading branch information
frankwhzhang authored Dec 29, 2021
2 parents 5840941 + 2239b88 commit b5d1584
Show file tree
Hide file tree
Showing 15 changed files with 772 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ python -u tools/static_trainer.py -m models/rank/dnn/config.yaml # 静态图训
| 排序 | [Dnn](models/rank/dnn/)([文档](https://paddlerec.readthedocs.io/en/latest/models/rank/dnn.html)) | [Python CPU/GPU](https://aistudio.baidu.com/aistudio/projectdetail/3240347) ||||| >=2.1.0 | / |
| 排序 | [FM](models/rank/fm/)([文档](https://paddlerec.readthedocs.io/en/latest/models/rank/fm.html)) | [Python CPU/GPU](https://aistudio.baidu.com/aistudio/projectdetail/3240371) |||| x | >=2.1.0 | [IEEE Data Mining 2010][Factorization machines](https://analyticsconsultores.com.mx/wp-content/uploads/2019/03/Factorization-Machines-Steffen-Rendle-Osaka-University-2010.pdf) |
| 排序 | [BERT4REC](models/rank/bert4rec/) | - |||| x | >=2.1.0 | [CIKM 2019][BERT4Rec: Sequential Recommendation with Bidirectional Encoder Representations from Transformer](https://arxiv.org/pdf/1904.06690.pdf) |
| 排序 | [FAT_DeepFFM](models/rank/fat_deepffm/) | - |||| x | >=2.1.0 | [2019][FAT-DeepFFM: Field Attentive Deep Field-aware Factorization Machine](https://arxiv.org/pdf/1905.06336.pdf) |
| 排序 | [FFM](models/rank/ffm/)([文档](https://paddlerec.readthedocs.io/en/latest/models/rank/ffm.html)) | [Python CPU/GPU](https://aistudio.baidu.com/aistudio/projectdetail/3240369) |||| x | >=2.1.0 | [RECSYS 2016][Field-aware Factorization Machines for CTR Prediction](https://dl.acm.org/doi/pdf/10.1145/2959100.2959134) |
| 排序 | [FNN](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5/models/rank/fnn/) | - |||| x | [1.8.5](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5) | [ECIR 2016][Deep Learning over Multi-field Categorical Data](https://arxiv.org/pdf/1601.02376.pdf) |
| 排序 | [Deep Crossing](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5/models/rank/deep_crossing/) | - |||| x | [1.8.5](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5) | [ACM 2016][Deep Crossing: Web-Scale Modeling without Manually Crafted Combinatorial Features](https://www.kdd.org/kdd2016/papers/files/adf0975-shanA.pdf) |
Expand Down
1 change: 1 addition & 0 deletions README_EN.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ python -u tools/static_trainer.py -m models/rank/dnn/config.yaml # Training wit
| Rank | [Dnn](models/rank/dnn/)([doc](https://paddlerec.readthedocs.io/en/latest/models/rank/dnn.html)) | [Python CPU/GPU](https://aistudio.baidu.com/aistudio/projectdetail/3240347) ||||| >=2.1.0 | / |
| Rank | [FM](models/rank/fm/)([doc](https://paddlerec.readthedocs.io/en/latest/models/rank/fm.html)) | [Python CPU/GPU](https://aistudio.baidu.com/aistudio/projectdetail/3240371) |||| x | >=2.1.0 | [IEEE Data Mining 2010][Factorization machines](https://analyticsconsultores.com.mx/wp-content/uploads/2019/03/Factorization-Machines-Steffen-Rendle-Osaka-University-2010.pdf) |
| Rank | [BERT4REC](models/rank/bert4rec/) | - |||| x | >=2.1.0 | [CIKM 2019][BERT4Rec: Sequential Recommendation with Bidirectional Encoder Representations from Transformer](https://arxiv.org/pdf/1904.06690.pdf) |
| Rank | [FAT_DeepFFM](models/rank/fat_deepffm/) | - |||| x | >=2.1.0 | [2019][FAT-DeepFFM: Field Attentive Deep Field-aware Factorization Machine](https://arxiv.org/pdf/1905.06336.pdf) |
| Rank | [FFM](models/rank/ffm/)([doc](https://paddlerec.readthedocs.io/en/latest/models/rank/ffm.html)) | [Python CPU/GPU](https://aistudio.baidu.com/aistudio/projectdetail/3240369) |||| x | >=2.1.0 | [RECSYS 2016][Field-aware Factorization Machines for CTR Prediction](https://dl.acm.org/doi/pdf/10.1145/2959100.2959134) |
| Rank | [FNN](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5/models/rank/fnn/) | - |||| x | [1.8.5](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5) | [ECIR 2016][Deep Learning over Multi-field Categorical Data](https://arxiv.org/pdf/1601.02376.pdf) |
| Rank | [Deep Crossing](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5/models/rank/deep_crossing/) | - |||| x | [1.8.5](https://github.com/PaddlePaddle/PaddleRec/tree/release/1.8.5) | [ACM 2016][Deep Crossing: Web-Scale Modeling without Manually Crafted Combinatorial Features](https://www.kdd.org/kdd2016/papers/files/adf0975-shanA.pdf) |
Expand Down
135 changes: 135 additions & 0 deletions models/rank/fat_deepffm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# 基于 FAT_DeepFFM 模型的点击率预估模型

以下是本例的简要目录结构及说明:

```
├── data # 样例数据
├── sample_data # 样例数据
├── train
├── sample_train.txt # 训练数据样例
├── __init__.py
├── README.md # 文档
├── config.yaml # sample数据配置
├── config_bigdata.yaml # 全量数据配置
├── net.py # 模型核心组网(动静统一)
├── criteo_reader.py # 数据读取程序
├── dygraph_model.py # 构建动态图
```

注:在阅读该示例前,建议您先了解以下内容:

[PaddleRec入门教程](https://github.com/PaddlePaddle/PaddleRec/blob/master/README.md)

## 内容

- [模型简介](#模型简介)
- [数据准备](#数据准备)
- [运行环境](#运行环境)
- [快速开始](#快速开始)
- [模型组网](#模型组网)
- [效果复现](#效果复现)
- [进阶使用](#进阶使用)
- [FAQ](#FAQ)

## 模型简介
`CTR(Click Through Rate)`,即点击率,是“推荐系统/计算广告”等领域的重要指标,对其进行预估是商品推送/广告投放等决策的基础。简单来说,CTR预估对每次广告的点击情况做出预测,预测用户是点击还是不点击。CTR预估模型综合考虑各种因素、特征,在大量历史数据上训练,最终对商业决策提供帮助。本模型实现了下述论文中的 FAT_DeepFFM 模型:

```text
@article{FAT-DeepFFM2019,
title={FAT-DeepFFM: Field Attentive Deep Field-aware Factorization Machine},
author={Junlin Zhang, Tongwen Huang, Zhiqi Zhang},
journal={arXiv preprint arXiv:1905.06336},
year={2019},
url={https://arxiv.org/pdf/1905.06336},
}
```

## 数据准备

训练及测试数据集选用[Display Advertising Challenge](https://www.kaggle.com/c/criteo-display-ad-challenge/)所用的Criteo数据集。该数据集包括两部分:训练集和测试集。训练集包含一段时间内Criteo的部分流量,测试集则对应训练数据后一天的广告点击流量。
每一行数据格式如下所示:
```
<label> <integer feature 1> ... <integer feature 13> <categorical feature 1> ... <categorical feature 26>
```
其中```<label>```表示广告是否被点击,点击用1表示,未点击用0表示。```<integer feature>```代表数值特征(连续特征),共有13个连续特征。```<categorical feature>```代表分类特征(离散特征),共有26个离散特征。相邻两个特征用```\t```分隔,缺失特征用空格表示。测试集中```<label>```特征已被移除。
在模型目录的data目录下为您准备了快速运行的示例数据,若需要使用全量数据可以参考下方[效果复现](#效果复现)部分。

## 运行环境
PaddlePaddle>=2.0

python 2.7/3.5/3.6/3.7

os : windows/linux/macos

## 快速开始
本文提供了样例数据可以供您快速体验,在任意目录下均可执行。在fat_deepffm模型目录的快速执行命令如下:
```bash
# 进入模型目录
# cd models/rank/fat_deepffm # 在任意目录均可运行
# 动态图训练
python -u ../../../tools/trainer.py -m config.yaml # 全量数据运行config_bigdata.yaml
# 动态图预测
python -u ../../../tools/infer.py -m config.yaml
```

## 模型组网

FAT_DeepFFM 模型的组网,代码参考 `net.py`。模型主要组成是 Embedding 层,CENet 层,DeepFFM特征交叉层,DNN层以及相应的分类任务的loss计算和auc计算。模型架构如下:

<img align="center" src="picture/11.jpg" width="400" height="300">


### **CENet 层**

FAT_DeepFFM 模型的特征输入,主要包括 sparse 类别特征。(在处理 dense 数值型特征时,进行升维与sparse 类别特征拼接)
sparse features 经由 embedding 层查找得到相应的 embedding 向量。使用CENet显示地建模特征之间的依赖关系。CENet网络结构如下图所示:

<img align="center" src="picture/2.jpg" width="400" height="300">

根据网络结构图,通过CENet的注意力机制有选择性地突出信息特征并抑制不太有用的特征,公式如下所示:

<img align="center" src="picture/3.jpg" width="400" height="60">


### **DeepFFM层**
DeepFFM网络结构如下图所示:

<img align="center" src="picture/4.jpg" width="400" height="300">

使用FFM对特征的不同field的关系进行建模,计算公式如下所示:

<img align="center" src="picture/55.jpg" width="500" height="100">



### **Loss 及 Auc 计算**
- 为了得到每条样本分属于正负样本的概率,我们将预测结果和 `1-predict` 合并起来得到 `predict_2d`,以便接下来计算 `auc`
- 每条样本的损失为负对数损失值,label的数据类型将转化为float输入。
- 该batch的损失 `avg_cost` 是各条样本的损失之和
- 我们同时还会计算预测的auc指标。

## 效果复现
为了方便使用者能够快速的跑通每一个模型,我们在每个模型下都提供了样例数据。如果需要复现 README 中的效果,请按如下步骤依次操作即可。
在全量数据下模型的指标如下:

| 模型 | auc | batch_size | epoch_num| Time of each epoch |
| :------| :------ | :------ | :------| :------ |
| FAT_DeepFFM | 0.8037 | 1000 | 1 | 约 3.5 小时 |

1. 确认您当前所在目录为 `PaddleRec/models/rank/fat_deepffm`
2. 进入 `PaddleRec/datasets/criteo` 目录下,执行该脚本,会从国内源的服务器上下载我们预处理完成的criteo全量数据集,并解压到指定文件夹。
``` bash
cd ../../../datasets/criteo
sh run.sh
```
3. 切回模型目录,执行命令运行全量数据
```bash
cd - # 切回模型目录
# 动态图训练
python -u ../../../tools/trainer.py -m config_bigdata.yaml # 全量数据运行config_bigdata.yaml
python -u ../../../tools/infer.py -m config_bigdata.yaml # 全量数据运行config_bigdata.yaml
```

## 进阶使用

## FAQ
13 changes: 13 additions & 0 deletions models/rank/fat_deepffm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
56 changes: 56 additions & 0 deletions models/rank/fat_deepffm/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# workspace
#workspace: "models/rank/fat_deepffm"


runner:
train_data_dir: "data/sample_data/train"
train_reader_path: "criteo_reader" # importlib format
use_gpu: False
use_auc: True
train_batch_size: 1
epochs: 1
print_interval: 10

model_save_path: "output_model_fat_deepffm"
infer_batch_size: 1000
infer_reader_path: "criteo_reader" # importlib format
test_data_dir: "data/sample_data/train"

infer_load_path: "output_model_fat_deepffm"
infer_start_epoch: 0
infer_end_epoch: 1

# distribute_config
sync_mode: "async"
split_file_list: False
thread_num: 1 # 1


# hyper parameters of user-defined network
hyper_parameters:
# optimizer config
optimizer:
class: Adam
learning_rate: 0.0001
strategy: async
# user-defined <key, value> pairs
sparse_inputs_slots: 27
sparse_feature_number: 1000001
sparse_feature_dim: 10
dense_input_dim: 13
distributed_embedding: 0
layer_sizes_dnn: [1600,1600]
56 changes: 56 additions & 0 deletions models/rank/fat_deepffm/config_bigdata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# workspace
#workspace: "models/rank/fat_deepffm"


runner:
train_data_dir: "../../../datasets/criteo/slot_train_data_full" # criteo_sample_train slot_train_data_full
train_reader_path: "criteo_reader" # importlib format
use_gpu: True
use_auc: True
train_batch_size: 1000
epochs: 1
print_interval: 100

model_save_path: "output_model_all_fat_deepffm"
infer_batch_size: 1000
infer_reader_path: "criteo_reader" # importlib format
test_data_dir: "../../../datasets/criteo/slot_test_data_full"

infer_load_path: "output_model_fat_deepffm"
infer_start_epoch: 0
infer_end_epoch: 1

# distribute_config
sync_mode: "async"
split_file_list: False
thread_num: 1 # 1


# hyper parameters of user-defined network
hyper_parameters:
# optimizer config
optimizer:
class: Adam
learning_rate: 0.0001
strategy: async
# user-defined <key, value> pairs
sparse_inputs_slots: 27
sparse_feature_number: 1000001
sparse_feature_dim: 10
dense_input_dim: 13
distributed_embedding: 0
layer_sizes_dnn: [1600,1600]
81 changes: 81 additions & 0 deletions models/rank/fat_deepffm/criteo_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
import numpy as np

from paddle.io import IterableDataset


class RecDataset(IterableDataset):
def __init__(self, file_list, config):
super(RecDataset, self).__init__()
self.file_list = file_list
self.init()

def init(self):
from operator import mul
padding = 0
sparse_slots = "click 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26"
self.sparse_slots = sparse_slots.strip().split(" ")
self.dense_slots = ["dense_feature"]
self.dense_slots_shape = [13]
self.slots = self.sparse_slots + self.dense_slots
self.slot2index = {}
self.visit = {}
for i in range(len(self.slots)):
self.slot2index[self.slots[i]] = i
self.visit[self.slots[i]] = False
self.padding = padding

def __iter__(self):
full_lines = []
self.data = []
for file in self.file_list:
with open(file, "r") as rf:
for l in rf:
line = l.strip().split(" ")
output = [(i, []) for i in self.slots]
for i in line:
slot_feasign = i.split(":")
slot = slot_feasign[0]
if slot not in self.slots:
continue
if slot in self.sparse_slots:
feasign = int(slot_feasign[1])
else:
feasign = float(slot_feasign[1])
output[self.slot2index[slot]][1].append(feasign)
self.visit[slot] = True
for i in self.visit:
slot = i
if not self.visit[slot]:
if i in self.dense_slots:
output[self.slot2index[i]][1].extend(
[self.padding] *
self.dense_slots_shape[self.slot2index[i]])
else:
output[self.slot2index[i]][1].extend(
[self.padding])
else:
self.visit[slot] = False
# sparse
output_list = []
for key, value in output[:-1]:
output_list.append(np.array(value).astype('int64'))
# dense
output_list.append(
np.array(output[-1][1]).astype("float32"))
# list
yield output_list
Loading

0 comments on commit b5d1584

Please sign in to comment.