Skip to content

Commit

Permalink
small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jik0730 committed Dec 6, 2018
1 parent 5e6b6da commit 2c1eef6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 18 deletions.
37 changes: 22 additions & 15 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import numpy as np
from torch.autograd import Variable
from collections import OrderedDict
from src.data_loader import fetch_dataloaders


Expand Down Expand Up @@ -57,20 +57,6 @@ def evaluate(model, loss_fn, meta_classes, task_lr, task_type, metrics, params,
X_sup, Y_sup = X_sup.cuda(async=True), Y_sup.cuda(async=True)
X_que, Y_que = X_que.cuda(async=True), Y_que.cuda(async=True)

# # Adapt parameters by single gradient step
# Y_sup_hat = model(X_sup)
# loss = loss_fn(Y_sup_hat, Y_sup)
# optimizer.zero_grad()
# loss.backward()

# # follows train_single_task
# adapted_state_dict = model.cloned_state_dict()
# for key, val in model.named_parameters():
# adapted_state_dict[key] = val - task_lr * val.grad

# # compute preditions for query set
# Y_que_hat = model(X_que, adapted_state_dict)

# Direct optimization
net_clone = copy.deepcopy(model)
optim = torch.optim.SGD(net_clone.parameters(), lr=task_lr)
Expand All @@ -83,6 +69,27 @@ def evaluate(model, loss_fn, meta_classes, task_lr, task_type, metrics, params,
Y_que_hat = net_clone(X_que)
loss = loss_fn(Y_que_hat, Y_que)

# # clear previous gradients, compute gradients of all variables wrt loss
# def zero_grad(params):
# for p in params:
# if p.grad is not None:
# p.grad.zero_()

# # NOTE In Meta-SGD paper, num_eval_updates=1 is enough
# for _ in range(num_eval_updates):
# Y_sup_hat = model(X_sup)
# loss = loss_fn(Y_sup_hat, Y_sup)
# zero_grad(model.parameters())
# grads = torch.autograd.grad(loss, model.parameters())
# # step() manually
# adapted_state_dict = model.cloned_state_dict()
# adapted_params = OrderedDict()
# for (key, val), grad in zip(model.named_parameters(), grads):
# adapted_params[key] = val - task_lr * grad
# adapted_state_dict[key] = adapted_params[key]
# Y_que_hat = model(X_que, adapted_state_dict)
# loss = loss_fn(Y_que_hat, Y_que) # NOTE !!!!!!!!

# extract data from torch Variable, move to cpu, convert to numpy arrays
Y_que_hat = Y_que_hat.data.cpu().numpy()
Y_que = Y_que.data.cpu().numpy()
Expand Down
6 changes: 3 additions & 3 deletions experiments/base_model/params.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
"SEED": 1,
"dataset": "Omniglot",
"meta_lr": 1e-3,
"task_lr": 0.4,
"task_lr": 1e-1,
"num_episodes": 10000,
"num_classes": 5,
"num_samples": 5,
"num_samples": 1,
"num_query": 10,
"num_steps": 100,
"num_inner_tasks": 32,
"num_inner_tasks": 8,
"num_train_updates": 1,
"num_eval_updates": 3,
"save_summary_steps": 100,
Expand Down

0 comments on commit 2c1eef6

Please sign in to comment.