-
Notifications
You must be signed in to change notification settings - Fork 3
/
run_latency_attention.py
196 lines (164 loc) · 7.36 KB
/
run_latency_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import sys
import logging
from functools import partial
import torch
import argparse
import socket
from datetime import datetime
from transformers.models.llama.modeling_llama import LlamaConfig, DynamicCache, LlamaAttention
from kernel.palu_attention import LlamaPaluAttention
TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"
def trace_handler(prof: torch.profiler.profile, file_postfix="prefilling", device="cuda:0"):
# Prefix for file names.
host_name = socket.gethostname()
timestamp = datetime.now().strftime(TIME_FORMAT_STR)
file_prefix = f"{host_name}_{timestamp}"
# Construct the trace file.
prof.export_chrome_trace(f"{file_prefix}_{file_postfix}.json.gz")
# Construct the memory timeline file.
prof.export_memory_timeline(f"{file_prefix}_{file_postfix}.html", device=device)
def build_attention(args):
device = "cuda:0"
dtype = torch.float16
logging.info(f"Creating Attention, dtype: {dtype}, device: {device}")
config = LlamaConfig()
config.max_position_embeddings = 300000
attention = LlamaAttention(config, layer_idx=0).to(device, dtype)
return attention, config
def build_attention_palu(args):
device = "cuda:0"
dtype = torch.float16
logging.info(f"Creating Attention_Palu, dtype: {dtype}, device: {device}")
config = LlamaConfig()
config.max_position_embeddings = 300000
config.group_size = args.group_size
config.num_groups = config.num_attention_heads // args.group_size
config.total_rank_k = args.rank_k
config.total_rank_v = args.rank_v
logging.info(f"rank_k: {config.total_rank_k}, rank_v: {config.total_rank_v}, group_size: {config.group_size}, num_groups: {config.num_groups}")
attention = LlamaAttention(config, layer_idx=0)
attention_palu = LlamaPaluAttention.from_attention(attention, config).to(device, dtype)
return attention_palu, config
def profile_tpot(model, cache_size_k, cache_size_v, cache_type=torch.float16, batch_size=1, prompt_len=1024, repeats=100,
cache_graph=False, torch_profile=False, outfile=""):
logging.info(">>> Profiling TPOT (generation stage)")
device = next(iter(model.parameters())).device
cache_k = torch.randn(cache_size_k, dtype=cache_type, device=device)
cache_v = torch.randn(cache_size_v, dtype=cache_type, device=device)
past_key_value = DynamicCache()
past_key_value.update(cache_k, cache_v, 0)
position_ids = torch.arange(prompt_len, prompt_len+1)
hidden_dim = model.config.hidden_size
input_token = torch.randn((batch_size, 1, hidden_dim), dtype=torch.float16, device=device) # only input 1 token at a time
# warmup
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.no_grad():
with torch.cuda.stream(s):
for _ in range(25):
_ = model(input_token, past_key_value=past_key_value, position_ids=position_ids)
torch.cuda.current_stream().wait_stream(s)
if cache_graph:
with torch.no_grad():
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
out = model(input_token, past_key_value=past_key_value, position_ids=position_ids)
def generate(new_input_token, past_key_value, position_ids):
input_token.copy_(new_input_token)
graph.replay()
return out
else:
def generate(new_input_token, past_key_value, position_ids):
out = model(new_input_token, past_key_value=past_key_value, position_ids=position_ids)
return out
new_input_token = torch.randn((batch_size, 1, hidden_dim), dtype=torch.float16, device=device) # only input 1 token at a time
with torch.no_grad():
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(repeats):
generate(new_input_token, past_key_value=past_key_value, position_ids=position_ids)
end.record()
torch.cuda.synchronize()
dur = start.elapsed_time(end)
logging.info(f"Finished, prompt_len: {prompt_len}, latency: {dur/repeats:.2f} milliseconds (cache_graph={cache_graph})")
if torch_profile:
outfile_postfix = f"{outfile}"
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=1, warmup=5, active=6, repeat=1),
record_shapes=True,
profile_memory=True,
with_stack=True,
on_trace_ready=partial(
trace_handler, file_postfix=outfile_postfix, device="cuda:0"
)
) as prof:
with torch.no_grad():
# (wait=1, warmup=5, active=6) , repeat=1
for _ in range(12):
generate(new_input_token, past_key_value, position_ids=position_ids)
prof.step()
def main(args):
bs = 1
if args.palu:
attention, config = build_attention_palu(args)
attention.eval()
num_groups = config.num_groups
# NOTE: Assuming uniform head_dim
group_dim_k = config.total_rank_k // config.num_groups
group_dim_v = config.total_rank_v // config.num_groups
cache_size_k = (bs, num_groups, args.prompt_len, group_dim_k)
cache_size_v = (bs, num_groups, args.prompt_len, group_dim_v)
profile_tpot(attention, cache_size_k, cache_size_v, torch.float16, bs, args.prompt_len, args.repeats, args.cache_graph, args.torch_profile, "tpot_palu_fp16")
else:
attention, config = build_attention(args)
attention.eval()
num_heads = config.num_attention_heads
head_dim = config.hidden_size // num_heads
cache_size = (bs, num_heads, args.prompt_len, head_dim)
profile_tpot(attention, cache_size, cache_size, torch.float16, bs, args.prompt_len, args.repeats, args.cache_graph, args.torch_profile, "tpot_fp16")
if __name__ =='__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--palu', action='store_true',
help='Whether to use PALU attention.'
)
parser.add_argument(
'--rank_k', type=int, default=1024,
help='The rank of key matrix for PALU attention.'
)
parser.add_argument(
'--rank_v', type=int, default=2048,
help='The rank of value matrix for PALU attention.'
)
parser.add_argument(
'--group_size', type=int, default=4,
help='The group size for PALU attention.'
)
parser.add_argument(
'--repeats', type=int, default=100,
help='The number of profiling to repeat (default: 100)'
)
parser.add_argument(
'--prompt_len', type=int, default=1024,
help='The number of input tokens to model. (default: 1024)'
)
parser.add_argument(
'--cache_graph', action='store_true', default=False,
help='To enable CUDA graph cache, this only works for the generation stage (TPOT and TTLT)'
)
parser.add_argument(
'--torch_profile', action='store_true',
help='Whether to launch the pytorch profiler.'
)
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s [%(filename)s:%(lineno)3d] %(message)s",
datefmt="%d/%b/%Y %H:%M:%S",
stream=sys.stdout)
main(args)