Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made it easier to instantiate environments #74

Merged
merged 4 commits into from
Oct 18, 2024
Merged

Made it easier to instantiate environments #74

merged 4 commits into from
Oct 18, 2024

Conversation

whitead
Copy link
Contributor

@whitead whitead commented Oct 17, 2024

Added a new way to create environments based on an explicit task. Also made it possible to see available environments (although it is approximate).

@dosubot dosubot bot added size:M This PR changes 30-99 lines, ignoring generated files. enhancement New feature or request labels Oct 17, 2024
@whitead whitead requested a review from albertbou92 October 17, 2024 22:41
Comment on lines 255 to 260
@classmethod
def from_name(cls, name: str, task: str | None = None, **env_kwargs) -> Self:
new_cls = _get_cls_from_name(ENV_REGISTRY, name)
if task is not None:
return new_cls.from_task(task)
return new_cls(**env_kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this become a problem if a subclass uses task in its init?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add a check for that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok - added a check

src/aviary/env.py Outdated Show resolved Hide resolved
self.end_immediately = end_immediately
self.task = task
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we name this self.subject = task?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would rather just keep it consistent - don't think it's so important here about what is done with the task but like to be able to trace its path.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good

src/aviary/env.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@jamesbraza jamesbraza left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.end_immediately = end_immediately
self.task = task
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good

@@ -432,7 +469,7 @@ def __bool__(self) -> bool:
return True


def _construct_obj_from_name(registry: dict[str, tuple[str, str]], name: str, **kwargs):
def _get_cls_from_name(registry: dict[str, tuple[str, str]], name: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice change here, like it

@@ -248,8 +248,35 @@ async def close(self) -> None:
"""

@classmethod
def from_name(cls, name: str, **env_kwargs) -> Self:
return _construct_obj_from_name(ENV_REGISTRY, name, **env_kwargs)
def from_task(cls, task: str) -> Self:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand the use-case for from_task. Is there a specific environment where this kind of behavior is desirable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand is for inference time?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably task is not an arbitrary string as a user prompt would be? It seems as though it must correspond to a valid problem_id?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for inference time - so that you can have user defined tasks instead of tasks coming from a training or eval set.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An example would help I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See Future-House/ldp#109

This is to enable scripts/entry points so that an end user can use the environments

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a comment above, it makes the most sense to me for environments like HotpotQA where question can be open-ended. I don't know what happens, however, when the user passes in an arbitrary problem_id to environments like GSM8K? Similarly in cloning, I can't see where self.problem_id is used?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed that the problem argument is also being set as task

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That now makes sense

@Ryan-Rhys
Copy link
Contributor

68747470733a2f2f6d65646961302e67697068792e636f6d2f6d656469612f336f6868777069546b4c4362317a734354432f67697068792e676966

LOL

Copy link
Contributor

@albertbou92 albertbou92 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@@ -50,6 +50,10 @@ def __init__(
self.check_tool = Tool.from_function(self.check_answer)
self.tools = [self.calc_tool, self.check_tool]

@classmethod
def from_task(cls, task: str) -> "CalculatorEnv":
return cls(problem_id="task", problem=task, answer=0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is self.problem_id used in the GSM8K environment aside from being exported as a dictionary in export_frame?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know - it's a required argument so I just put a placeholder. Do you think I should refactor to make id optional?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I missed that task is being passed in to problem

@@ -191,6 +191,10 @@ def __init__(
create_tool(self.finish, "Finish"),
]

@classmethod
def from_task(cls, task: str) -> "HotPotQAEnv":
return cls(question=task, correct_answer=0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think HotpotQA makes the most sense intuitively for me since question can be open-ended allowing the user to pass an arbitrary question at inference time.

in calling an LLM. This is how the environment should be used after training
and in deployment. We don't take config here, because the default environment config
should be general for arbitrary tasks. Or, the config should be coupled to the agent
training (future TODO).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth adding some examples in the docstring e.g.

  1. For the HotpotQA environment, a question not featured in the HotpotQA dataset.
  2. For the GSM8K environment, a math word question not featured in the GSM8K dataset.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok - will add these in next revision set

@whitead whitead merged commit a20118e into main Oct 18, 2024
5 of 6 checks passed
@whitead whitead deleted the env-tass branch October 18, 2024 18:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request size:M This PR changes 30-99 lines, ignoring generated files.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants