Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] calib avg for calib dataset arg passed as tensors #254

Merged
merged 1 commit into from
Jul 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,18 @@ def quantize(

# Calculate the average length of the average input_ids
total_input_ids_length = 0
for e in calibration_dataset:
input_ids_length = len(e["input_ids"])
for row in calibration_dataset:
input_ids = row["input_ids"]
if isinstance(input_ids, torch.Tensor):
input_ids_length = input_ids.shape[1]
else:
input_ids_length = len(input_ids)
total_input_ids_length += input_ids_length
avg = total_input_ids_length / len(calibration_dataset)

if avg < min_calibration_dataset_input_ids_avg_length:
logger.warning(f"The average length of input_ids of calibration_dataset should be greater than "
f"{min_calibration_dataset_input_ids_avg_length}! Current AVG is {avg}.")
f"{min_calibration_dataset_input_ids_avg_length}: actual avg: {avg}.")

device_map = self.hf_device_map
if device_map:
Expand Down