-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_model_cspa.py
71 lines (53 loc) · 1.46 KB
/
get_model_cspa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
from collections import namedtuple
import torch
import pickle
import argparse
import yaml
from torchtyping import TensorType as TT
import torch
from utils.cspa_main import (
get_cspa_per_checkpoint
)
from utils.data_processing import get_ckpts
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
def get_args():
parser = argparse.ArgumentParser(description="Get CPSA per checkpoint and attention head")
parser.add_argument(
"-c",
"--config",
default="./configs/cspa/160m-canonical.yml",
help="Path to config file",
)
return parser.parse_args()
def read_config(config_path):
with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
return config
def main(args):
if 'device' in args:
device = args.device
else:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_grad_enabled(False)
print(f"Using device: {device}")
config = read_config(args.config)
#checkpoints = get_ckpts(config['checkpoint_schedule'])
checkpoints = config['checkpoint_schedule']
print(config)
get_cspa_per_checkpoint(
config['base_model'],
config['variant'],
config['cache'],
device,
checkpoints,
start_layer=config["start_layer"],
overwrite=config["overwrite"],
display_all=False
)
if __name__ == "__main__":
args = get_args()
main(args)