Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added main entry point #109

Merged
merged 11 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ repos:
- id: mypy
additional_dependencies:
- fastapi>=0.109 # Match pyproject.toml
- fhaviary>=0.6 # Match pyproject.toml
- fhaviary>=0.8 # Match pyproject.toml
- httpx
- litellm>=1.49.3 # Match pyproject.toml
- numpy>=1.20 # Match pyproject.toml
Expand Down
6 changes: 5 additions & 1 deletion ldp/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def named_ops(self) -> Iterable[tuple[str, Op]]:
"""Analogous to torch.nn.Module.named_parameters()."""
return _find_ops(self)

@classmethod
def from_name(cls, name: str, **kwargs) -> Agent:
return _AGENT_REGISTRY[name](**kwargs)


class AgentConfig(BaseModel):
"""Configuration for specifying the type of agent i.e. the subclass of Agent above."""
Expand All @@ -96,7 +100,7 @@ class AgentConfig(BaseModel):
)

def construct_agent(self) -> Agent:
return _AGENT_REGISTRY[self.agent_type](**self.agent_kwargs)
return Agent.from_name(self.agent_type, **self.agent_kwargs)

def __hash__(self) -> int:
return hash(self.agent_type + json.dumps(self.agent_kwargs, sort_keys=True))
Expand Down
50 changes: 50 additions & 0 deletions ldp/alg/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,3 +475,53 @@ async def after_eval_loop(self) -> None:
await super().after_eval_loop() # Call the parent to compute means
if self.eval_means:
self._log_filtered_metrics(self.eval_means, step_type="Eval")


class TerminalLoggingCallback(Callback):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to ignore for this PR, but we should refactor our logging callback to be general, and move this terminal logging callback to be a composition of the logging callback with a stdout handler or a RichHandler

Our callbacks are a point of tech debt in this framework imo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think they do different things though. The one I made is just printing the action actions/observations. The logging callback is a MeanMetricsCallback that logs metrics to logs. Feel like they're pretty different. Maybe we need a guide or something for logging or callbacks :)

"""Callback that prints action, observation, and timing information to the terminal."""
jamesbraza marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self):
self.start_time = None
# try now, rather than start running and die
try:
from rich.pretty import pprint # noqa: F401
except ImportError:
raise ImportError(
f"rich is required for {type(self).__name__}. Please install it with `pip install rich`."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make a rich-specific extra and move our current rich extra to be rich-progress

Or we can leave as is

)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not raise ImportError("...") from exc? Maybe it's me, but I find having the "niceified" error connected to the un-niceified error to be useful


async def before_transition(
self,
traj_id: str,
agent: Agent,
env: Environment,
agent_state: Any,
obs: list[Message],
) -> None:
"""Start the timer before each transition."""
self.start_time = time.time()

async def after_agent_get_asv(
self,
traj_id: str,
action: OpResult[ToolRequestMessage],
next_agent_state: Any,
value: float,
) -> None:
from rich.pretty import pprint
print("\nAction:")
pprint(action.value, expand_all=True)
Comment on lines +513 to +514
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why have 2+ print invocations? Can we make it just one with \n?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are different - print vs pprint?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah why not just always use pprint?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would you concat them?

pprint("Action\n", action.value) is not valid and pprint(f"Action\n{str(action.value)}") defeats the purpose of using pprint

and pprint("Action\n") prints the python rep of that string instead of literally printing the words Action followed by a newline.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, you're right. Thanks for going through this with me 👍


async def after_env_step(
self, traj_id: str, obs: list[Message], reward: float, done: bool, trunc: bool
) -> None:
from rich.pretty import pprint
# Compute elapsed time
if self.start_time is not None:
elapsed_time = time.time() - self.start_time
self.start_time = None # Reset timer
else:
elapsed_time = 0.0
print("\nObservation:")
pprint(obs, expand_all=True)
print(f"Elapsed time: {elapsed_time:.2f} seconds")
1 change: 1 addition & 0 deletions ldp/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def from_jsonl(cls, filename: str | os.PathLike) -> Self:
return traj

