-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Goal conditioning integration #5142
Conversation
@@ -92,6 +92,11 @@ class ScheduleType(Enum): | |||
LINEAR = "linear" | |||
|
|||
|
|||
class ConditioningType(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like we discussed in the design doc, we probably don't need anything more than hyper, but it might still be worth keeping this just in case. For instances where the user finds out that the hypernetwork hurts performance on their task for whatever reason, it might be easier to disable it in the trainer than to rebuild their environment without goals.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made the default hyper and the no conditioning the second option
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be hypernetwork
- is that too long?
Co-authored-by: Arthur Juliani <awjuliani@gmail.com>
…ity-Technologies/ml-agents into goal-conditioning-integration
@@ -115,6 +120,7 @@ def _check_valid_memory_size(self, attribute, value): | |||
num_layers: int = 2 | |||
vis_encode_type: EncoderType = EncoderType.SIMPLE | |||
memory: Optional[MemorySettings] = None | |||
conditioning_type: ConditioningType = ConditioningType.HYPER |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we maybe use goal_conditioning_type
?
@@ -79,9 +94,6 @@ def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: | |||
""" | |||
Encode observations using a list of processors and an RSA. | |||
:param inputs: List of Tensors corresponding to a set of obs. | |||
:param processors: a ModuleList of the input processors to be applied to these obs. | |||
:param rsa: Optionally, an RSA to use for variable length obs. | |||
:param x_self_encoder: Optionally, an encoder to use for x_self (in this case, the non-variable inputs.). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for removing this 😅
encoding = self.linear_encoder(encoded_self) | ||
if isinstance(self.linear_encoder, ConditionalEncoder): | ||
goal = self.observation_encoder.get_goal_encoding(inputs) | ||
encoding = self.linear_encoder(encoded_self, goal) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're going to do both Conditional and Linear I think we should rename it to just encoder
or body_encoder
, but not really that necessary
@property | ||
def total_goal_enc_size(self) -> int: | ||
""" | ||
Returns the total encoding size for this ObservationEncoder. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Returns the total encoding size for this ObservationEncoder. | |
Returns the total goal encoding size for this ObservationEncoder. |
Proposed change(s)
[Do not merge, first merge the modules, then merge this into main]
Integration of the Hypernetworks into network body
Useful links (Github issues, JIRA tickets, ML-Agents forum threads etc.)
Types of change(s)
Checklist
Other comments