Skip to content

Commit

Permalink
Fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
wenscarl committed Nov 3, 2023
1 parent 8d84c50 commit d307996
Showing 1 changed file with 28 additions and 22 deletions.
50 changes: 28 additions & 22 deletions flax/linen/fp8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,28 +123,34 @@ def out_qdq_bwd(compute_dtype, res, g):


@partial(custom_jvp, nondiff_argnums=(2, 3, 4))
def dot_general_with_precision(lhs, rhs, dimension_numbers, precision=None,
preferred_element_type=None):
def dot_general_with_precision(
lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None
):
if precision != None or preferred_element_type != None:
warnings.warn("The function dot_general_with_precision will set the "
"precision/preferred_element_type and disregard any provided "
"values.")
return lax.dot_general(lhs, rhs, dimension_numbers,
precision=lax.Precision.DEFAULT)

@dot_general_with_precision.defjvp
def dot_general_with_precision_jvp(dimension_numbers, precision,
preferred_element_type, primals, tangents):
lhs, rhs = primals
lhs_dot, rhs_dot = tangents

out = lax.dot_general(lhs, rhs, dimension_numbers,
precision=lax.Precision.DEFAULT)
grad_out = (lax.dot_general(lhs_dot, rhs, dimension_numbers,
precision=lax.Precision.HIGHEST) +
lax.dot_general(lhs, rhs_dot, dimension_numbers,
precision=lax.Precision.HIGHEST))
return out, grad_out
warnings.warn(
'The function dot_general_with_precision will set the '
'precision/preferred_element_type and disregard any provided '
'values.')
return lax.dot_general(
lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT
)

def dot_general_with_precision_jvp(
dimension_numbers, precision, preferred_element_type, primals, tangents
):
lhs, rhs = primals
lhs_dot, rhs_dot = tangents

out = lax.dot_general(
lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT
)
grad_out = lax.dot_general(
lhs_dot, rhs, dimension_numbers, precision=lax.Precision.HIGHEST
) + lax.dot_general(
lhs, rhs_dot, dimension_numbers, precision=lax.Precision.HIGHEST
)
return out, grad_out


class Fp8DotGeneralOp(module.Module):
amax_history_length: int = 1024
Expand Down Expand Up @@ -200,7 +206,7 @@ def __call__(self, *args, **kwargs):
k_qdq = in_qdq(
comp_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value
)
y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers) # type: ignore
y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers) # type: ignore
y = out_qdq(
comp_dtype,
y_qdq,
Expand Down

0 comments on commit d307996

Please sign in to comment.