def compute_discounted_returns(self, discount: float = 1.0) -> list[float]:
"""Compute the discounted returns for each step in the trajectory."""
return discounted_returns(
rewards=[step.reward for step in self.steps],
terminated=[step.truncated for step in self.steps],
Expand Down
73 changes: 73 additions & 0 deletions ldp/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import argparse
import asyncio
import pickle
from os import PathLike
from pathlib import Path

from aviary.env import Environment

from ldp.agent import Agent
from ldp.alg.callbacks import TerminalLoggingCallback
from ldp.alg.rollout import RolloutManager


def agent_factory(agent: Agent | str | PathLike) -> Agent:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally a factory is like a classmethod going from a type to an instance. Can we have this support type[Agent]?

Also, can you add a docstring stating the possible modes checked?

"""
Construct an agent.

Args:
    agent: Either an agent instance, a name for an agent, or a path to an agent pickle.
"""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more logically connected to parsing user input - I extracted it out of main - and so I think it's better here where I use it.

I changed the name. Feel like the type hints and name are enough documentation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of naming this resolve_agent? Also, feel free to ignore this one

if isinstance(agent, Agent):
return agent

if isinstance(agent, str):
try:
return Agent.from_name(agent)
except KeyError:
pass

path = Path(agent)
if not path.exists():
raise ValueError(f"Could not resolve agent: {agent}")

with path.open("rb") as f:
return pickle.load(f) # noqa: S301
Comment on lines +23 to +28
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
path = Path(agent)
if not path.exists():
raise ValueError(f"Could not resolve agent: {agent}")
with path.open("rb") as f:
return pickle.load(f) # noqa: S301
try:
with path.open("rb") as f:
return pickle.load(f) # noqa: S301
except FileNotFoundError:
raise ValueError(f"Could not resolve agent: {agent}") from None

Alternate way of doing this



def environment_factory(environment: Environment | str, task: str) -> Environment:
if isinstance(environment, Environment):
return environment

if isinstance(environment, str):
try:
return Environment.from_name(environment, task=task)
except ValueError:
pass

raise ValueError(
f"Could not resolve environment: {environment}. Available environments: {Environment.available()}"
)


async def main(
task: str,
environment: Environment | str,
agent: Agent | str | PathLike = "SimpleAgent",
):
agent = agent_factory(agent)

callback = TerminalLoggingCallback()
rollout_manager = RolloutManager(agent=agent, callbacks=[callback])

_ = await rollout_manager.sample_trajectories(
environment_factory=lambda: environment_factory(environment, task)
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("task", help="Task to prompt environment with.")
parser.add_argument(
"--env", required=True, help="Environment to sample trajectories from."
)
parser.add_argument(
"--agent", default="SimpleAgent", help="Agent to sample trajectories with."
)
args = parser.parse_args()

asyncio.run(main(args.task, args.env, args.agent))
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ classifiers = [
dependencies = [
"aiofiles",
"dm-tree",
"fhaviary>=0.6", # For MalformedMessageError
"fhaviary>=0.8", # For from_task
"httpx",
"litellm>=1.40.15", # For LITELLM_LOG addition
"networkx[default]~=3.4", # Pin for pydot fix
Expand Down Expand Up @@ -317,6 +317,7 @@ ignore = [
"ARG003", # Thrown all the time when we are subclassing
"ASYNC109", # Buggy, SEE: https://github.com/astral-sh/ruff/issues/12353
"ASYNC2", # It's ok to mix async and sync ops (like opening a file)
"B904",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like B904 because it makes the code explicit, but I am not that attached to it.

Would you be open to reverting this? I don't think this PR uses it

"BLE001", # Don't care to enforce blind exception catching
"COM812", # Trailing comma with black leads to wasting lines
"D100", # D100, D101, D102, D103, D104, D105, D106, D107: don't always need docstrings
Expand Down
Loading