Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kernels for GroupNorm #353

Merged
merged 23 commits into from
Nov 7, 2024
Merged

Conversation

pramodith
Copy link
Collaborator

@pramodith pramodith commented Nov 5, 2024

Summary

Implementation of group norm that achieves output parity with torch's GroupNorm.

This is feature is a part of #285

Details

The formulas/equations involved in GroupNorm are the same as LayerNorm/BatchNorm. The main differences lie in the axis along which the mean and std are computed + the dimensions of the Affine transformation parameters.

In group norm W and B are of shape (n_channels), however the mean and std are calculated over all the channels in a given group.

Testing Done

Testing was done on a A100 PCIE and a A100 SXM-4.

We see an increase in speed, while the total memory used remains about the same. Note that benchmarking was done using a batch size of 128, Hidden Dim size of 512 and the number of channels per group fixed at 4.

These results look very similar to the layer norm benchmark too.

group_norm_memory

group_norm_speed

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@pramodith pramodith marked this pull request as ready for review November 5, 2024 19:44
c2 += tl.sum(wdy)

# Need to ensure additions to the same channel are atomic
tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ByronHsu is it possible for us to test on multiple GPU, specifically around

scope (str, optional) – Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are “gpu” (default), “cta” (cooperative thread array, thread block), or “sys” (stands for “SYSTEM”). The default value is “gpu”.

whether the default value works for multi-gpu.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what kind of testing? run on a 4 GPUs env to ensure the kernel working fine on a single GPU? Not sure how this is related to multi gpu. My understanding is the kernel only happens on 1 gpu

@lancerts
Copy link
Collaborator

lancerts commented Nov 7, 2024

Very solid PR!

lancerts
lancerts previously approved these changes Nov 7, 2024
@lancerts lancerts requested a review from ByronHsu November 7, 2024 19:59
@ByronHsu ByronHsu merged commit a954b73 into linkedin:main Nov 7, 2024
2 checks passed
@ByronHsu
Copy link
Collaborator

ByronHsu commented Nov 7, 2024

@pramodith can you update the readme to include groupnorm

@pramodith
Copy link
Collaborator Author

@pramodith can you update the readme to include groupnorm

Will do tomorrow!

@ByronHsu ByronHsu mentioned this pull request Nov 8, 2024
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants