-
Notifications
You must be signed in to change notification settings - Fork 645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add aitm model #756
Merged
Merged
Add aitm model #756
Changes from 4 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
0552236
Add aitm model
renmada 5ba0784
Fix code style
renmada 660d0db
Fix code style
renmada f8347a7
Fix code style
renmada bdfb610
Merge remote-tracking branch 'upstream/master' into aitm
renmada bf9a3c0
Fix
renmada 198aa25
Fix
renmada 580c98d
move to datasets loader
renmada b6cfc37
move to datasets loader
renmada 99e15bf
Update
renmada fa1dd3d
Fix tipc data path
renmada File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# global settings | ||
|
||
runner: | ||
train_data_dir: "./data/sample_data/train" | ||
train_reader_path: "reader" # importlib format | ||
use_gpu: True | ||
train_batch_size: 10 | ||
epochs: 2 | ||
print_interval: 10 | ||
#model_init_path: "output_model/0" # init model | ||
model_save_path: "output_model_aitm_test/" | ||
test_data_dir: "./data/sample_data/test" | ||
infer_reader_path: "reader" # importlib format | ||
infer_batch_size: 10 | ||
infer_load_path: "output_model_aitm_test/" | ||
infer_start_epoch: 0 | ||
infer_end_epoch: 2 | ||
|
||
# hyper parameters of user-defined network | ||
hyper_parameters: | ||
# optimizer config | ||
optimizer: | ||
class: Adam | ||
learning_rate: 0.0001 | ||
# user-defined <key, value> pairs | ||
embedding_size: 5 | ||
dims: [128, 64, 32] | ||
drop_prob: [0.1, 0.3, 0.3] | ||
feature_vocabulary: [['101', 238635], ['121', 98], ['122', 14], ['124', 3], ['125', 8], ['126', 4], ['127', 4], ['128', 3], ['129', 5], ['205', 467298], ['206', 6929], ['207', 263942], ['216', 106399], ['508', 5888], ['509', 104830], ['702', 51878], ['853', 37148], ['301', 4]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# global settings | ||
|
||
runner: | ||
train_data_dir: "./data/whole_data/train" | ||
train_reader_path: "reader" # importlib format | ||
use_gpu: True | ||
train_batch_size: 2000 | ||
epochs: 6 | ||
print_interval: 500 | ||
#model_init_path: "output_model/0" # init model | ||
model_save_path: "output_model_aitm/" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 保存路径命名可以参考其他模型,改为output_model_aitm_all |
||
test_data_dir: "./data/whole_data/test" | ||
infer_reader_path: "reader" # importlib format | ||
infer_batch_size: 2000 | ||
infer_load_path: "output_model_aitm/" | ||
infer_start_epoch: 0 | ||
infer_end_epoch: 6 | ||
|
||
# hyper parameters of user-defined network | ||
hyper_parameters: | ||
# optimizer config | ||
optimizer: | ||
class: Adam | ||
learning_rate: 0.0002 | ||
# user-defined <key, value> pairs | ||
embedding_size: 5 | ||
dims: [128, 64, 32] | ||
drop_prob: [0.1, 0.3, 0.3] | ||
feature_vocabulary: [['101', 238635], ['121', 98], ['122', 14], ['124', 3], ['125', 8], ['126', 4], ['127', 4], ['128', 3], ['129', 5], ['205', 467298], ['206', 6929], ['207', 263942], ['216', 106399], ['508', 5888], ['509', 104830], ['702', 51878], ['853', 37148], ['301', 4]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
0,0,180467,4,1,2,3,1,1,2,1,157266,5673,164308,73943,753,0,0,32454,2 | ||
0,0,180467,4,1,2,3,1,1,2,1,426254,5673,216630,0,753,82564,0,24700,2 | ||
1,0,180467,4,1,2,3,1,1,2,1,172023,2114,63355,0,2261,80294,0,0,2 | ||
0,0,225145,39,3,2,4,1,1,2,2,70837,1668,46734,61858,1952,0,0,36223,2 | ||
0,0,225145,39,3,2,4,1,1,2,2,45118,6423,137020,39048,0,0,0,0,2 | ||
0,0,225145,39,3,2,4,1,1,2,2,0,3384,55310,13430,287,0,0,35494,2 | ||
0,1,225145,39,3,2,4,1,1,2,2,0,2015,0,81649,467,0,10904,0,2 | ||
0,0,225145,39,3,2,4,1,1,2,2,27259,4109,88551,49425,4417,0,39516,0,2 | ||
1,0,225145,39,3,2,4,1,1,2,2,297113,4109,219686,5769,4417,0,0,0,2 | ||
0,0,225145,39,3,2,4,1,1,2,2,151749,3228,148375,89123,2599,47614,10053,0,2 | ||
0,0,225145,39,3,2,4,1,1,2,2,0,1081,47185,0,3865,0,0,0,2 | ||
0,0,225145,39,3,2,4,1,1,2,2,0,2015,0,81649,467,0,10904,0,2 | ||
0,0,225145,39,3,2,4,1,1,2,2,26903,4630,159418,0,0,80716,0,24471,2 | ||
0,0,225145,39,3,2,4,1,1,2,2,282069,2967,230832,84386,5400,0,0,0,2 | ||
0,1,186821,75,10,1,3,2,1,2,2,62621,4544,58151,70939,0,0,0,0,3 | ||
0,0,186821,75,10,1,3,2,1,2,2,93138,532,115823,87562,3213,0,0,16010,3 | ||
0,0,186821,75,10,1,3,2,1,2,2,45206,2820,183822,13990,3138,63787,37851,4086,3 | ||
0,0,146759,60,11,2,6,0,2,2,0,368243,2215,123885,92072,5027,0,0,0,1 | ||
0,1,146759,60,11,2,6,0,2,2,0,215196,959,195714,0,3110,0,0,9556,1 | ||
0,0,131084,23,11,2,6,3,1,2,2,239163,1830,99024,90467,2743,0,0,0,1 | ||
0,0,131084,23,11,2,6,3,1,2,2,264944,2096,111223,67585,4293,0,0,0,1 | ||
0,0,131084,23,11,2,6,3,1,2,2,233483,598,197331,53113,0,0,0,0,1 | ||
0,0,131084,23,11,2,6,3,1,2,2,312012,3757,205654,22313,1892,0,0,0,1 | ||
0,0,131084,23,11,2,6,3,1,2,2,198046,4010,46170,0,567,0,0,0,1 | ||
0,0,25124,60,5,2,5,0,1,2,2,314484,6224,138638,0,533,37733,0,0,1 | ||
0,0,25124,60,5,2,5,0,1,2,2,258277,5670,74133,72066,1197,72538,20372,0,1 | ||
0,0,25124,60,5,2,5,0,1,2,2,0,4770,137189,37029,1778,0,0,0,1 | ||
0,0,25124,60,5,2,5,0,1,2,2,424602,2114,130310,60811,2261,0,0,32948,1 | ||
0,0,25124,60,5,2,5,0,1,2,2,209585,3746,60855,34322,3483,87904,26605,0,1 | ||
1,0,25124,60,5,2,5,0,1,2,2,433251,572,36404,0,319,103340,0,0,1 | ||
0,0,139588,60,10,1,3,0,1,2,2,319344,6048,176737,0,0,0,0,0,1 | ||
0,0,139588,60,10,1,3,0,1,2,2,170111,5294,255332,26976,0,0,0,0,1 | ||
0,0,139588,60,10,1,3,0,1,2,2,35,5674,100397,21139,0,0,0,0,1 | ||
0,0,233279,60,13,1,5,0,1,2,3,225154,5748,229278,0,4488,0,0,28266,3 | ||
0,0,233279,60,13,1,5,0,1,2,3,0,3612,158255,46162,5182,0,0,0,3 | ||
0,0,233279,60,13,1,5,0,1,2,3,86976,2541,234649,75046,4702,0,0,0,3 | ||
0,0,233279,60,13,1,5,0,1,2,3,161648,2748,15385,0,3270,0,0,12825,3 | ||
0,0,233279,60,13,1,5,0,1,2,3,71718,2371,74049,83673,4259,36871,46702,9076,3 | ||
0,0,233279,60,13,1,5,0,1,2,3,108993,2139,107889,0,5313,0,0,0,3 | ||
0,0,57653,60,10,1,3,0,1,2,1,405019,5473,238445,86159,1489,0,31263,14752,3 | ||
0,0,57653,60,10,1,3,0,1,2,1,103212,5473,62663,13073,1489,0,0,0,3 | ||
0,0,57653,60,10,1,3,0,1,2,1,464808,887,212115,0,3249,0,0,14835,3 | ||
0,0,94847,0,0,0,0,0,0,0,0,10251,4770,1730,0,1778,0,0,989,1 | ||
0,0,94847,0,0,0,0,0,0,0,0,206244,2139,246373,97164,5313,0,0,11014,1 | ||
0,0,94847,0,0,0,0,0,0,0,0,222670,4770,91229,13685,1778,0,0,8594,1 | ||
0,0,94847,0,0,0,0,0,0,0,0,308509,4010,50918,104006,0,0,0,0,1 | ||
0,0,94847,0,0,0,0,0,0,0,0,60836,3639,47235,65915,662,0,0,0,1 | ||
0,0,94847,0,0,0,0,0,0,0,0,437045,2697,6369,105353,0,0,0,0,1 | ||
0,0,94847,0,0,0,0,0,0,0,0,28761,1005,91855,91674,5425,0,0,0,1 | ||
0,0,94847,0,0,0,0,0,0,0,0,0,990,0,0,0,0,0,0,1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
click,purchase,101,121,122,124,125,126,127,128,129,205,206,207,216,508,509,702,853,301 | ||
0,0,85077,60,5,2,5,0,1,2,1,303983,3751,249750,4124,5565,0,0,6260,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,200850,1005,156972,315,5425,0,0,31950,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,361912,2748,116642,103627,3270,0,0,0,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,235052,4791,18014,92858,1575,18172,39447,12632,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,236563,1464,49011,27526,2400,0,40351,0,2 | ||
1,1,85077,60,5,2,5,0,1,2,1,0,2668,73088,0,1349,36448,0,3202,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,868,2114,134232,27903,2261,0,42361,8186,2 | ||
0,1,85077,60,5,2,5,0,1,2,1,278769,887,252320,35421,0,0,0,0,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,378778,959,237727,24961,3110,0,12437,0,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,286476,2668,36149,50475,1349,0,17247,24785,2 | ||
1,0,85077,60,5,2,5,0,1,2,1,37325,1219,210814,0,732,0,0,0,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,0,6396,145671,104006,0,0,10562,0,2 | ||
0,1,85077,60,5,2,5,0,1,2,1,0,6396,187630,3695,0,0,39953,0,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,174938,2542,121978,25017,2752,0,0,16432,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,126218,2748,99220,3695,3270,1641,39953,0,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,137678,2668,156813,76812,1349,0,0,0,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,0,6639,75496,0,1641,0,0,0,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,0,3358,28909,24961,1715,33178,12437,0,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,463355,2321,143602,0,2827,0,0,0,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,208394,5441,165858,100800,3175,0,0,14526,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,0,4770,160421,49203,1778,104211,12570,0,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,0,5130,152781,34023,3305,0,0,0,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,0,1219,71745,37124,732,0,43897,0,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,0,4770,78413,0,1778,68142,0,0,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,0,4770,0,51246,1778,0,12494,0,2 | ||
0,0,85077,60,5,2,5,0,1,2,1,0,4770,12306,99058,1778,94838,23330,0,2 | ||
0,0,126936,75,10,1,3,1,1,2,2,436088,5342,251939,55848,5686,0,43802,891,3 | ||
0,0,225894,36,5,2,5,1,1,2,1,236983,2114,100033,45859,2261,104041,37050,32948,2 | ||
0,0,225894,36,5,2,5,1,1,2,1,0,2668,81805,94691,1349,80860,23323,36346,2 | ||
1,0,225894,36,5,2,5,1,1,2,1,236983,2114,100033,45859,2261,104041,37050,32948,2 | ||
0,0,225894,36,5,2,5,1,1,2,1,76908,1323,2845,4224,0,0,0,0,2 | ||
0,0,225894,36,5,2,5,1,1,2,1,0,2139,27250,65082,5313,86671,24195,0,2 | ||
0,0,214292,60,13,1,5,0,1,2,2,405168,4408,200625,53253,2913,0,0,0,3 | ||
0,0,214292,60,13,1,5,0,1,2,2,0,3711,105033,0,3877,0,0,0,3 | ||
0,0,214292,60,13,1,5,0,1,2,2,342042,3346,230129,0,610,0,0,0,3 | ||
0,0,214292,60,13,1,5,0,1,2,2,140466,3711,160865,33806,3877,0,0,19879,3 | ||
0,0,214292,60,13,1,5,0,1,2,2,243201,3924,130907,30524,1289,0,0,0,3 | ||
0,0,214292,60,13,1,5,0,1,2,2,229158,3643,123318,92849,0,0,0,0,3 | ||
0,0,214292,60,13,1,5,0,1,2,2,438610,4782,215293,100092,0,0,0,0,3 | ||
0,0,214292,60,13,1,5,0,1,2,2,106945,5334,16167,34467,4967,0,0,0,3 | ||
0,0,214292,60,13,1,5,0,1,2,2,72148,2177,182249,0,5689,93690,0,13168,3 | ||
0,0,214292,60,13,1,5,0,1,2,2,145491,989,248896,65379,0,0,0,0,3 | ||
0,0,214292,60,13,1,5,0,1,2,2,49677,5673,57894,0,0,0,0,0,3 | ||
0,0,214292,60,13,1,5,0,1,2,2,401746,3334,54001,104747,0,83894,0,0,3 | ||
0,0,214292,60,13,1,5,0,1,2,2,293581,1685,30366,47952,2631,0,0,12116,3 | ||
0,0,214292,60,13,1,5,0,1,2,2,211989,2815,162936,92766,0,0,0,0,3 | ||
0,0,214292,60,13,1,5,0,1,2,2,398227,4782,12082,65379,0,0,0,0,3 | ||
0,0,214292,60,13,1,5,0,1,2,2,0,3713,0,68249,1813,0,27110,0,3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
import paddle.nn as nn | ||
import paddle.nn.functional as F | ||
import net | ||
|
||
|
||
class DygraphModel(): | ||
def create_model(self, config): | ||
feature_vocabulary = config.get("hyper_parameters.feature_vocabulary") | ||
embedding_size = config.get("hyper_parameters.embedding_size") | ||
tower_dims = config.get("hyper_parameters.dims") | ||
drop_prob = config.get('hyper_parameters.drop_prob') | ||
feature_vocabulary = dict(feature_vocabulary) | ||
model = net.AITM(feature_vocabulary, embedding_size, tower_dims, | ||
drop_prob) | ||
return model | ||
|
||
# define feeds which convert numpy of batch data to paddle.tensor | ||
def create_feeds(self, batch_data, config): | ||
click, conversion, features = batch_data | ||
return click.astype('float32'), conversion.astype('float32'), features | ||
|
||
# define loss function by predicts and label | ||
def create_loss(self, | ||
click_pred, | ||
conversion_pred, | ||
click_label, | ||
conversion_label, | ||
constraint_weight=0.6): | ||
click_loss = F.binary_cross_entropy(click_pred, click_label) | ||
conversion_loss = F.binary_cross_entropy(conversion_pred, | ||
conversion_label) | ||
|
||
label_constraint = paddle.maximum(conversion_pred - click_pred, | ||
paddle.zeros_like(click_label)) | ||
constraint_loss = paddle.sum(label_constraint) | ||
|
||
loss = click_loss + conversion_loss + constraint_weight * constraint_loss | ||
return loss | ||
|
||
# define optimizer | ||
def create_optimizer(self, dy_model, config): | ||
lr = config.get("hyper_parameters.optimizer.learning_rate", 0.0001) | ||
optimizer = paddle.optimizer.Adam( | ||
learning_rate=lr, | ||
parameters=dy_model.parameters(), | ||
weight_decay=1e-6) | ||
return optimizer | ||
|
||
# define metrics such as auc/acc | ||
# multi-task need to define multi metric | ||
def create_metrics(self): | ||
metrics_list_name = ["click_auc", "purchase_auc"] | ||
metrics_list = [ | ||
paddle.metric.Auc("ROC", num_thresholds=100000), | ||
paddle.metric.Auc("ROC", num_thresholds=100000) | ||
] | ||
return metrics_list, metrics_list_name | ||
|
||
# construct train forward phase | ||
def train_forward(self, dy_model, metrics_list, batch_data, config): | ||
click, conversion, features = self.create_feeds(batch_data, config) | ||
click_pred, conversion_pred = dy_model.forward(features) | ||
loss = self.create_loss(click_pred, conversion_pred, click, conversion) | ||
# update metrics | ||
|
||
self.update_auc(click_pred, click, metrics_list[0]) | ||
self.update_auc(conversion_pred, conversion, metrics_list[1]) | ||
print_dict = {'loss': loss} | ||
return loss, metrics_list, print_dict | ||
|
||
@staticmethod | ||
def update_auc(prob, label, metrics): | ||
if prob.ndim == 1: | ||
prob = prob.unsqueeze(-1) | ||
assert prob.ndim == 2 | ||
predict_2d = paddle.concat(x=[1 - prob, prob], axis=1) | ||
metrics.update(predict_2d, label) | ||
|
||
def infer_forward(self, dy_model, metrics_list, batch_data, config): | ||
click, conversion, features = self.create_feeds(batch_data, config) | ||
with paddle.no_grad(): | ||
click_pred, conversion_pred = dy_model.forward(features) | ||
# update metrics | ||
self.update_auc(click_pred, click, metrics_list[0]) | ||
self.update_auc(conversion_pred, conversion, metrics_list[1]) | ||
return metrics_list, None | ||
|
||
def forward(self, dy_model, batch_data, config): | ||
click, conversion, features = self.create_feeds(batch_data, config) | ||
with paddle.no_grad(): | ||
click_pred, conversion_pred = dy_model.forward(features) | ||
# update metrics | ||
return click, click_pred, conversion, conversion_pred |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
demo数据下不建议使用gpu,可以采取减小数据集,减小epoch等方式缩短demo训练时间至一分钟以内