-
Notifications
You must be signed in to change notification settings - Fork 726
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #522 from yinhaofeng/maml
maml
- Loading branch information
Showing
111 changed files
with
632 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
wget https://paddlerec.bj.bcebos.com/datasets/omniglot/omniglot_python.zip | ||
unzip omniglot_python.zip | ||
mv images_evaluation/* images_background/ | ||
mv images_background omniglot_raw | ||
rm -rf demo.py images_background_small1 images_background_small2 images_evaluation/ one-shot-classification strokes_* | ||
python preprocess.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import cv2 | ||
import numpy as np | ||
import random | ||
import shutil | ||
|
||
data_folder = './omniglot_raw' # omniglot数据集路径 | ||
|
||
character_folders = [os.path.join(data_folder, family, character) \ | ||
for family in os.listdir(data_folder) \ | ||
if os.path.isdir(os.path.join(data_folder, family)) \ | ||
for character in os.listdir(os.path.join(data_folder, family))] | ||
print("The number of character folders: {}".format(len( | ||
character_folders))) # 1623 | ||
random.seed(1) | ||
random.shuffle(character_folders) | ||
train_folders = character_folders[:973] | ||
val_folders = character_folders[973:1298] | ||
test_folders = character_folders[1298:] | ||
print('The number of train characters is {}'.format(len(train_folders))) # 973 | ||
print('The number of validation characters is {}'.format(len( | ||
val_folders))) # 325 | ||
print('The number of test characters is {}'.format(len(test_folders))) # 325 | ||
|
||
for char_fold in train_folders: | ||
path = char_fold.split("/") | ||
path[1] = "omniglot_train" | ||
shutil.copytree(char_fold, "/".join(path)) | ||
|
||
for char_fold in val_folders: | ||
path = char_fold.split("/") | ||
path[1] = "omniglot_valid" | ||
shutil.copytree(char_fold, "/".join(path)) | ||
|
||
for char_fold in test_folders: | ||
path = char_fold.split("/") | ||
path[1] = "omniglot_test" | ||
shutil.copytree(char_fold, "/".join(path)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
wget https://paddlerec.bj.bcebos.com/datasets/omniglot/omniglot.tar | ||
tar -xf omniglot.tar |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# global settings | ||
|
||
runner: | ||
train_data_dir: "./data" | ||
train_reader_path: "omniglot_reader" # importlib format | ||
use_gpu: True | ||
use_auc: False | ||
train_batch_size: 32 | ||
epochs: 1 | ||
print_interval: 10 | ||
model_save_path: "output_model_maml" | ||
test_data_dir: "./data" | ||
infer_reader_path: "omniglot_reader" # importlib format | ||
infer_batch_size: 32 | ||
infer_load_path: "output_model_maml" | ||
infer_start_epoch: 0 | ||
infer_end_epoch: 1 | ||
|
||
# hyper parameters of user-defined network | ||
hyper_parameters: | ||
# optimizer config | ||
meta_optimizer: | ||
class: Adam | ||
learning_rate: 0.001 | ||
strategy: async | ||
base_optimizer: | ||
class: SGD | ||
learning_rate: 0.1 | ||
strategy: async | ||
# user-defined <key, value> pairs | ||
update_step: 5 | ||
update_step_test: 5 | ||
n_way: 5 | ||
k_spt: 1 | ||
k_query: 15 | ||
imgsize: 28 | ||
conv_stride: 1 | ||
conv_padding: 1 | ||
conv_kernal: [3, 3] | ||
bn_channel: 64 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# global settings | ||
|
||
runner: | ||
train_data_dir: "../../../datasets/omniglot/omniglot_train" | ||
train_reader_path: "omniglot_reader" # importlib format | ||
use_gpu: True | ||
use_auc: False | ||
train_batch_size: 32 | ||
epochs: 100 | ||
print_interval: 10 | ||
model_save_path: "output_model_all_maml" | ||
test_data_dir: "../../../datasets/omniglot/omniglot_test" | ||
infer_reader_path: "omniglot_reader" # importlib format | ||
infer_batch_size: 32 | ||
infer_load_path: "output_model_all_maml" | ||
infer_start_epoch: 90 | ||
infer_end_epoch: 100 | ||
|
||
# hyper parameters of user-defined network | ||
hyper_parameters: | ||
# optimizer config | ||
meta_optimizer: | ||
class: Adam | ||
learning_rate: 0.001 | ||
strategy: async | ||
base_optimizer: | ||
class: SGD | ||
learning_rate: 0.1 | ||
strategy: async | ||
# user-defined <key, value> pairs | ||
update_step: 5 | ||
update_step_test: 5 | ||
n_way: 5 | ||
k_spt: 1 | ||
k_query: 15 | ||
imgsize: 28 | ||
conv_stride: 1 | ||
conv_padding: 1 | ||
conv_kernal: [3, 3] | ||
bn_channel: 64 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
import paddle.nn as nn | ||
import paddle.nn.functional as F | ||
import math | ||
import copy | ||
import numpy as np | ||
import net | ||
|
||
|
||
class DygraphModel(): | ||
# define model | ||
def create_model(self, config): | ||
conv_stride = config.get("hyper_parameters.conv_stride") | ||
conv_padding = config.get("hyper_parameters.conv_padding") | ||
conv_kernal = config.get("hyper_parameters.conv_kernal") | ||
bn_channel = config.get("hyper_parameters.bn_channel") | ||
maml_model = net.MAMLLayer(conv_stride, conv_padding, conv_kernal, | ||
bn_channel) | ||
return maml_model | ||
|
||
# define feeds which convert numpy of batch data to paddle.tensor | ||
def create_feeds(self, batch_data, config): | ||
x_spt = paddle.to_tensor(batch_data[0].numpy().astype("float32")) | ||
y_spt = paddle.to_tensor(batch_data[1].numpy().astype("int64")) | ||
x_qry = paddle.to_tensor(batch_data[2].numpy().astype("float32")) | ||
y_qry = paddle.to_tensor(batch_data[3].numpy().astype("int64")) | ||
#print("x_spt",x_spt.shape,"y_spt",y_spt.shape,"x_qry",x_qry.shape,"y_qry",y_qry.shape) | ||
return x_spt, y_spt, x_qry, y_qry | ||
|
||
# define optimizer | ||
def create_optimizer(self, dy_model, config): | ||
meta_lr = config.get("hyper_parameters.meta_optimizer.learning_rate", | ||
0.001) | ||
optimizer = paddle.optimizer.Adam( | ||
learning_rate=meta_lr, parameters=dy_model.parameters()) | ||
return optimizer | ||
|
||
# define metrics such as auc/acc | ||
# multi-task need to define multi metric | ||
def create_metrics(self): | ||
metrics_list_name = [] | ||
metrics_list = [] | ||
return metrics_list, metrics_list_name | ||
|
||
# construct train forward phase | ||
def train_forward(self, dy_model, metrics_list, batch_data, config): | ||
np.random.seed(12345) | ||
x_spt, y_spt, x_qry, y_qry = self.create_feeds(batch_data, config) | ||
update_step = config.get("hyper_parameters.update_step", 5) | ||
task_num = x_spt.shape[0] | ||
query_size = x_qry.shape[ | ||
1] # 75 = 15 * 5, x_qry.shape = [32,75,1,28,28] | ||
loss_list = [] | ||
loss_list.clear() | ||
correct_list = [] | ||
correct_list.clear() | ||
task_grad = [[] for _ in range(task_num)] | ||
|
||
for i in range(task_num): | ||
# 外循环 | ||
task_net = copy.deepcopy(dy_model) | ||
base_lr = config.get( | ||
"hyper_parameters.base_optimizer.learning_rate", 0.1) | ||
task_optimizer = paddle.optimizer.SGD( | ||
learning_rate=base_lr, parameters=task_net.parameters()) | ||
for j in range(update_step): | ||
#内循环 | ||
task_optimizer.clear_grad() # 梯度清零 | ||
y_hat = task_net(x_spt[i]) # (setsz, ways) [5,5] | ||
loss_spt = F.cross_entropy(y_hat, y_spt[i]) | ||
loss_spt.backward() | ||
task_optimizer.step() | ||
|
||
y_hat = task_net(x_qry[i]) | ||
loss_qry = F.cross_entropy(y_hat, y_qry[i]) | ||
loss_qry.backward() | ||
for k in task_net.parameters(): | ||
task_grad[i].append(k.grad) | ||
loss_list.append(loss_qry) | ||
pred_qry = F.softmax(y_hat, axis=1).argmax(axis=1) | ||
correct = paddle.equal(pred_qry, y_qry[i]).numpy().sum().item() | ||
correct_list.append(correct) | ||
|
||
loss_average = paddle.add_n(loss_list) / task_num | ||
acc = sum(correct_list) / (query_size * task_num) | ||
|
||
for num, k in enumerate(dy_model.parameters()): | ||
tmp_list = [task_grad[i][num] for i in range(task_num)] | ||
if tmp_list[0] is not None: | ||
k._set_grad_ivar(paddle.add_n(tmp_list) / task_num) | ||
|
||
acc = paddle.to_tensor(acc) | ||
print_dict = {'loss': loss_average, "acc": acc} | ||
_ = paddle.ones(shape=[5, 5], dtype="float32") | ||
return _, metrics_list, print_dict | ||
|
||
def infer_forward(self, dy_model, metrics_list, batch_data, config): | ||
dy_model.train() | ||
x_spt, y_spt, x_qry, y_qry = self.create_feeds(batch_data, config) | ||
x_spt = x_spt[0] | ||
y_spt = y_spt[0] | ||
x_qry = x_qry[0] | ||
y_qry = y_qry[0] | ||
update_step = config.get("hyper_parameters.update_step_test", 5) | ||
query_size = x_qry.shape[0] | ||
correct_list = [] | ||
correct_list.clear() | ||
|
||
task_net = copy.deepcopy(dy_model) | ||
base_lr = config.get("hyper_parameters.base_optimizer.learning_rate", | ||
0.1) | ||
task_optimizer = paddle.optimizer.SGD(learning_rate=base_lr, | ||
parameters=task_net.parameters()) | ||
for j in range(update_step): | ||
task_optimizer.clear_grad() | ||
y_hat = task_net(x_spt) | ||
loss_spt = F.cross_entropy(y_hat, y_spt) | ||
loss_spt.backward() | ||
task_optimizer.step() | ||
|
||
y_hat = task_net(x_qry) | ||
pred_qry = F.softmax(y_hat, axis=1).argmax(axis=1) | ||
correct = paddle.equal(pred_qry, y_qry).numpy().sum().item() | ||
correct_list.append(correct) | ||
acc = sum(correct_list) / query_size | ||
acc = paddle.to_tensor(acc) | ||
print_dict = {"acc": acc} | ||
|
||
return metrics_list, print_dict |
Oops, something went wrong.