Skip to content

Commit

Permalink
Auto damp recovery (#326)
Browse files Browse the repository at this point in the history
  • Loading branch information
CSY-ModelCloud authored Aug 2, 2024
1 parent db15847 commit 4473139
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 32 deletions.
9 changes: 9 additions & 0 deletions gptqmodel/integration/optimum/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
dataset: Optional[Union[List[str], str]] = None,
group_size: int = 128,
damp_percent: float = 0.1,
damp_auto_increment: float = 0.0015,
desc_act: bool = False,
sym: bool = True,
true_sequential: bool = True,
Expand All @@ -87,6 +88,7 @@ def __init__(
self.dataset = dataset
self.group_size = group_size
self.damp_percent = damp_percent
self.damp_auto_increment = damp_auto_increment
self.desc_act = desc_act
self.sym = sym
self.true_sequential = true_sequential
Expand Down Expand Up @@ -125,6 +127,7 @@ def __init__(
dataset: Optional[Union[List[str], str]] = None,
group_size: int = 128,
damp_percent: float = 0.1,
damp_auto_increment: float = 0.0015,
desc_act: bool = False,
sym: bool = True,
true_sequential: bool = True,
Expand Down Expand Up @@ -200,6 +203,7 @@ def __init__(
self.dataset = dataset
self.group_size = group_size
self.damp_percent = damp_percent
self.damp_auto_increment = damp_auto_increment
self.desc_act = desc_act
self.sym = sym
self.true_sequential = true_sequential
Expand All @@ -218,6 +222,7 @@ def __init__(
quantize_config = QuantizeConfig()
quantize_config.group_size = self.group_size
quantize_config.damp_percent = self.damp_percent
quantize_config.damp_auto_increment = self.damp_auto_increment
quantize_config.desc_act = self.desc_act
quantize_config.sym = self.sym
quantize_config.true_sequential = self.true_sequential
Expand All @@ -229,6 +234,7 @@ def __init__(
"dataset",
"group_size",
"damp_percent",
"damp_auto_increment",
"desc_act",
"sym",
"true_sequential",
Expand All @@ -243,6 +249,9 @@ def __init__(
if not (0 < self.damp_percent < 1):
raise ValueError("damp_percent must between 0 and 1.")

if self.damp_auto_increment < 0:
raise ValueError("damp_auto_increment must greater than 0.")

if self.exllama_config is None:
self.exllama_config = {"version": ExllamaVersion.TWO}
else:
Expand Down
36 changes: 13 additions & 23 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,29 +494,19 @@ def tmp(_, inp, out):
for name in subset:
layer_pb.set_description(f"Quantizing {name} in layer {i} of {layer_count - 1}")

try:
scale, zero, g_idx, duration, avg_loss, bits = gptq[name].fasterquant(
percdamp=self.quantize_config.damp_percent,
group_size=self.quantize_config.group_size,
actorder=self.quantize_config.desc_act,
static_groups=self.quantize_config.static_groups,
)
if self.quantize_config.dynamic_bits is not None:
stat = {"layer": i, "module": name, "avg_loss": f"{avg_loss:.5f}", "bits": bits,
"time": f"{duration:.3f}"}
else:
stat = {"layer": i, "module": name, "avg_loss": f"{avg_loss:.5f}",
"time": f"{duration:.3f}"}

quant_log.append(stat)
logger.info(stat)

except torch._C._LinAlgError as e:
if "not positive-definite" in str(e).lower():
logger.warning(
"Please increase damp or nsamples for calibration data to avoid the following quant error. "
)
raise e
scale, zero, g_idx, duration, avg_loss, bits, damp_percent = gptq[name].fasterquant(
percdamp=self.quantize_config.damp_percent,
damp_auto_increment=self.quantize_config.damp_auto_increment,
group_size=self.quantize_config.group_size,
actorder=self.quantize_config.desc_act,
static_groups=self.quantize_config.static_groups,
)
stat = {"layer": i, "module": name, "avg_loss": f"{avg_loss:.5f}", "damp_percent": f"{damp_percent:.5f}", "time": f"{duration:.3f}"}
if self.quantize_config.dynamic_bits is not None:
stat["bits"]=f"{bits}"

quant_log.append(stat)
logger.info(stat)

quantizers[f"{self.layers_node}.{i}.{name}"] = (
gptq[name].quantizer.to(CPU if force_layer_back_to_cpu else cur_layer_device),
Expand Down
7 changes: 6 additions & 1 deletion gptqmodel/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class QuantizeConfig():
group_size: int = field(default=128)
# increase damp if NaN is encountred during `.quantize()` and/or increase calib dataset size
damp_percent: float = field(default=0.005)
damp_auto_increment: float = field(default=0.0015)
desc_act: bool = field(default=True)
static_groups: bool = field(default=False)
sym: bool = field(default=True)
Expand Down Expand Up @@ -126,11 +127,14 @@ def __post_init__(self):
f"Layer {layer}: only support quantize to {fields_info[0].metadata['choices']} bits.")

if self.group_size != -1 and self.group_size <= 0:
raise ValueError("unless equal to -1, group_size must greater then 0.")
raise ValueError("unless equal to -1, group_size must greater than 0.")

if not (0 < self.damp_percent < 1):
raise ValueError("damp_percent must between 0 and 1.")

if self.damp_auto_increment < 0:
raise ValueError("damp_auto_increment must greater than 0.")

# validate meta
if self.meta is not None:
if not isinstance(self.meta, dict):
Expand Down Expand Up @@ -302,6 +306,7 @@ def to_dict(self):
"sym": self.sym,
"lm_head": self.lm_head,
"damp_percent": self.damp_percent,
"damp_auto_increment": self.damp_auto_increment,
"true_sequential": self.true_sequential,
# TODO: deprecate?
"model_name_or_path": self.model_name_or_path,
Expand Down
30 changes: 22 additions & 8 deletions gptqmodel/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def fasterquant(
self,
blocksize=128,
percdamp=0.01,
damp_auto_increment=0.0015,
group_size=-1,
actorder=False,
static_groups=False,
Expand Down Expand Up @@ -114,13 +115,26 @@ def fasterquant(
Losses = torch.zeros_like(W)
Q = torch.zeros_like(W)

damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
while 1 > percdamp > 0:
try:
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
break
except torch._C._LinAlgError as e:
if damp_auto_increment != 0:
logger.warning(f"Current damp={percdamp:.5f} is too low, increased by {damp_auto_increment:.5f}")
percdamp += damp_auto_increment
else:
logger.warning("Please increase damp or nsamples for calibration data to avoid the following quant error. ")
raise e

if not (0 < percdamp < 1):
raise ValueError(f"damp_percent must between 0 and 1. current is {percdamp}")

for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
Expand Down Expand Up @@ -195,7 +209,7 @@ def fasterquant(
zero.append(self.quantizer.zero)
scale = torch.cat(scale, dim=1)
zero = torch.cat(zero, dim=1)
return scale, zero, g_idx, duration, avg_loss, bits
return scale, zero, g_idx, duration, avg_loss, bits, percdamp

def free(self):
if os.environ.get("DEBUG"):
Expand Down

0 comments on commit 4473139

Please sign in to comment.