Skip to content

Commit

Permalink
Merge pull request #3441 from kaixih:update_scale_compute
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 578045483
  • Loading branch information
Flax Authors committed Oct 31, 2023
2 parents 6f69bb6 + 3a0e87a commit 8d09772
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions flax/linen/fp8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,16 @@ def quantize_dequantize(x, q_dtype, scale, compute_dtype):


def compute_scale(amax, scale, fp8_max, margin=0):
"""Default function to convert amax to scaling factor."""
# This function copied from the TransformerEngine is used to compute its
# `scale`. However, our scale matches its `scale_inv` concept. So, we apply
# the reciprocal operation at the entry and exit of the function.
# The algorithm for computing the new scale is sourced from
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#transformer_engine.jax.update_fp8_metas
# wherein the `original_scale` corresponds to the reciprocal of the `scale`
# passed in this function.
scale = 1.0 / scale
exp = jnp.floor(jnp.log2(fp8_max / amax)) - margin
sf = jnp.round(lax.pow(2., jnp.abs(exp)))

sf = (fp8_max / amax) / (2**margin)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(lax.is_finite(amax), sf, scale)
sf = jnp.where(exp < 0, 1.0 / sf, sf)
sf = jnp.where(jnp.isfinite(amax), sf, scale)

return 1.0 / sf


Expand Down Expand Up @@ -155,7 +155,7 @@ def setup(self) -> None:
OVERWRITE_WITH_GRADIENT, 'output_grad_scale', *scale_args)


def __call__(self, *args, **kwargs) -> jnp.ndarray:
def __call__(self, *args, **kwargs):

assert len(args) == 3
x = args[0]
Expand Down

0 comments on commit 8d09772

Please sign in to comment.