Skip to content

Commit

Permalink
Init fp8 dense
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Mar 29, 2024
1 parent a063684 commit 9eb9629
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 0 deletions.
178 changes: 178 additions & 0 deletions keras/layers/core/dense.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import ml_dtypes

from keras import activations
from keras import backend
from keras import constraints
Expand Down Expand Up @@ -81,6 +83,8 @@ def __init__(
kernel_constraint=None,
bias_constraint=None,
lora_rank=None,
amax_history_length=1024,
use_fp8=False,
**kwargs,
):
super().__init__(activity_regularizer=activity_regularizer, **kwargs)
Expand All @@ -94,6 +98,8 @@ def __init__(
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.lora_rank = lora_rank
self.amax_history_length = amax_history_length
self.use_fp8 = use_fp8
self.lora_enabled = False
self.input_spec = InputSpec(min_ndim=2)
self.supports_masking = True
Expand Down Expand Up @@ -126,6 +132,8 @@ def build(self, input_shape):
self.built = True
if self.lora_rank:
self.enable_lora(self.lora_rank)
if self.use_fp8:
self.fp8_build(input_shape)

@property
def kernel(self):
Expand All @@ -140,6 +148,8 @@ def kernel(self):
return self._kernel

def call(self, inputs):
if self.use_fp8:
return self.fp8_call(inputs)
x = ops.matmul(inputs, self.kernel)
if self.bias is not None:
x = ops.add(x, self.bias)
Expand Down Expand Up @@ -234,6 +244,7 @@ def get_config(self):
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
"kernel_constraint": constraints.serialize(self.kernel_constraint),
"bias_constraint": constraints.serialize(self.bias_constraint),
"amax_history_length": self.amax_history_length,
}
if self.lora_rank:
config["lora_rank"] = self.lora_rank
Expand Down Expand Up @@ -402,3 +413,170 @@ def _get_kernel_with_merged_lora(self):
kernel_scale = ops.squeeze(kernel_scale, axis=0)
return kernel_value, kernel_scale
return self.kernel, None

"""FP8-related methods"""

def fp8_build(self, input_shape):
amax_history_kwargs = {
"shape": (self.amax_history_length,),
"initializer": "zeros",
"trainable": False,
"autocast": False,
}
scale_kwargs = {
"shape": (),
"initializer": "ones",
"trainable": False,
"autocast": False,
}
self.inputs_amax_history = self.add_weight(
name="inputs_amax_history", **amax_history_kwargs
)
self.inputs_scale = self.add_weight(name="inputs_scale", **scale_kwargs)
self.kernel_amax_history = self.add_weight(
name="kernel_amax_history", **amax_history_kwargs
)
self.kernel_scale = self.add_weight(name="kernel_scale", **scale_kwargs)
self.outputs_grad_amax_history = self.add_weight(
name="outputs_grad_amax_history", **amax_history_kwargs
)
self.outputs_grad_scale = self.add_weight(
name="outputs_grad_scale", **scale_kwargs
)
if backend.backend() == "jax":
# For unknown reason, we need to set these weights to be trainable
# to enable assignment in `outputs_qdq` for jax
self.outputs_grad_amax_history.trainable = True
self.outputs_grad_scale.trainable = True

def fp8_call(self, inputs):
if self.lora_enabled:
raise NotImplementedError(
"Currently, `fp8_call` doesn't support LoRA"
)

def compute_scale(amax, scale, dtype_max, margin=0):
"""Default function to convert amax to scaling factor."""
exp = ops.floor(ops.log2(ops.divide(dtype_max, amax))) - margin
sf = ops.round(ops.power(2.0, ops.abs(exp)))
sf = ops.where(amax > 0.0, sf, scale)
sf = ops.where(ops.isfinite(amax), sf, scale)
sf = ops.where(exp < 0.0, ops.reciprocal(sf), sf)
# The scaling factor we need equals to the notion of "scale_inv" in
# TransformerEngine. So, we convert the sf to its reciprocal.
return ops.reciprocal(sf)

def compute_scale_and_amax_history(
x, scale, amax_history, quantized_dtype
):
x = ops.convert_to_tensor(x)
scale = ops.convert_to_tensor(scale)
amax_history = ops.convert_to_tensor(amax_history)
quantized_dtype_max = float(ml_dtypes.finfo(quantized_dtype).max)

amax_update = ops.cast(ops.max(ops.abs(x)), scale.dtype)
amax_history_update = ops.scatter_update(
ops.roll(amax_history, shift=-1),
[[0]],
ops.reshape(amax_update, [1]),
)

amax_from_history = ops.max(amax_history)
scale_update = compute_scale(
amax_from_history, scale, quantized_dtype_max
)
return amax_history_update, scale_update

@ops.custom_gradient
def inputs_qdq(inputs):
"""Quantize-dequantize the inputs but not its gradient."""
qdq_inputs = quantizers.quantize_and_dequantize(
inputs,
ops.convert_to_tensor(self.inputs_scale),
"float8_e4m3fn",
self.compute_dtype,
)

def grad(*args, upstream=None):
if upstream is None:
(upstream,) = args
return upstream

return qdq_inputs, grad

@ops.custom_gradient
def kernel_qdq(kernel):
"""Quantize-dequantize the kernel but not its gradient."""
qdq_kernel = quantizers.quantize_and_dequantize(
kernel,
ops.convert_to_tensor(self.kernel_scale),
"float8_e4m3fn",
self.compute_dtype,
)

def grad(*args, upstream=None, variables=None):
if upstream is None:
(upstream,) = args
return upstream

return qdq_kernel, grad

@ops.custom_gradient
def outputs_qdq(outputs, scale, amax_history):
"""Quantize-dequantize the output gradient but not the output."""

def grad(*args, upstream=None):
if upstream is None:
(upstream,) = args
qdq_upstream = quantizers.quantize_and_dequantize(
upstream, scale, "float8_e5m2", self.compute_dtype
)
amax_history_update, scale_update = (
compute_scale_and_amax_history(
upstream, scale, amax_history, "float8_e5m2"
)
)
self.outputs_grad_scale.assign(scale_update)
self.outputs_grad_amax_history.assign(amax_history_update)
return qdq_upstream, None, None

return outputs, grad

x = ops.matmul(
inputs_qdq(inputs), kernel_qdq(ops.convert_to_tensor(self._kernel))
)
# `outputs_qdq` is placed immediately after `ops.matmul` for the sake
# of pattern matching in gemm_rewrite. That way, the qdq will be
# adjacent to the corresponding matmul_bprop in the bprop.
x = outputs_qdq(
x,
ops.convert_to_tensor(self.outputs_grad_scale),
ops.convert_to_tensor(self.outputs_grad_amax_history),
)
if self.bias is not None:
# Under non-mixed precision cases, F32 bias has to be converted to
# BF16 first to get the biasAdd fusion support. ref. PR
# https://github.com/tensorflow/tensorflow/pull/60306
bias = self.bias
if self.dtype_policy.compute_dtype == "float32":
bias_bf16 = ops.cast(bias, "bfloat16")
bias = ops.cast(bias_bf16, bias.dtype)
x = ops.add(x, bias)
if self.activation is not None:
x = self.activation(x)

# Update fp8 stats
amax_history_update, scale_update = compute_scale_and_amax_history(
inputs, self.inputs_scale, self.inputs_amax_history, "float8_e4m3fn"
)
self.inputs_scale.assign(scale_update)
self.inputs_amax_history.assign(amax_history_update)
amax_history_update, scale_update = compute_scale_and_amax_history(
self.kernel,
self.kernel_scale,
self.kernel_amax_history,
"float8_e4m3fn",
)
self.kernel_scale.assign(scale_update)
self.kernel_amax_history.assign(amax_history_update)
return x
1 change: 1 addition & 0 deletions keras/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from keras.quantizers.quantizers import AbsMaxQuantizer
from keras.quantizers.quantizers import Quantizer
from keras.quantizers.quantizers import abs_max_quantize
from keras.quantizers.quantizers import quantize_and_dequantize
from keras.saving import serialization_lib
from keras.utils.naming import to_snake_case

Expand Down
31 changes: 31 additions & 0 deletions keras/quantizers/quantizers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import ml_dtypes

from keras import backend
from keras import ops
from keras.api_export import keras_export
Expand Down Expand Up @@ -100,3 +102,32 @@ def get_config(self):
"epsilon": self.epsilon,
"output_dtype": self.output_dtype,
}


def quantize_and_dequantize(inputs, scale, quantized_dtype, compute_dtype):
# Quantize
quantized_dtype_min = float(ml_dtypes.finfo(quantized_dtype).min)
quantized_dtype_max = float(ml_dtypes.finfo(quantized_dtype).max)
x = ops.divide(inputs, ops.cast(scale, compute_dtype))
x = ops.clip(x, quantized_dtype_min, quantized_dtype_max)

# TODO: Introduce float8 dtype
if backend.backend() == "tensorflow":
import tensorflow as tf

x = tf.cast(x, quantized_dtype)
elif backend.backend() == "jax":
x = x.astype(quantized_dtype)
elif backend.backend() == "torch":
import torch

if quantized_dtype == "float8_e5m2":
x = x.to(torch.float8_e5m2)
if quantized_dtype == "float8_e4m3fn":
x = x.to(torch.float8_e4m3fn)

# x = ops.cast(x, quantized_dtype)

# Dequantize
x = ops.multiply(ops.cast(x, compute_dtype), ops.cast(scale, compute_dtype))
return x

0 comments on commit 9eb9629

Please sign in to comment.