Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up the inference #24

Open
Symbolk opened this issue Jun 28, 2023 · 3 comments
Open

Speed up the inference #24

Symbolk opened this issue Jun 28, 2023 · 3 comments

Comments

@Symbolk
Copy link

Symbolk commented Jun 28, 2023

Hi, this model seems nice, but I do find that the inference speed is very slow (70ms/token on single A100), so I want to speed up it.

It seems to be related with MPT itself: https://huggingface.co/mosaicml/mpt-7b-instruct/discussions/23

Any suggestions or best practices on speeding up? E.g., FastTransformer (a bit low-level), ONNX Runtime, or Oneflow?

@madhavatreplit
Copy link
Contributor

Are you using triton flash attention, bfloat16 as described in the model's huggingface README?

You can also accelerate with Fastertransformers with this model.

@Symbolk
Copy link
Author

Symbolk commented Jul 1, 2023

Are you using triton flash attention, bfloat16 as described in the model's huggingface README?

You can also accelerate with Fastertransformers with this model.

Thanks for reply! I will try the right configuration for triton and bfloat16, with them enabled, how many milliseconds per token should I expect on A100-80G or V100-32G?

@Symbolk
Copy link
Author

Symbolk commented Jul 1, 2023

Hi, I enabled triton and bfloat16, inside the docker provided here: https://github.com/mosaicml/llm-foundry/, with dependencies installed, but the error is thrown like this:


0it [00:03, ?it/s]
╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ in _fwd_kernel:21                                                            │
╰──────────────────────────────────────────────────────────────────────────────╯
KeyError: 
('2-.-0-.-0-d82511111ad128294e9d31a6ac684238-2b0c5161c53c71b37ae20a9996ee4bb8-c1
f92808b4e4644c1732e8338187ac87-d962222789c30252d492a16cca3bf467-12f7ac1ca211e037
f62a7c0c323d9990-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff51
98-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', 
(torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.bfloat16,
torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 
'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 
'i32', 'i32', 'i32', 'i32', 'i32'), ('vector', True, 128, False, False, False, 
128, 128), (True, True, True, True, True, True, True, (False,), (True, False), 
(True, False), (True, False), (True, False), (True, False), (True, False), 
(True, False), (True, False), (True, False), (True, False), (False, False), 
(True, False), (True, False), (True, False), (True, False), (True, False), 
(False, False), (False, False), (True, False), (True, False), (False, False), 
(False, False)))

During handling of the above exception, another exception occurred:

╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /cache/pretrained_model/replit_new/eval.py:176 in <module>                   │
│                                                                              │
│   173 if __name__ == '__main__':                                             │
│   174 │                                                                      │
│   175 │   logger.info(f'CUDA version: {torch.version.cuda}')                 │
│ ❱ 176 │   run()                                                              │
│   177                                                                        │
│                                                                              │
│ /cache/pretrained_model/replit_new/eval.py:134 in run                        │
│                                                                              │
│   131 │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │      │
│   132 │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │      │
│   133 │   │   start_time = time.time()                                       │
│ ❱ 134 │   │   generated_snippet = model.generate(x, max_length=256,          │
│   135 │   │   │   │   │   │   │   │   │   │      do_sample=True, use_cache=F │
│   136 │   │   │   │   │   │   │   │   │   │      num_return_sequences=1, eos │
│   137 │   │   end_time = time.time()                                         │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/autograd/grad_mode.py:27 in             │
│ decorate_context                                                             │
│                                                                              │
│    24 │   │   @functools.wraps(func)                                         │
│    25 │   │   def decorate_context(*args, **kwargs):                         │
│    26 │   │   │   with self.clone():                                         │
│ ❱  27 │   │   │   │   return func(*args, **kwargs)                           │
│    28 │   │   return cast(F, decorate_context)                               │
│    29 │                                                                      │
│    30 │   def _wrap_generator(self, func):                                   │
│                                                                              │
│ /usr/lib/python3/dist-packages/transformers/generation/utils.py:1572 in      │
│ generate                                                                     │
│                                                                              │
│   1569 │   │   │   )                                                         │
│   1570 │   │   │                                                             │
│   1571 │   │   │   # 13. run sample                                          │
│ ❱ 1572 │   │   │   return self.sample(                                       │
│   1573 │   │   │   │   input_ids,                                            │
│   1574 │   │   │   │   logits_processor=logits_processor,                    │
│   1575 │   │   │   │   logits_warper=logits_warper,                          │
│                                                                              │
│ /usr/lib/python3/dist-packages/transformers/generation/utils.py:2619 in      │
│ sample                                                                       │
│                                                                              │
│   2616 │   │   │   model_inputs = self.prepare_inputs_for_generation(input_i │
│   2617 │   │   │                                                             │
│   2618 │   │   │   # forward pass to get next token                          │
│ ❱ 2619 │   │   │   outputs = self(                                           │
│   2620 │   │   │   │   **model_inputs,                                       │
│   2621 │   │   │   │   return_dict=True,                                     │
│   2622 │   │   │   │   output_attentions=output_attentions,                  │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/nn/modules/module.py:1194 in _call_impl │
│                                                                              │
│   1191 │   │   # this function, and just call forward.                       │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                     │
│   1195 │   │   # Do not call functions when jit is used                      │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /home/mosaicml/.cache/huggingface/modules/transformers_modules/replit_new/mo │
│ deling_mpt.py:239 in forward                                                 │
│                                                                              │
│   236 │   def forward(self, input_ids: torch.LongTensor, past_key_values: Op │
│   237 │   │   return_dict = return_dict if return_dict is not None else self │
│   238 │   │   use_cache = use_cache if use_cache is not None else self.confi │
│ ❱ 239 │   │   outputs = self.transformer(input_ids=input_ids, past_key_value │
│   240 │   │   logits = F.linear(outputs.last_hidden_state, self.transformer. │
│   241 │   │   if self.logit_scale is not None:                               │
│   242 │   │   │   if self.logit_scale == 0:                                  │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/nn/modules/module.py:1194 in _call_impl │
│                                                                              │
│   1191 │   │   # this function, and just call forward.                       │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                     │
│   1195 │   │   # Do not call functions when jit is used                      │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /home/mosaicml/.cache/huggingface/modules/transformers_modules/replit_new/mo │
│ deling_mpt.py:185 in forward                                                 │
│                                                                              │
│   182 │   │   │   │   assert all_hidden_states is not None                   │
│   183 │   │   │   │   all_hidden_states = all_hidden_states + (x,)           │
│   184 │   │   │   past_key_value = past_key_values[b_idx] if past_key_values │
│ ❱ 185 │   │   │   (x, past_key_value) = block(x, past_key_value=past_key_val │
│   186 │   │   │   if past_key_values is not None:                            │
│   187 │   │   │   │   past_key_values[b_idx] = past_key_value                │
│   188 │   │   x = self.norm_f(x)                                             │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/nn/modules/module.py:1194 in _call_impl │
│                                                                              │
│   1191 │   │   # this function, and just call forward.                       │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                     │
│   1195 │   │   # Do not call functions when jit is used                      │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /home/mosaicml/.cache/huggingface/modules/transformers_modules/replit_new/bl │
│ ocks.py:36 in forward                                                        │
│                                                                              │
│   33 │                                                                       │
│   34 │   def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[t │
│   35 │   │   a = self.norm_1(x)                                              │
│ ❱ 36 │   │   (b, _, past_key_value) = self.attn(a, past_key_value=past_key_v │
│   37 │   │   x = x + self.resid_attn_dropout(b)                              │
│   38 │   │   m = self.norm_2(x)                                              │
│   39 │   │   n = self.ffn(m)                                                 │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/nn/modules/module.py:1194 in _call_impl │
│                                                                              │
│   1191 │   │   # this function, and just call forward.                       │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                     │
│   1195 │   │   # Do not call functions when jit is used                      │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /home/mosaicml/.cache/huggingface/modules/transformers_modules/replit_new/at │
│ tention.py:172 in forward                                                    │
│                                                                              │
│   169 │   │   │   past_key_value = (key, value)                              │
│   170 │   │   if attn_bias is not None:                                      │
│   171 │   │   │   attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1): │
│ ❱ 172 │   │   (context, attn_weights) = self.attn_fn(query, key, value, self │
│   173 │   │   return (self.out_proj(context), attn_weights, past_key_value)  │
│   174                                                                        │
│   175 class MultiQueryAttention(nn.Module):                                  │
│                                                                              │
│ /home/mosaicml/.cache/huggingface/modules/transformers_modules/replit_new/at │
│ tention.py:111 in triton_flash_attn_fn                                       │
│                                                                              │
│   108 │   │   key = key.expand(*key.shape[:2], n_heads, key.size(-1))        │
│   109 │   │   value = value.expand(*value.shape[:2], n_heads, value.size(-1) │
│   110 │   reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_ │
│ ❱ 111 │   attn_output = flash_attn_triton.flash_attn_func(query, key, value, │
│   112 │   output = attn_output.view(*attn_output.shape[:2], -1)              │
│   113 │   return (output, None)                                              │
│   114                                                                        │
│                                                                              │
│ /usr/lib/python3/dist-packages/flash_attn/flash_attn_triton.py:810 in        │
│ forward                                                                      │
│                                                                              │
│   807 │   │   """                                                            │
│   808 │   │   # Make sure that the last dimension is contiguous              │
│   809 │   │   q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in │
│ ❱ 810 │   │   o, lse, ctx.softmax_scale = _flash_attn_forward(               │
│   811 │   │   │   q, k, v, bias=bias, causal=causal, softmax_scale=softmax_s │
│   812 │   │   )                                                              │
│   813 │   │   ctx.save_for_backward(q, k, v, o, lse, bias)                   │
│                                                                              │
│ /usr/lib/python3/dist-packages/flash_attn/flash_attn_triton.py:623 in        │
│ _flash_attn_forward                                                          │
│                                                                              │
│   620 │   BLOCK = 128                                                        │
│   621 │   num_warps = 4 if d <= 64 else 8                                    │
│   622 │   grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch │
│ ❱ 623 │   _fwd_kernel[grid](                                                 │
│   624 │   │   q, k, v, bias, o,                                              │
│   625 │   │   lse, tmp,                                                      │
│   626 │   │   softmax_scale,                                                 │
│                                                                              │
│ /home/mosaicml/.local/lib/python3.10/site-packages/triton/runtime/jit.py:106 │
│ in launcher                                                                  │
│                                                                              │
│   103 │   │   memorizes the grid.                                            │
│   104 │   │   """                                                            │
│   105 │   │   def launcher(*args, **kwargs):                                 │
│ ❱ 106 │   │   │   return self.run(*args, grid=grid, **kwargs)                │
│   107 │   │   return launcher                                                │
│   108                                                                        │
│   109                                                                        │
│                                                                              │
│ /home/mosaicml/.local/lib/python3.10/site-packages/triton/runtime/autotuner. │
│ py:200 in run                                                                │
│                                                                              │
│   197 │   def run(self, *args, **kwargs):                                    │
│   198 │   │   for v, heur in self.values.items():                            │
│   199 │   │   │   kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwa │
│ ❱ 200 │   │   return self.fn.run(*args, **kwargs)                            │
│   201                                                                        │
│   202                                                                        │
│   203 def heuristics(values):                                                │
│ in _fwd_kernel:41                                                            │
│                                                                              │
│ /home/mosaicml/.local/lib/python3.10/site-packages/triton/compiler.py:1268   │
│ in compile                                                                   │
│                                                                              │
│   1265 │   if warm_cache_only:                                               │
│   1266 │   │   return  # load_binary() requires a valid cuda context         │
│   1267 │                                                                     │
│ ❱ 1268 │   return CompiledKernel(name, so_cache_manager._make_path(so_name), │
│   1269                                                                       │
│   1270                                                                       │
│   1271 class CompiledKernel:                                                 │
│                                                                              │
│ /home/mosaicml/.local/lib/python3.10/site-packages/triton/compiler.py:1281   │
│ in __init__                                                                  │
│                                                                              │
│   1278 │   │   # initialize launcher                                         │
│   1279 │   │   import importlib.util                                         │
│   1280 │   │   spec = importlib.util.spec_from_file_location("launcher", so_ │
│ ❱ 1281 │   │   mod = importlib.util.module_from_spec(spec)                   │
│   1282 │   │   spec.loader.exec_module(mod)                                  │
│   1283 │   │   self.c_wrapper = getattr(mod, "launch")                       │
│   1284 │   │   # initialize metadata                                         │
│ in module_from_spec:571                                                      │
│ in create_module:1176                                                        │
│ in _call_with_frames_removed:241                                             │
╰──────────────────────────────────────────────────────────────────────────────╯
ImportError: 
/home/mosaicml/.triton/cache/ab77933fc177e6d77b0dd8896210d966/_fwd_kernel.so: 
undefined symbol: cuLaunchKernel

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants