Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import simplification & dev dependencies #51

Merged
merged 7 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/codeflash.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ jobs:
if: steps.bot_check.outputs.skip_remaining_steps == 'no'
run: curl -LsSf https://astral.sh/uv/install.sh | sh
- if: steps.bot_check.outputs.skip_remaining_steps == 'no'
run: uv sync
run: |-
uv sync
uv pip install codeflash
- name: Run CodeFlash on fhaviary
if: steps.bot_check.outputs.skip_remaining_steps == 'no'
run: uv run codeflash
Expand Down
2 changes: 1 addition & 1 deletion packages/gsm8k/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ where = ["src"]

[tool.setuptools_scm]
root = "../.."
version_file = "src/aviary/gsm8k/version.py"
version_file = "src/aviary/envs/gsm8k/version.py"
5 changes: 2 additions & 3 deletions packages/gsm8k/tests/test_gsm8k_env.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pytest

from aviary.env import Environment, TaskDataset
from aviary.gsm8k.env import CalculatorEnv, CalculatorEnvConfig
from aviary.tools import ToolCall, ToolRequestMessage
from aviary.core import Environment, TaskDataset, ToolCall, ToolRequestMessage
from aviary.envs.gsm8k import CalculatorEnv, CalculatorEnvConfig


@pytest.mark.asyncio
Expand Down
2 changes: 1 addition & 1 deletion packages/hotpotqa/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ where = ["src"]

[tool.setuptools_scm]
root = "../.."
version_file = "src/aviary/hotpotqa/version.py"
version_file = "src/aviary/envs/hotpotqa/version.py"
4 changes: 2 additions & 2 deletions packages/hotpotqa/tests/test_hotpotqa_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from aviary.env import Environment, TaskDataset
from aviary.hotpotqa import HotPotQAEnv
from aviary.core import Environment, TaskDataset
from aviary.envs.hotpotqa import HotPotQAEnv


def test_env_construction() -> None:
Expand Down
42 changes: 22 additions & 20 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,27 @@ requires-python = ">=3.11"
cloud = [
"boto3",
]
dev = [
"SQLAlchemy[aiosqlite]~=2.0", # Match aviary dependencies
"aviary.gsm8k[typing]", # So `uv sync` pulls this in, and for type stubs
"aviary.hotpotqa", # So `uv sync` pulls this in
"fhaviary[image,llm,server,typing,xml]",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this actually work (referring to other extras within this extra)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep - just tested uv pip install -e '.[dev]' on a clean env and all extras were pulled in.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also CI is passing, which would need these dependencies

"ipython>=8", # Pin to keep recent
"mypy>=1.8", # Pin for mutable-override
"pre-commit~=3.4",
"pydantic~=2.9", # Pydantic 2.9 changed JSON schema exports 'allOf', so ensure tests match
"pylint-pydantic",
"pylint>=3.2",
"pytest-asyncio",
"pytest-recording",
"pytest-subtests",
"pytest-sugar",
"pytest-timer[colorama]",
"pytest-xdist",
"pytest>=8", # Pin to keep recent
"refurb>=2", # Pin to keep recent
"typeguard",
]
gsm8k = ["aviary.gsm8k"]
hotpotqa = ["aviary.hotpotqa"]
image = [
Expand Down Expand Up @@ -394,26 +415,7 @@ trailing_comma_inline_array = true

[tool.uv]
dev-dependencies = [
"SQLAlchemy[aiosqlite]~=2.0", # Match aviary dependencies
"aviary.gsm8k[typing]", # So `uv sync` pulls this in, and for type stubs
"aviary.hotpotqa", # So `uv sync` pulls this in
"codeflash",
"fhaviary[image,llm,server,typing,xml]",
"ipython>=8", # Pin to keep recent
"mypy>=1.8", # Pin for mutable-override
"pre-commit~=3.4",
"pydantic~=2.9", # Pydantic 2.9 changed JSON schema exports 'allOf', so ensure tests match
"pylint-pydantic",
"pylint>=3.2",
"pytest-asyncio",
"pytest-recording",
"pytest-subtests",
"pytest-sugar",
"pytest-timer[colorama]",
"pytest-xdist",
"pytest>=8", # Pin to keep recent
"refurb>=2", # Pin to keep recent
"typeguard",
"fhaviary[dev]",
]

[tool.uv.sources]
Expand Down
69 changes: 69 additions & 0 deletions src/aviary/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from aviary.env import (
TASK_DATASET_REGISTRY,
DummyEnv,
DummyEnvState,
DummyTaskDataset,
Environment,
Frame,
TaskConfig,
TaskDataset,
)
from aviary.env_client import EnvironmentClient
from aviary.message import MalformedMessageError, Message, join
from aviary.render import Renderer
from aviary.tools import (
INVALID_TOOL_NAME,
FunctionInfo,
Messages,
MessagesAdapter,
Parameters,
Tool,
ToolCall,
ToolCallFunction,
ToolRequestMessage,
ToolResponseMessage,
Tools,
ToolsAdapter,
ToolSelector,
ToolSelectorLedger,
argref_by_name,
eval_answer,
wraps_doc_only,
)
from aviary.utils import encode_image_to_base64, is_coroutine_callable, partial_format

__all__ = [
"INVALID_TOOL_NAME",
"TASK_DATASET_REGISTRY",
"DummyEnv",
"DummyEnvState",
"DummyTaskDataset",
"Environment",
"EnvironmentClient",
"Frame",
"FunctionInfo",
"MalformedMessageError",
"Message",
"Messages",
"MessagesAdapter",
"Parameters",
"Renderer",
"TaskConfig",
"TaskDataset",
"Tool",
"ToolCall",
"ToolCallFunction",
"ToolRequestMessage",
"ToolResponseMessage",
"ToolSelector",
"ToolSelectorLedger",
"Tools",
"ToolsAdapter",
"argref_by_name",
"encode_image_to_base64",
"eval_answer",
"is_coroutine_callable",
"join",
"partial_format",
"wraps_doc_only",
]
8 changes: 4 additions & 4 deletions src/aviary/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ def from_name(cls, name: str, **env_kwargs) -> Self:
# Maps baseline environment names to their module and class names
ENV_REGISTRY: dict[str, tuple[str, str]] = {
"dummy": ("aviary.env", "DummyEnv"),
"calculator": ("aviary.gsm8k.env", "CalculatorEnv"),
"hotpotqa": ("aviary.hotpotqa.env", "HotPotQAEnv"),
"calculator": ("aviary.envs.gsm8k.env", "CalculatorEnv"),
"hotpotqa": ("aviary.envs.hotpotqa.env", "HotPotQAEnv"),
}

TEnvironment = TypeVar("TEnvironment", bound=Environment)
Expand Down Expand Up @@ -319,8 +319,8 @@ def iter_batches(
# Maps baseline task dataset names to their module and class names
TASK_DATASET_REGISTRY: dict[str, tuple[str, str]] = {
"dummy": ("aviary.env", "DummyTaskDataset"),
"gsm8k": ("aviary.gsm8k.env", "GSM8kDataset"),
"hotpotqa": ("aviary.hotpotqa.env", "HotPotQADataset"),
"gsm8k": ("aviary.envs.gsm8k.env", "GSM8kDataset"),
"hotpotqa": ("aviary.envs.hotpotqa.env", "HotPotQADataset"),
}


Expand Down
12 changes: 8 additions & 4 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
import pytest
from pydantic import BaseModel

from aviary.env import DummyEnv, DummyEnvState, Environment, Frame, TaskDataset
from aviary.message import Message
from aviary.render import Renderer
from aviary.tools import (
from aviary.core import (
DummyEnv,
DummyEnvState,
Environment,
Frame,
Message,
Renderer,
TaskDataset,
Tool,
ToolCall,
ToolRequestMessage,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import numpy as np
import pytest

from aviary.message import Message
from aviary.tools import (
from aviary.core import (
Message,
ToolCall,
ToolCallFunction,
ToolRequestMessage,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from pydantic import BaseModel, Field
from pytest_subtests import SubTests

from aviary.env import DummyEnv
from aviary.tools import (
from aviary.core import (
INVALID_TOOL_NAME,
DummyEnv,
FunctionInfo,
Tool,
ToolCall,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from aviary.tools import eval_answer
from aviary.core import eval_answer


@pytest.mark.vcr
Expand Down
Loading