Skip to content

Commit

Permalink
Merge pull request #522 from yinhaofeng/maml
Browse files Browse the repository at this point in the history
maml
  • Loading branch information
seemingwang authored Aug 26, 2021
2 parents a3b328e + 7599ff7 commit 80520bd
Show file tree
Hide file tree
Showing 111 changed files with 632 additions and 0 deletions.
6 changes: 6 additions & 0 deletions datasets/omniglot/download.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
wget https://paddlerec.bj.bcebos.com/datasets/omniglot/omniglot_python.zip
unzip omniglot_python.zip
mv images_evaluation/* images_background/
mv images_background omniglot_raw
rm -rf demo.py images_background_small1 images_background_small2 images_evaluation/ one-shot-classification strokes_*
python preprocess.py
52 changes: 52 additions & 0 deletions datasets/omniglot/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import cv2
import numpy as np
import random
import shutil

data_folder = './omniglot_raw' # omniglot数据集路径

character_folders = [os.path.join(data_folder, family, character) \
for family in os.listdir(data_folder) \
if os.path.isdir(os.path.join(data_folder, family)) \
for character in os.listdir(os.path.join(data_folder, family))]
print("The number of character folders: {}".format(len(
character_folders))) # 1623
random.seed(1)
random.shuffle(character_folders)
train_folders = character_folders[:973]
val_folders = character_folders[973:1298]
test_folders = character_folders[1298:]
print('The number of train characters is {}'.format(len(train_folders))) # 973
print('The number of validation characters is {}'.format(len(
val_folders))) # 325
print('The number of test characters is {}'.format(len(test_folders))) # 325

for char_fold in train_folders:
path = char_fold.split("/")
path[1] = "omniglot_train"
shutil.copytree(char_fold, "/".join(path))

for char_fold in val_folders:
path = char_fold.split("/")
path[1] = "omniglot_valid"
shutil.copytree(char_fold, "/".join(path))

for char_fold in test_folders:
path = char_fold.split("/")
path[1] = "omniglot_test"
shutil.copytree(char_fold, "/".join(path))
2 changes: 2 additions & 0 deletions datasets/omniglot/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
wget https://paddlerec.bj.bcebos.com/datasets/omniglot/omniglot.tar
tar -xf omniglot.tar
1 change: 1 addition & 0 deletions datasets/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ sh data_process.sh
|[FourSquare](https://paddlerec.bj.bcebos.com/datasets/FourSquare/FourSquare.zip)|此数据集包含在纽约和东京进行的大约10个月收集的签到。每个签到都有其时间戳,GPS坐标及其语义相关联。|[Kaggle](https://www.kaggle.com/chetanism/foursquare-nyc-and-tokyo-checkin-dataset)|
|[AmazonBook](https://paddlerec.bj.bcebos.com/datasets/AmazonBook/AmazonBook.tar.gz)|论文原作者处理过的AmazonBook数据集 |[《Controllable Multi-Interest Framework for Recommendation》](https://arxiv.org/abs/2005.09347)|
|[Ali_Display_Ad_Click](https://paddlerec.bj.bcebos.com/datasets/dmr/dataset_full.zip)|预处理过的Alimama数据集 |[Deep Match to Rank Model for Personalized Click-Through Rate Prediction](https://github.com/lvze92/DMR)|
|[omniglot](https://paddlerec.bj.bcebos.com/datasets/omniglot/omniglot.tar)|预处理过的omniglot数据集 |[Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks](https://arxiv.org/pdf/1703.03400.pdf)|
Binary file added doc/imgs/maml.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
54 changes: 54 additions & 0 deletions models/multitask/maml/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# global settings

runner:
train_data_dir: "./data"
train_reader_path: "omniglot_reader" # importlib format
use_gpu: True
use_auc: False
train_batch_size: 32
epochs: 1
print_interval: 10
model_save_path: "output_model_maml"
test_data_dir: "./data"
infer_reader_path: "omniglot_reader" # importlib format
infer_batch_size: 32
infer_load_path: "output_model_maml"
infer_start_epoch: 0
infer_end_epoch: 1

# hyper parameters of user-defined network
hyper_parameters:
# optimizer config
meta_optimizer:
class: Adam
learning_rate: 0.001
strategy: async
base_optimizer:
class: SGD
learning_rate: 0.1
strategy: async
# user-defined <key, value> pairs
update_step: 5
update_step_test: 5
n_way: 5
k_spt: 1
k_query: 15
imgsize: 28
conv_stride: 1
conv_padding: 1
conv_kernal: [3, 3]
bn_channel: 64
54 changes: 54 additions & 0 deletions models/multitask/maml/config_bigdata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# global settings

runner:
train_data_dir: "../../../datasets/omniglot/omniglot_train"
train_reader_path: "omniglot_reader" # importlib format
use_gpu: True
use_auc: False
train_batch_size: 32
epochs: 100
print_interval: 10
model_save_path: "output_model_all_maml"
test_data_dir: "../../../datasets/omniglot/omniglot_test"
infer_reader_path: "omniglot_reader" # importlib format
infer_batch_size: 32
infer_load_path: "output_model_all_maml"
infer_start_epoch: 90
infer_end_epoch: 100

# hyper parameters of user-defined network
hyper_parameters:
# optimizer config
meta_optimizer:
class: Adam
learning_rate: 0.001
strategy: async
base_optimizer:
class: SGD
learning_rate: 0.1
strategy: async
# user-defined <key, value> pairs
update_step: 5
update_step_test: 5
n_way: 5
k_spt: 1
k_query: 15
imgsize: 28
conv_stride: 1
conv_padding: 1
conv_kernal: [3, 3]
bn_channel: 64
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
143 changes: 143 additions & 0 deletions models/multitask/maml/dygraph_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import math
import copy
import numpy as np
import net


class DygraphModel():
# define model
def create_model(self, config):
conv_stride = config.get("hyper_parameters.conv_stride")
conv_padding = config.get("hyper_parameters.conv_padding")
conv_kernal = config.get("hyper_parameters.conv_kernal")
bn_channel = config.get("hyper_parameters.bn_channel")
maml_model = net.MAMLLayer(conv_stride, conv_padding, conv_kernal,
bn_channel)
return maml_model

# define feeds which convert numpy of batch data to paddle.tensor
def create_feeds(self, batch_data, config):
x_spt = paddle.to_tensor(batch_data[0].numpy().astype("float32"))
y_spt = paddle.to_tensor(batch_data[1].numpy().astype("int64"))
x_qry = paddle.to_tensor(batch_data[2].numpy().astype("float32"))
y_qry = paddle.to_tensor(batch_data[3].numpy().astype("int64"))
#print("x_spt",x_spt.shape,"y_spt",y_spt.shape,"x_qry",x_qry.shape,"y_qry",y_qry.shape)
return x_spt, y_spt, x_qry, y_qry

# define optimizer
def create_optimizer(self, dy_model, config):
meta_lr = config.get("hyper_parameters.meta_optimizer.learning_rate",
0.001)
optimizer = paddle.optimizer.Adam(
learning_rate=meta_lr, parameters=dy_model.parameters())
return optimizer

# define metrics such as auc/acc
# multi-task need to define multi metric
def create_metrics(self):
metrics_list_name = []
metrics_list = []
return metrics_list, metrics_list_name

# construct train forward phase
def train_forward(self, dy_model, metrics_list, batch_data, config):
np.random.seed(12345)
x_spt, y_spt, x_qry, y_qry = self.create_feeds(batch_data, config)
update_step = config.get("hyper_parameters.update_step", 5)
task_num = x_spt.shape[0]
query_size = x_qry.shape[
1] # 75 = 15 * 5, x_qry.shape = [32,75,1,28,28]
loss_list = []
loss_list.clear()
correct_list = []
correct_list.clear()
task_grad = [[] for _ in range(task_num)]

for i in range(task_num):
# 外循环
task_net = copy.deepcopy(dy_model)
base_lr = config.get(
"hyper_parameters.base_optimizer.learning_rate", 0.1)
task_optimizer = paddle.optimizer.SGD(
learning_rate=base_lr, parameters=task_net.parameters())
for j in range(update_step):
#内循环
task_optimizer.clear_grad() # 梯度清零
y_hat = task_net(x_spt[i]) # (setsz, ways) [5,5]
loss_spt = F.cross_entropy(y_hat, y_spt[i])
loss_spt.backward()
task_optimizer.step()

y_hat = task_net(x_qry[i])
loss_qry = F.cross_entropy(y_hat, y_qry[i])
loss_qry.backward()
for k in task_net.parameters():
task_grad[i].append(k.grad)
loss_list.append(loss_qry)
pred_qry = F.softmax(y_hat, axis=1).argmax(axis=1)
correct = paddle.equal(pred_qry, y_qry[i]).numpy().sum().item()
correct_list.append(correct)

loss_average = paddle.add_n(loss_list) / task_num
acc = sum(correct_list) / (query_size * task_num)

for num, k in enumerate(dy_model.parameters()):
tmp_list = [task_grad[i][num] for i in range(task_num)]
if tmp_list[0] is not None:
k._set_grad_ivar(paddle.add_n(tmp_list) / task_num)

acc = paddle.to_tensor(acc)
print_dict = {'loss': loss_average, "acc": acc}
_ = paddle.ones(shape=[5, 5], dtype="float32")
return _, metrics_list, print_dict

def infer_forward(self, dy_model, metrics_list, batch_data, config):
dy_model.train()
x_spt, y_spt, x_qry, y_qry = self.create_feeds(batch_data, config)
x_spt = x_spt[0]
y_spt = y_spt[0]
x_qry = x_qry[0]
y_qry = y_qry[0]
update_step = config.get("hyper_parameters.update_step_test", 5)
query_size = x_qry.shape[0]
correct_list = []
correct_list.clear()

task_net = copy.deepcopy(dy_model)
base_lr = config.get("hyper_parameters.base_optimizer.learning_rate",
0.1)
task_optimizer = paddle.optimizer.SGD(learning_rate=base_lr,
parameters=task_net.parameters())
for j in range(update_step):
task_optimizer.clear_grad()
y_hat = task_net(x_spt)
loss_spt = F.cross_entropy(y_hat, y_spt)
loss_spt.backward()
task_optimizer.step()

y_hat = task_net(x_qry)
pred_qry = F.softmax(y_hat, axis=1).argmax(axis=1)
correct = paddle.equal(pred_qry, y_qry).numpy().sum().item()
correct_list.append(correct)
acc = sum(correct_list) / query_size
acc = paddle.to_tensor(acc)
print_dict = {"acc": acc}

return metrics_list, print_dict
Loading

0 comments on commit 80520bd

Please sign in to comment.