-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added main entry point #109
Changes from 5 commits
c8c79a7
cc2e5c9
e55a561
a00b757
bda1c90
a438aa1
67e55e1
103e83d
fa93902
0840ccc
87ad64c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -475,3 +475,53 @@ async def after_eval_loop(self) -> None: | |
await super().after_eval_loop() # Call the parent to compute means | ||
if self.eval_means: | ||
self._log_filtered_metrics(self.eval_means, step_type="Eval") | ||
|
||
|
||
class TerminalLoggingCallback(Callback): | ||
"""Callback that prints action, observation, and timing information to the terminal.""" | ||
jamesbraza marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __init__(self): | ||
self.start_time = None | ||
# try now, rather than start running and die | ||
try: | ||
from rich.pretty import pprint # noqa: F401 | ||
except ImportError: | ||
raise ImportError( | ||
f"rich is required for {type(self).__name__}. Please install it with `pip install rich`." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should make a Or we can leave as is |
||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not |
||
|
||
async def before_transition( | ||
self, | ||
traj_id: str, | ||
agent: Agent, | ||
env: Environment, | ||
agent_state: Any, | ||
obs: list[Message], | ||
) -> None: | ||
"""Start the timer before each transition.""" | ||
self.start_time = time.time() | ||
|
||
async def after_agent_get_asv( | ||
self, | ||
traj_id: str, | ||
action: OpResult[ToolRequestMessage], | ||
next_agent_state: Any, | ||
value: float, | ||
) -> None: | ||
from rich.pretty import pprint | ||
print("\nAction:") | ||
pprint(action.value, expand_all=True) | ||
Comment on lines
+513
to
+514
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why have 2+ print invocations? Can we make it just one with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are different - There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah why not just always use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How would you concat them?
and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see, you're right. Thanks for going through this with me 👍 |
||
|
||
async def after_env_step( | ||
self, traj_id: str, obs: list[Message], reward: float, done: bool, trunc: bool | ||
) -> None: | ||
from rich.pretty import pprint | ||
# Compute elapsed time | ||
if self.start_time is not None: | ||
elapsed_time = time.time() - self.start_time | ||
self.start_time = None # Reset timer | ||
else: | ||
elapsed_time = 0.0 | ||
print("\nObservation:") | ||
pprint(obs, expand_all=True) | ||
print(f"Elapsed time: {elapsed_time:.2f} seconds") |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,73 @@ | ||||||||||||||||||||||||
import argparse | ||||||||||||||||||||||||
import asyncio | ||||||||||||||||||||||||
import pickle | ||||||||||||||||||||||||
from os import PathLike | ||||||||||||||||||||||||
from pathlib import Path | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
from aviary.env import Environment | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
from ldp.agent import Agent | ||||||||||||||||||||||||
from ldp.alg.callbacks import TerminalLoggingCallback | ||||||||||||||||||||||||
from ldp.alg.rollout import RolloutManager | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
def agent_factory(agent: Agent | str | PathLike) -> Agent: | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Normally a factory is like a Also, can you add a docstring stating the possible modes checked? """
Construct an agent.
Args:
agent: Either an agent instance, a name for an agent, or a path to an agent pickle.
""" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is more logically connected to parsing user input - I extracted it out of I changed the name. Feel like the type hints and name are enough documentation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you think of naming this |
||||||||||||||||||||||||
if isinstance(agent, Agent): | ||||||||||||||||||||||||
return agent | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if isinstance(agent, str): | ||||||||||||||||||||||||
try: | ||||||||||||||||||||||||
return Agent.from_name(agent) | ||||||||||||||||||||||||
except KeyError: | ||||||||||||||||||||||||
pass | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
path = Path(agent) | ||||||||||||||||||||||||
if not path.exists(): | ||||||||||||||||||||||||
raise ValueError(f"Could not resolve agent: {agent}") | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
with path.open("rb") as f: | ||||||||||||||||||||||||
return pickle.load(f) # noqa: S301 | ||||||||||||||||||||||||
Comment on lines
+23
to
+28
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Alternate way of doing this |
||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
def environment_factory(environment: Environment | str, task: str) -> Environment: | ||||||||||||||||||||||||
if isinstance(environment, Environment): | ||||||||||||||||||||||||
return environment | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if isinstance(environment, str): | ||||||||||||||||||||||||
try: | ||||||||||||||||||||||||
return Environment.from_name(environment, task=task) | ||||||||||||||||||||||||
except ValueError: | ||||||||||||||||||||||||
pass | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||
f"Could not resolve environment: {environment}. Available environments: {Environment.available()}" | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
async def main( | ||||||||||||||||||||||||
task: str, | ||||||||||||||||||||||||
environment: Environment | str, | ||||||||||||||||||||||||
agent: Agent | str | PathLike = "SimpleAgent", | ||||||||||||||||||||||||
): | ||||||||||||||||||||||||
agent = agent_factory(agent) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
callback = TerminalLoggingCallback() | ||||||||||||||||||||||||
rollout_manager = RolloutManager(agent=agent, callbacks=[callback]) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
_ = await rollout_manager.sample_trajectories( | ||||||||||||||||||||||||
environment_factory=lambda: environment_factory(environment, task) | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
if __name__ == "__main__": | ||||||||||||||||||||||||
parser = argparse.ArgumentParser() | ||||||||||||||||||||||||
parser.add_argument("task", help="Task to prompt environment with.") | ||||||||||||||||||||||||
parser.add_argument( | ||||||||||||||||||||||||
"--env", required=True, help="Environment to sample trajectories from." | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
parser.add_argument( | ||||||||||||||||||||||||
"--agent", default="SimpleAgent", help="Agent to sample trajectories with." | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
args = parser.parse_args() | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
asyncio.run(main(args.task, args.env, args.agent)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,7 @@ classifiers = [ | |
dependencies = [ | ||
"aiofiles", | ||
"dm-tree", | ||
"fhaviary>=0.6", # For MalformedMessageError | ||
"fhaviary>=0.8", # For from_task | ||
"httpx", | ||
"litellm>=1.40.15", # For LITELLM_LOG addition | ||
"networkx[default]~=3.4", # Pin for pydot fix | ||
|
@@ -317,6 +317,7 @@ ignore = [ | |
"ARG003", # Thrown all the time when we are subclassing | ||
"ASYNC109", # Buggy, SEE: https://github.com/astral-sh/ruff/issues/12353 | ||
"ASYNC2", # It's ok to mix async and sync ops (like opening a file) | ||
"B904", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like B904 because it makes the code explicit, but I am not that attached to it. Would you be open to reverting this? I don't think this PR uses it |
||
"BLE001", # Don't care to enforce blind exception catching | ||
"COM812", # Trailing comma with black leads to wasting lines | ||
"D100", # D100, D101, D102, D103, D104, D105, D106, D107: don't always need docstrings | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to ignore for this PR, but we should refactor our logging callback to be general, and move this terminal logging callback to be a composition of the logging callback with a stdout handler or a
RichHandler
Our callbacks are a point of tech debt in this framework imo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think they do different things though. The one I made is just printing the action actions/observations. The logging callback is a MeanMetricsCallback that logs metrics to logs. Feel like they're pretty different. Maybe we need a guide or something for logging or callbacks :)