From c8fbd6f84a036907d49f5f07f188b6b8c016a906 Mon Sep 17 00:00:00 2001 From: Wen Chen Date: Mon, 24 Jun 2024 15:17:59 +0000 Subject: [PATCH] Addressed the reviewer's comment. Avoided changing the existing Fp8DotGeneralOp API. --- flax/linen/__init__.py | 2 +- flax/linen/fp8_ops.py | 20 ++++++++++---------- tests/linen/linen_test.py | 8 ++++---- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index 557a00afa4..f01ed92880 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -86,7 +86,7 @@ ) from .batch_apply import BatchApply as BatchApply from .combinators import Sequential as Sequential -from .fp8_ops import OCPFp8DotGeneralOp as OCPFp8DotGeneralOp +from .fp8_ops import Fp8DotGeneralOp as Fp8DotGeneralOp from .fp8_ops import NANOOFp8DotGeneralOp as NANOOFp8DotGeneralOp from .initializers import ( ones_init as ones_init, diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index f33cafaebb..53a65854e6 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -17,6 +17,9 @@ import warnings from functools import partial +from typing import Any +DType = Any + import jax from jax import custom_jvp, custom_vjp, lax, random from jax import numpy as jnp @@ -269,8 +272,9 @@ def dot_general_with_precision_jvp( class Fp8DotGeneralOp(module.Module): - fp8_genre: str = 'OCP' amax_history_length: int = 1024 + e4m3_dtype: DType = jnp.float8_e4m3fn + e5m2_dtype: DType = jnp.float8_e5m2 def setup(self) -> None: scale_args = ( @@ -317,18 +321,16 @@ def __call__(self, *args, **kwargs): comp_dtype = k.dtype x = jnp.asarray(x, comp_dtype) - e4m3_dtype, e5m2_dtype = get_fp8_dtypes(self.fp8_genre) - x_qdq = in_qdq( - comp_dtype, e4m3_dtype, x, self.input_scale.value, self.input_amax_history.value + comp_dtype, self.e4m3_dtype, x, self.input_scale.value, self.input_amax_history.value ) k_qdq = in_qdq( - comp_dtype, e4m3_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value + comp_dtype, self.e4m3_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value ) y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers) # type: ignore y = out_qdq( comp_dtype, - e5m2_dtype, + self.e5m2_dtype, y_qdq, self.output_grad_scale.value, self.output_grad_amax_history.value, @@ -336,8 +338,6 @@ def __call__(self, *args, **kwargs): return y # type: ignore -class OCPFp8DotGeneralOp(Fp8DotGeneralOp): - fp8_genre: str = 'OCP' - class NANOOFp8DotGeneralOp(Fp8DotGeneralOp): - fp8_genre: str = 'NANOO' + e4m3_dtype: DType = jnp.float8_e4m3fnuz + e5m2_dtype: DType = jnp.float8_e5m2fnuz diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index bbc5d75570..0736c6f5f1 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -1271,7 +1271,7 @@ def run(fp8_injection, expected_shapes): p = nn.DenseGeneral(features=64, name='dense') if fp8_injection: if fp8_genre == 'OCP': - p.dot_general_cls = nn.OCPFp8DotGeneralOp + p.dot_general_cls = nn.Fp8DotGeneralOp else: p.dot_general_cls = nn.NANOOFp8DotGeneralOp @@ -1293,7 +1293,7 @@ def _train(variables, x): 'params': {'kernel': (32, 64), 'bias': (64,)}, } if fp8_genre == 'OCP': - fp8_op_name = 'OCPFp8DotGeneralOp_0' + fp8_op_name = 'Fp8DotGeneralOp_0' else: fp8_op_name = 'NANOOFp8DotGeneralOp_0' expected_shapes_new = { @@ -1326,8 +1326,8 @@ def test_fp8_train_state(self, fp8_genre): key, init_key, random_key = random.split(random.PRNGKey(seed=123), 3) x = random.uniform(random_key, (16, 16), dtype=jnp.float32) if fp8_genre == 'OCP': - fp8_dot_op = nn.OCPFp8DotGeneralOp - fp8_op_name = 'OCPFp8DotGeneralOp_0' + fp8_dot_op = nn.Fp8DotGeneralOp + fp8_op_name = 'Fp8DotGeneralOp_0' else: fp8_dot_op = nn.NANOOFp8DotGeneralOp fp8_op_name = 'NANOOFp8DotGeneralOp_0'