Skip to content

Commit

Permalink
v3.5.0: mc.llm_parallel() & mc.utils.run_parallel(): parallel inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Nayjest committed Apr 20, 2024
1 parent 8170ceb commit 760fc9c
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 2 deletions.
5 changes: 3 additions & 2 deletions microcore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .types import BadAIJsonAnswer, BadAIAnswer
from .wrappers.prompt_wrapper import PromptWrapper
from .wrappers.llm_response_wrapper import LLMResponse
from ._llm_functions import llm, allm
from ._llm_functions import llm, allm, llm_parallel
from .utils import parse, dedent


Expand Down Expand Up @@ -101,6 +101,7 @@ def delete(self, collection: str, what: str | list[str] | dict):
__all__ = [
"llm",
"allm",
"llm_parallel",
"tpl",
"prompt",
"fmt",
Expand Down Expand Up @@ -136,4 +137,4 @@ def delete(self, collection: str, what: str | list[str] | dict):
# "wrappers",
]

__version__ = "3.4.0"
__version__ = "3.5.0"
14 changes: 14 additions & 0 deletions microcore/_llm_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime

from .message_types import Msg
from .utils import run_parallel
from .wrappers.llm_response_wrapper import LLMResponse
from ._env import env

Expand Down Expand Up @@ -93,3 +94,16 @@ async def allm(
...
[h(response) for h in env().llm_after_handlers]
return response


async def llm_parallel(
prompts: list, max_concurrent_tasks: int = None, **kwargs
) -> list[str] | list[LLMResponse]:
tasks = [allm(prompt, **kwargs) for prompt in prompts]

if max_concurrent_tasks is None:
max_concurrent_tasks = int(env().config.MAX_CONCURRENT_TASKS)
if not max_concurrent_tasks:
max_concurrent_tasks = len(tasks)

return await run_parallel(tasks, max_concurrent_tasks=max_concurrent_tasks)
2 changes: 2 additions & 0 deletions microcore/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ class Config(LLMConfig):
TEXT_TO_SPEECH_PATH: str | Path = from_env()
"""Path to the folder with generated voice files"""

MAX_CONCURRENT_TASKS: int = from_env(None)

def __post_init__(self):
if self.JINJA2_AUTO_ESCAPE is None:
self.JINJA2_AUTO_ESCAPE = get_bool_from_env("JINJA2_AUTO_ESCAPE", False)
Expand Down
11 changes: 11 additions & 0 deletions microcore/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import builtins
import dataclasses
import inspect
Expand Down Expand Up @@ -276,3 +277,13 @@ def dedent(text: str):
else:
dedented_lines = lines
return "\n".join(dedented_lines)


async def run_parallel(tasks: list, max_concurrent_tasks: int):
semaphore = asyncio.Semaphore(max_concurrent_tasks)

async def worker(task):
async with semaphore:
return await task

return await asyncio.gather(*[worker(task) for task in tasks])

0 comments on commit 760fc9c

Please sign in to comment.