Skip to content

Commit

Permalink
Replace the use of assert in non-test code (#80)
Browse files Browse the repository at this point in the history
* Replace `assert`s in the `conversable_agent` module with `if-log-raise`.

* Use a `logger` object in the `code_utils` module.

* Replace use of `assert` with `if-log-raise` in the `code_utils` module.

* Replace use of `assert` in the `math_utils` module with `if-not-raise`.

* Replace `assert` with `if` in the `oai.completion` module.

* Replace `assert` in the `retrieve_utils` module with an if statement.

* Add missing `not`.

* Blacken `completion.py`.

* Test `generate_reply` and `a_generate_reply` raise an assertion error
when there are neither `messages` nor a `sender`.

* Test `execute_code` raises an `AssertionError` when neither code nor
filename is provided.

* Test `split_text_to_chunks` raises when passed an invalid chunk mode.

* * Add `tiktoken` and `chromadb` to test dependencies as they're used in
the `test_retrieve_utils` module.

* Sort the test requirements alphabetically.
  • Loading branch information
cipherself authored Oct 3, 2023
1 parent 39c145d commit a3547f8
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 22 deletions.
16 changes: 14 additions & 2 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict
import copy
import json
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from autogen import oai
from .agent import Agent
Expand All @@ -21,6 +22,9 @@ def colored(x, *args, **kwargs):
return x


logger = logging.getLogger(__name__)


class ConversableAgent(Agent):
"""(In preview) A class for generic conversable agents which can be configured as assistant or user proxy.
Expand Down Expand Up @@ -757,7 +761,11 @@ def generate_reply(
Returns:
str or dict or None: reply. None if no reply is generated.
"""
assert messages is not None or sender is not None, "Either messages or sender must be provided."
if all((messages is None, sender is None)):
error_msg = f"Either {messages=} or {sender=} must be provided."
logger.error(error_msg)
raise AssertionError(error_msg)

if messages is None:
messages = self._oai_messages[sender]

Expand Down Expand Up @@ -804,7 +812,11 @@ async def a_generate_reply(
Returns:
str or dict or None: reply. None if no reply is generated.
"""
assert messages is not None or sender is not None, "Either messages or sender must be provided."
if all((messages is None, sender is None)):
error_msg = f"Either {messages=} or {sender=} must be provided."
logger.error(error_msg)
raise AssertionError(error_msg)

if messages is None:
messages = self._oai_messages[sender]

Expand Down
10 changes: 8 additions & 2 deletions autogen/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
WIN32 = sys.platform == "win32"
PATH_SEPARATOR = WIN32 and "\\" or "/"

logger = logging.getLogger(__name__)


def infer_lang(code):
"""infer the language for the code.
Expand Down Expand Up @@ -250,7 +252,11 @@ def execute_code(
str: The error message if the code fails to execute; the stdout otherwise.
image: The docker image name after container run when docker is used.
"""
assert code is not None or filename is not None, "Either code or filename must be provided."
if all((code is None, filename is None)):
error_msg = f"Either {code=} or {filename=} must be provided."
logger.error(error_msg)
raise AssertionError(error_msg)

timeout = timeout or DEFAULT_TIMEOUT
original_filename = filename
if WIN32 and lang in ["sh", "shell"]:
Expand All @@ -276,7 +282,7 @@ def execute_code(
f".\\{filename}" if WIN32 else filename,
]
if WIN32:
logging.warning("SIGALRM is not supported on Windows. No timeout will be enforced.")
logger.warning("SIGALRM is not supported on Windows. No timeout will be enforced.")
result = subprocess.run(
cmd,
cwd=work_dir,
Expand Down
14 changes: 9 additions & 5 deletions autogen/math_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def remove_boxed(string: str) -> Optional[str]:
"""
left = "\\boxed{"
try:
assert string[: len(left)] == left
assert string[-1] == "}"
if not all((string[: len(left)] == left, string[-1] == "}")):
raise AssertionError

return string[len(left) : -1]
except Exception:
return None
Expand Down Expand Up @@ -94,7 +95,8 @@ def _fix_fracs(string: str) -> str:
new_str += substr
else:
try:
assert len(substr) >= 2
if not len(substr) >= 2:
raise AssertionError
except Exception:
return string
a = substr[0]
Expand Down Expand Up @@ -129,7 +131,8 @@ def _fix_a_slash_b(string: str) -> str:
try:
a = int(a_str)
b = int(b_str)
assert string == "{}/{}".format(a, b)
if not string == "{}/{}".format(a, b):
raise AssertionError
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except Exception:
Expand All @@ -143,7 +146,8 @@ def _remove_right_units(string: str) -> str:
"""
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
if not len(splits) == 2:
raise AssertionError
return splits[0]
else:
return string
Expand Down
29 changes: 20 additions & 9 deletions autogen/oai/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,23 +582,31 @@ def eval_func(responses, **data):
cls._prompts = space.get("prompt")
if cls._prompts is None:
cls._messages = space.get("messages")
assert isinstance(cls._messages, list) and isinstance(
cls._messages[0], (dict, list)
), "messages must be a list of dicts or a list of lists."
if not all((isinstance(cls._messages, list), isinstance(cls._messages[0], (dict, list)))):
error_msg = "messages must be a list of dicts or a list of lists."
logger.error(error_msg)
raise AssertionError(error_msg)
if isinstance(cls._messages[0], dict):
cls._messages = [cls._messages]
space["messages"] = tune.choice(list(range(len(cls._messages))))
else:
assert space.get("messages") is None, "messages and prompt cannot be provided at the same time."
assert isinstance(cls._prompts, (str, list)), "prompt must be a string or a list of strings."
if space.get("messages") is not None:
error_msg = "messages and prompt cannot be provided at the same time."
logger.error(error_msg)
raise AssertionError(error_msg)
if not isinstance(cls._prompts, (str, list)):
error_msg = "prompt must be a string or a list of strings."
logger.error(error_msg)
raise AssertionError(error_msg)
if isinstance(cls._prompts, str):
cls._prompts = [cls._prompts]
space["prompt"] = tune.choice(list(range(len(cls._prompts))))
cls._stops = space.get("stop")
if cls._stops:
assert isinstance(
cls._stops, (str, list)
), "stop must be a string, a list of strings, or a list of lists of strings."
if not isinstance(cls._stops, (str, list)):
error_msg = "stop must be a string, a list of strings, or a list of lists of strings."
logger.error(error_msg)
raise AssertionError(error_msg)
if not (isinstance(cls._stops, list) and isinstance(cls._stops[0], list)):
cls._stops = [cls._stops]
space["stop"] = tune.choice(list(range(len(cls._stops))))
Expand Down Expand Up @@ -969,7 +977,10 @@ def eval_func(responses, **data):
elif isinstance(agg_method, dict):
for key in metric_keys:
metric_agg_method = agg_method[key]
assert callable(metric_agg_method), "please provide a callable for each metric"
if not callable(metric_agg_method):
error_msg = "please provide a callable for each metric"
logger.error(error_msg)
raise AssertionError(error_msg)
result_agg[key] = metric_agg_method([r[key] for r in result_list])
else:
raise ValueError(
Expand Down
4 changes: 3 additions & 1 deletion autogen/retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"yml",
"pdf",
]
VALID_CHUNK_MODES = frozenset({"one_line", "multi_lines"})


def num_tokens_from_text(
Expand Down Expand Up @@ -96,7 +97,8 @@ def split_text_to_chunks(
overlap: int = 10,
):
"""Split a long text into chunks of max_tokens."""
assert chunk_mode in {"one_line", "multi_lines"}
if chunk_mode not in VALID_CHUNK_MODES:
raise AssertionError
if chunk_mode == "one_line":
must_break_at_empty_line = False
chunks = []
Expand Down
9 changes: 6 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,18 @@
install_requires=install_requires,
extras_require={
"test": [
"pytest>=6.1.1",
"chromadb",
"coverage>=5.3",
"pre-commit",
"datasets",
"ipykernel",
"nbconvert",
"nbformat",
"ipykernel",
"pre-commit",
"pydantic==1.10.9",
"pytest-asyncio",
"pytest>=6.1.1",
"sympy",
"tiktoken",
"wolframalpha",
],
"blendsearch": ["flaml[blendsearch]"],
Expand Down
22 changes: 22 additions & 0 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@
from autogen.agentchat import ConversableAgent


@pytest.fixture
def conversable_agent():
return ConversableAgent(
"conversable_agent_0",
max_consecutive_auto_reply=10,
code_execution_config=False,
llm_config=False,
human_input_mode="NEVER",
)


def test_trigger():
agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, human_input_mode="NEVER")
Expand Down Expand Up @@ -217,6 +228,17 @@ def add_num(num_to_be_added):
), "generate_reply not working when messages is None"


def test_generate_reply_raises_on_messages_and_sender_none(conversable_agent):
with pytest.raises(AssertionError):
conversable_agent.generate_reply(messages=None, sender=None)


@pytest.mark.asyncio
async def test_a_generate_reply_raises_on_messages_and_sender_none(conversable_agent):
with pytest.raises(AssertionError):
await conversable_agent.a_generate_reply(messages=None, sender=None)


if __name__ == "__main__":
test_trigger()
# test_context()
Expand Down
5 changes: 5 additions & 0 deletions test/test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,11 @@ def test_execute_code(use_docker=None):
assert isinstance(image, str) or docker is None or os.path.exists("/.dockerenv") or use_docker is False


def test_execute_code_raises_when_code_and_filename_are_both_none():
with pytest.raises(AssertionError):
execute_code(code=None, filename=None)


@pytest.mark.skipif(
sys.platform in ["darwin"],
reason="do not run on MacOS",
Expand Down
4 changes: 4 additions & 0 deletions test/test_retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def test_split_text_to_chunks(self):
chunks = split_text_to_chunks(long_text, max_tokens=1000)
assert all(num_tokens_from_text(chunk) <= 1000 for chunk in chunks)

def test_split_text_to_chunks_raises_on_invalid_chunk_mode(self):
with pytest.raises(AssertionError):
split_text_to_chunks("A" * 10000, chunk_mode="bogus_chunk_mode")

def test_extract_text_from_pdf(self):
pdf_file_path = os.path.join(test_dir, "example.pdf")
assert "".join(expected_text.split()) == "".join(extract_text_from_pdf(pdf_file_path).strip().split())
Expand Down

0 comments on commit a3547f8

Please sign in to comment.