Skip to content

Commit

Permalink
Addressed the reviewer's comment. Avoided changing the existing
Browse files Browse the repository at this point in the history
Fp8DotGeneralOp API.
  • Loading branch information
wenchenvincent committed Jun 24, 2024
1 parent 4edbd2b commit c8fbd6f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
)
from .batch_apply import BatchApply as BatchApply
from .combinators import Sequential as Sequential
from .fp8_ops import OCPFp8DotGeneralOp as OCPFp8DotGeneralOp
from .fp8_ops import Fp8DotGeneralOp as Fp8DotGeneralOp
from .fp8_ops import NANOOFp8DotGeneralOp as NANOOFp8DotGeneralOp
from .initializers import (
ones_init as ones_init,
Expand Down
20 changes: 10 additions & 10 deletions flax/linen/fp8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
import warnings
from functools import partial

from typing import Any
DType = Any

import jax
from jax import custom_jvp, custom_vjp, lax, random
from jax import numpy as jnp
Expand Down Expand Up @@ -269,8 +272,9 @@ def dot_general_with_precision_jvp(


class Fp8DotGeneralOp(module.Module):
fp8_genre: str = 'OCP'
amax_history_length: int = 1024
e4m3_dtype: DType = jnp.float8_e4m3fn
e5m2_dtype: DType = jnp.float8_e5m2

def setup(self) -> None:
scale_args = (
Expand Down Expand Up @@ -317,27 +321,23 @@ def __call__(self, *args, **kwargs):
comp_dtype = k.dtype
x = jnp.asarray(x, comp_dtype)

e4m3_dtype, e5m2_dtype = get_fp8_dtypes(self.fp8_genre)

x_qdq = in_qdq(
comp_dtype, e4m3_dtype, x, self.input_scale.value, self.input_amax_history.value
comp_dtype, self.e4m3_dtype, x, self.input_scale.value, self.input_amax_history.value
)
k_qdq = in_qdq(
comp_dtype, e4m3_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value
comp_dtype, self.e4m3_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value
)
y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers) # type: ignore
y = out_qdq(
comp_dtype,
e5m2_dtype,
self.e5m2_dtype,
y_qdq,
self.output_grad_scale.value,
self.output_grad_amax_history.value,
)

return y # type: ignore

class OCPFp8DotGeneralOp(Fp8DotGeneralOp):
fp8_genre: str = 'OCP'

class NANOOFp8DotGeneralOp(Fp8DotGeneralOp):
fp8_genre: str = 'NANOO'
e4m3_dtype: DType = jnp.float8_e4m3fnuz
e5m2_dtype: DType = jnp.float8_e5m2fnuz
8 changes: 4 additions & 4 deletions tests/linen/linen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,7 +1271,7 @@ def run(fp8_injection, expected_shapes):
p = nn.DenseGeneral(features=64, name='dense')
if fp8_injection:
if fp8_genre == 'OCP':
p.dot_general_cls = nn.OCPFp8DotGeneralOp
p.dot_general_cls = nn.Fp8DotGeneralOp
else:
p.dot_general_cls = nn.NANOOFp8DotGeneralOp

Expand All @@ -1293,7 +1293,7 @@ def _train(variables, x):
'params': {'kernel': (32, 64), 'bias': (64,)},
}
if fp8_genre == 'OCP':
fp8_op_name = 'OCPFp8DotGeneralOp_0'
fp8_op_name = 'Fp8DotGeneralOp_0'
else:
fp8_op_name = 'NANOOFp8DotGeneralOp_0'
expected_shapes_new = {
Expand Down Expand Up @@ -1326,8 +1326,8 @@ def test_fp8_train_state(self, fp8_genre):
key, init_key, random_key = random.split(random.PRNGKey(seed=123), 3)
x = random.uniform(random_key, (16, 16), dtype=jnp.float32)
if fp8_genre == 'OCP':
fp8_dot_op = nn.OCPFp8DotGeneralOp
fp8_op_name = 'OCPFp8DotGeneralOp_0'
fp8_dot_op = nn.Fp8DotGeneralOp
fp8_op_name = 'Fp8DotGeneralOp_0'
else:
fp8_dot_op = nn.NANOOFp8DotGeneralOp
fp8_op_name = 'NANOOFp8DotGeneralOp_0'
Expand Down

0 comments on commit c8fbd6f

Please sign in to comment.