Skip to content

Commit

Permalink
Merge pull request #35 from kaixih:use_fast_accumulation_fp8
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 582770579
  • Loading branch information
pax authors committed Nov 15, 2023
2 parents 5420b56 + 2cefb21 commit e7c8561
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion praxis/layers/injection/fp8_nvidia_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def __call__(self, equation: str, *args: pytypes.JTensor) -> pytypes.JTensor:
k_qdq = fp8_ops.in_qdq(
comp_dtype, k, theta.kernel_scale, theta.kernel_amax_history
)
y_qdq = jnp.einsum(equation, x_qdq, k_qdq)
y_qdq = jnp.einsum(
equation, x_qdq, k_qdq, _dot_general=fp8_ops.dot_general_with_precision
)
y = fp8_ops.out_qdq(
comp_dtype,
y_qdq,
Expand Down

0 comments on commit e7c8561

Please sign in to comment.