Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does VICREG loss work correctly in multi-gpus setting? #535

Closed
NoTody opened this issue Oct 8, 2022 · 2 comments
Closed

Does VICREG loss work correctly in multi-gpus setting? #535

NoTody opened this issue Oct 8, 2022 · 2 comments
Milestone

Comments

@NoTody
Copy link
Contributor

NoTody commented Oct 8, 2022

Hi, Kevin

I notice that the current implementation of VICREG doesn't have the gather layer for the input embeddings. Does it still work as expected to get embeddings from mini-batches in multi-gpus in this case? See https://github.com/facebookresearch/vicreg/blob/a73f567660ae507b0667c68f685945ae6e2f62c3/main_vicreg.py#L200 for original implementation of VICREG.

Regards,
Notody

@KevinMusgrave
Copy link
Owner

Yeah I don't think it does. For me, the ideal solution would be to make it compatible with DistributedLossWrapper. That means either

  • making it have the same input format as BaseMetricLossFunction. I don't know if there's a way to do this that makes sense.
  • Adding an if-statement in DistributedLossWrapper to catch the VICReg loss type

@KevinMusgrave KevinMusgrave added this to the v2.0 milestone Jan 21, 2023
@KevinMusgrave
Copy link
Owner

In v2.0 this will work:

loss_fn = DistributedLossWrapper(loss=VICRegLoss())
loss = loss_fn(embeddings, ref_emb=ref_emb)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants