Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Adds outlines performance improvement #5006

Conversation

lynkz-matt-psaltis
Copy link

This addresses concurrency performance issues when using Outlines leading to high CPU usage and a reduction in batching throughput. Timings showed ~1600% performance increase but I would very much appreciate validation of this method!

Overall approach:

  1. Moved the mask to be created once for each request rather than per iteration to reduce memory allocations.
  2. Use a simple caching object to reduce overall computational and data transformation costs.
  3. Convert to a numpy array as this showed a 200% decrease in cost for Python List to Tensor conversion.
  4. Offload int32 to int64 conversion to the GPU to maximize throughput.
  5. Use non_blocking during tensor move operations to reduce CPU contention.
Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    46                                               @line_profiler.profile
    47                                               def __call__(self, input_ids: List[int],
    48                                                            scores: torch.Tensor) -> torch.Tensor:
    49                                                   """Use the FSM to bias the logits before sampling the next token."""
    50     38607     210730.1      5.5      5.6          seq_id = hash(tuple(input_ids))
    51
    52     38607      17504.0      0.5      0.5          if len(input_ids) == 0:
    53        48        430.8      9.0      0.0              self.init_state()
    54                                                   else:
    55     38559       9725.8      0.3      0.3              last_token = input_ids[-1]
    56     38559     148043.9      3.8      3.9              last_seq_id = hash(tuple(input_ids[:-1]))
    57     77118     102817.6      1.3      2.7              self.fsm_state[seq_id] = self.fsm.next_state(
    58     38559      25271.5      0.7      0.7                  self.fsm_state[last_seq_id], last_token)
    59
    60     38607      11364.8      0.3      0.3          state = self.fsm_state[seq_id]
    61
    62                                                   # Retrieve allowed tokens from cache using the current state
    63     38607      11408.9      0.3      0.3          if state not in self.allowed_tokens_cache:
    64                                                       # Cache miss, calculate allowed tokens and cache them
    65       576     175710.0    305.1      4.7              allowed_tokens = self.fsm.allowed_token_ids(state)
    66       576     793771.6   1378.1     21.1              np_allowed_tokens = np.array(allowed_tokens, dtype=np.int32)
    67       576       2933.5      5.1      0.1              allowed_tokens_tensor = torch.from_numpy(np_allowed_tokens)
    68                                                       
    69       576        812.7      1.4      0.0              if (allowed_tokens_tensor.device != scores.device):
    70       576    1434945.6   2491.2     38.1                  allowed_tokens_tensor = allowed_tokens_tensor.to(scores.device, dtype=torch.int64, non_blocking=True)
    71                                                       else:
    72                                                           allowed_tokens_tensor = allowed_tokens_tensor.to(torch.int64)
    73
    74       576        600.9      1.0      0.0              self.allowed_tokens_cache[state] = allowed_tokens_tensor
    75
    76                                                   else:
    77     38031      11473.7      0.3      0.3              allowed_tokens_tensor = self.allowed_tokens_cache[self.fsm_state[seq_id]]
    78
    79     38607      10663.8      0.3      0.3          if self.mask is None:
    80        48        928.5     19.3      0.0              self.mask = torch.full_like(scores, -math.inf)
    81                                                   else:
    82     38559     291444.4      7.6      7.7              self.mask.fill_(-math.inf)
    83                                                   
    84     38607     285348.3      7.4      7.6          self.mask.index_fill_(0, allowed_tokens_tensor, 0)
    85     38607     209957.4      5.4      5.6          scores.add_(self.mask)
    86
    87     38607       7854.3      0.2      0.2          return scores

Fixes #3567


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@lynkz-matt-psaltis lynkz-matt-psaltis force-pushed the feature/outlines-performance branch 10 times, most recently from d839241 to 476f6ab Compare May 23, 2024 14:59
@simon-mo simon-mo self-requested a review May 23, 2024 16:07
Copy link
Collaborator

@simon-mo simon-mo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! Thank you for the wonderful work debugging and optimization. Just a small concern about the shape of mask.

from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
from outlines.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel
from torch import Tensor, from_numpy, full_like, int64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stylistically we would like to keep the torch namespace so we are using torch.Tensor etc...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sweet on it

allowed_tokens_tensor = self.allowed_tokens_cache[state]

if self.mask is None:
self.mask = full_like(scores, -math.inf)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally the shape is (scores.shape[-1], ) which is not the full shape of scores.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh good catch thank you!

@Yard1
Copy link
Collaborator

Yard1 commented May 23, 2024

@lynkz-matt-psaltis I would suggest using PyTorch profiler (using with_stack=True option) for profiling, CPU-only profiling will not take into account GPU operations and may lead to misleading results as CPU operations block waiting on GPU.

Comment on lines 65 to 66
allowed_tokens_tensor = allowed_tokens_tensor.to(
scores.device, dtype=int64, non_blocking=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non_blocking has no impact unless the CPU tensor is using pinned memory

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Yard1 that is very true :)

@lynkz-matt-psaltis
Copy link
Author

@lynkz-matt-psaltis I would suggest using PyTorch profiler (using with_stack=True option) for profiling, CPU-only profiling will not take into account GPU operations and may lead to misleading results as CPU operations block waiting on GPU.

Epic thanks! I've never done any torch work before so that feedback is golden.

@Yard1
Copy link
Collaborator

Yard1 commented May 23, 2024

You can use the benchmark_latency.py script we have, just need to modify it to run the LLM/request with outlines enabled and add with_stack=True to the pytorch profiler context (suggest using a very small number of iterations and output tokens when profiling to reduce profile size). You can then run the script with --profile --profile-result-dir DIRECTORY arguments and view the generated json file in ui.perfetto.dev

Also https://pytorch.org/docs/stable/notes/cuda.html#asynchronous-execution is a great read :)

@lynkz-matt-psaltis
Copy link
Author

You can use the benchmark_latency.py script we have, just need to modify it to run the LLM/request with outlines enabled and add with_stack=True to the pytorch profiler context (suggest using a very small number of iterations and output tokens when profiling to reduce profile size). You can then run the script with --profile --profile-result-dir DIRECTORY arguments and view the generated json file in ui.perfetto.dev

Also https://pytorch.org/docs/stable/notes/cuda.html#asynchronous-execution is a great read :)

Amazing! Thanks @Yard1. Can I just say you've all been such a welcoming community love it and huge shout out to everyone contributing!!!

@lynkz-matt-psaltis lynkz-matt-psaltis force-pushed the feature/outlines-performance branch 3 times, most recently from 7848018 to 6eb8382 Compare May 25, 2024 05:14
@lynkz-matt-psaltis
Copy link
Author

Using state directly for the cache key wasn't sufficient - Some tests were failing for some FSMs. I swapped to a tuple on the returned array which is stable but reduces performance to around 400% increase instead of 1600% due to the hashing computational costs. I noticed that Outlines has moved to a Guide structure and the transition has been implemented here: #4109. I'm thinking I'd be better off putting something on top of this PR so its all aligned for an outlines update? Thoughts?

@simon-mo
Copy link
Collaborator

Since the update has been done (thx! @br3no) is there update to this PR and I would assume it still applies?

@njhill
Copy link
Member

njhill commented Jun 13, 2024

@lynkz-matt-psaltis I guess #5053 is the replacement?

@lynkz-matt-psaltis
Copy link
Author

@lynkz-matt-psaltis I guess #5053 is the replacement?

Kind of - I've been exploring going further with this branch against latest outlines by optimising further. Some of these required outlines changes so I'm worried about the sequencing & coordination aspects of those changes going in. Given the direction of this conversation and the potential for further api signature changes, that may not be as big an issue now.

I can look at either:

  1. Strip this back to just vllm only optimisations
  2. Finish the current approach which would require outlines changes as well.

Any advice?

Copy link

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

Copy link

This pull request has been automatically closed due to inactivity. Please feel free to reopen if you intend to continue working on it. Thank you!

@github-actions github-actions bot closed this Nov 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Misc]: Throughput/Latency for guided_json with ~100% GPU cache utilization
4 participants