Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix Unet global_cond_dim to use state dim, not action dim #278

Merged
merged 1 commit into from
Jun 17, 2024

Conversation

ojh6404
Copy link
Contributor

@ojh6404 ojh6404 commented Jun 17, 2024

This PR fixes global_cond_dim of Diffusion's unet when dim of observation.state != dim of action


This change is Reviewable

@aliberts aliberts added the 🧠 Policies Something policies-related label Jun 17, 2024
Copy link
Collaborator

@alexander-soare alexander-soare left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved! Pending on fixing the linting error.

Thanks for this!

@@ -165,7 +165,7 @@ def __init__(self, config: DiffusionConfig):
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
self.unet = DiffusionConditionalUnet1d(
config,
global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim * num_images)
global_cond_dim=(config.input_shapes["observation.state"][0] + self.rgb_encoder.feature_dim * num_images)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change should fix the linting test

Suggested change
global_cond_dim=(config.input_shapes["observation.state"][0] + self.rgb_encoder.feature_dim * num_images)
global_cond_dim=(
config.input_shapes["observation.state"][0] + self.rgb_encoder.feature_dim * num_images
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexander-soare Actually, I just fixed linting problem on my forked repo.

@alexander-soare alexander-soare merged commit b72d574 into huggingface:main Jun 17, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🧠 Policies Something policies-related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants