Skip to content

Commit

Permalink
change cache_seed_gen name
Browse files Browse the repository at this point in the history
  • Loading branch information
XianBW committed Oct 21, 2024
1 parent 3bc8911 commit 024d175
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
6 changes: 3 additions & 3 deletions rdagent/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,15 @@ def get_next_seed(self) -> int:
return random.randint(0, 10000) # noqa: S311


cache_seed_gen = CacheSeedGen()
LLM_CACHE_SEED_GEN = CacheSeedGen()


def _subprocess_wrapper(f: Callable, seed: int, args: list) -> Any:
"""
It is a function wrapper. To ensure the subprocess has a fixed start seed.
"""

cache_seed_gen.set_seed(seed)
LLM_CACHE_SEED_GEN.set_seed(seed)
return f(*args)


Expand Down Expand Up @@ -146,7 +146,7 @@ def multiprocessing_wrapper(func_calls: list[tuple[Callable, tuple]], n: int) ->

with mp.Pool(processes=max(1, min(n, len(func_calls)))) as pool:
results = [
pool.apply_async(_subprocess_wrapper, args=(f, cache_seed_gen.get_next_seed(), args))
pool.apply_async(_subprocess_wrapper, args=(f, LLM_CACHE_SEED_GEN.get_next_seed(), args))
for f, args in func_calls
]
return [result.get() for result in results]
Expand Down
4 changes: 2 additions & 2 deletions rdagent/oai/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import tiktoken

from rdagent.core.conf import RD_AGENT_SETTINGS
from rdagent.core.utils import SingletonBaseClass, cache_seed_gen
from rdagent.core.utils import LLM_CACHE_SEED_GEN, SingletonBaseClass
from rdagent.log import LogColors
from rdagent.log import rdagent_logger as logger
from rdagent.oai.llm_conf import LLM_SETTINGS
Expand Down Expand Up @@ -597,7 +597,7 @@ def _create_chat_completion_inner_function( # noqa: C901, PLR0912, PLR0915
This seed is different from `self.chat_seed` for GPT. It is for the local cache mechanism enabled by RD-Agent locally.
"""
if seed is None and RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen:
seed = cache_seed_gen.get_next_seed()
seed = LLM_CACHE_SEED_GEN.get_next_seed()

# TODO: we can add this function back to avoid so much `self.cfg.log_llm_chat_content`
if LLM_SETTINGS.log_llm_chat_content:
Expand Down
16 changes: 8 additions & 8 deletions test/oai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_chat_cache(self) -> None:
- cache is not missed & same question get different answer.
"""
from rdagent.core.conf import RD_AGENT_SETTINGS
from rdagent.core.utils import cache_seed_gen
from rdagent.core.utils import LLM_CACHE_SEED_GEN
from rdagent.oai.llm_conf import LLM_SETTINGS

system_prompt = "You are a helpful assistant."
Expand All @@ -78,7 +78,7 @@ def test_chat_cache(self) -> None:

RD_AGENT_SETTINGS.use_auto_chat_cache_seed_gen = True

cache_seed_gen.set_seed(10)
LLM_CACHE_SEED_GEN.set_seed(10)
response1 = APIBackend().build_messages_and_create_chat_completion(
system_prompt=system_prompt,
user_prompt=user_prompt,
Expand All @@ -88,7 +88,7 @@ def test_chat_cache(self) -> None:
user_prompt=user_prompt,
)

cache_seed_gen.set_seed(20)
LLM_CACHE_SEED_GEN.set_seed(20)
response3 = APIBackend().build_messages_and_create_chat_completion(
system_prompt=system_prompt,
user_prompt=user_prompt,
Expand All @@ -98,7 +98,7 @@ def test_chat_cache(self) -> None:
user_prompt=user_prompt,
)

cache_seed_gen.set_seed(10)
LLM_CACHE_SEED_GEN.set_seed(10)
response5 = APIBackend().build_messages_and_create_chat_completion(
system_prompt=system_prompt,
user_prompt=user_prompt,
Expand Down Expand Up @@ -133,7 +133,7 @@ def test_chat_cache_multiprocess(self) -> None:
- cache is not missed & same question get different answer.
"""
from rdagent.core.conf import RD_AGENT_SETTINGS
from rdagent.core.utils import cache_seed_gen, multiprocessing_wrapper
from rdagent.core.utils import LLM_CACHE_SEED_GEN, multiprocessing_wrapper
from rdagent.oai.llm_conf import LLM_SETTINGS

system_prompt = "You are a helpful assistant."
Expand All @@ -152,11 +152,11 @@ def test_chat_cache_multiprocess(self) -> None:

func_calls = [(_worker, (system_prompt, user_prompt)) for _ in range(4)]

cache_seed_gen.set_seed(10)
LLM_CACHE_SEED_GEN.set_seed(10)
responses1 = multiprocessing_wrapper(func_calls, n=4)
cache_seed_gen.set_seed(20)
LLM_CACHE_SEED_GEN.set_seed(20)
responses2 = multiprocessing_wrapper(func_calls, n=4)
cache_seed_gen.set_seed(10)
LLM_CACHE_SEED_GEN.set_seed(10)
responses3 = multiprocessing_wrapper(func_calls, n=4)

# Reset, for other tests
Expand Down

0 comments on commit 024d175

Please sign in to comment.