-
Notifications
You must be signed in to change notification settings - Fork 1
/
evaluate_flops.py
25 lines (20 loc) · 1.02 KB
/
evaluate_flops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from transformers import AutoTokenizer,AutoModelForCausalLM
import argparse
from calflops import calculate_flops
def main(args):
# load your model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.path,use_fast=False,add_bos_token=False,trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(args.path,device_map="auto",trust_remote_code=True)
batch_size = 1
max_seq_length = args.seqlen
flops, macs, params = calculate_flops(model=model,
input_shape=(batch_size, max_seq_length),
transformer_tokenizer=tokenizer,
output_precision=2)
print("FLOPs:%s MACs:%s Params:%s \n" %(flops, macs, params))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--path", type=str, required=True, help='model checkpoint location')
parser.add_argument("--seqlen", type=int, default=128)
args = parser.parse_args()
main(args)