-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_utils.py
51 lines (39 loc) · 2.13 KB
/
run_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import random
import argparse
import numpy as np
# import torch
# from lora import run_lora
import jittor as jt
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
jt.set_global_seed(seed)
jt.seed(seed)
jt.set_seed(seed)
# torch.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=1, type=int)
# Dataset arguments
parser.add_argument('--root_path', type=str, default='')
parser.add_argument('--dataset', type=str, default='')
parser.add_argument('--shots', default=4, type=int)
# Model arguments
parser.add_argument('--backbone', default='ViT-B/32', type=str)
# Training arguments
parser.add_argument('--lr', default=2e-4, type=float)
parser.add_argument('--n_iters', default=200, type=int)
parser.add_argument('--batch_size', default=32, type=int)
# LoRA arguments
parser.add_argument('--position', type=str, default='all', choices=['bottom', 'mid', 'up', 'half-up', 'half-bottom', 'all', 'top3'], help='where to put the LoRA modules')
parser.add_argument('--encoder', type=str, choices=['text', 'vision', 'both'], default='both')
parser.add_argument('--params', metavar='N', type=str, nargs='+',default=['q', 'k', 'v'], help='list of attention matrices where putting a LoRA')
parser.add_argument('--r', default=2, type=int, help='the rank of the low-rank matrices')
parser.add_argument('--alpha', default=1, type=int, help='scaling (see LoRA paper)')
parser.add_argument('--dropout_rate', default=0.25, type=float, help='dropout rate applied before the LoRA module')
parser.add_argument('--save_path', default='caches', help='path to save the lora modules after training, not saved if None')
parser.add_argument('--filename', default='lora_weights', help='file name to save the lora weights (.pt extension will be added)')
parser.add_argument('--eval_only', default=False, action='store_true', help='only evaluate the LoRA modules (save_path should not be None)')
args = parser.parse_args()
return args