Skip to content

Commit

Permalink
Fix: copy feature_extractor for whisper model
Browse files Browse the repository at this point in the history
Signed-off-by: sagewe <sagewe@fastmail.com>
  • Loading branch information
sagewe committed Nov 18, 2024
1 parent 6040dfb commit a8264a2
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
43 changes: 43 additions & 0 deletions mergekit/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,19 @@ def run_merge(
"Chat template specified but no tokenizer found. Chat template will not be saved."
)

# Copy feature_extractor if it is a whisper model
if options.copy_feature_extractor and arch_info.definition.expected_model_type == "whisper":
try:
_copy_feature_extractor(
merge_config, out_path, trust_remote_code=options.trust_remote_code
)
except Exception as e:
logging.error(
"Failed to copy feature_extractor. The merge was still successful, just copy it from somewhere else.",
exc_info=e,
)


if tokenizer:
logging.info("Saving tokenizer")
_set_chat_template(tokenizer, merge_config)
Expand Down Expand Up @@ -229,6 +242,36 @@ def _copy_tokenizer(
tokenizer.save_pretrained(out_path, safe_serialization=True)


def _copy_feature_extractor(
merge_config: MergeConfiguration, out_path: str, trust_remote_code: bool = False
):
donor_model = merge_config.base_model or (merge_config.referenced_models()[0])

if (os.path.exists(
os.path.join(donor_model.model.path, "preprocessor_config.json")
)
):
logging.info(f"Copying feature_extractor from {donor_model}")

for file_name in [
"preprocessor_config.json",
]:
if os.path.exists(os.path.join(donor_model.model.path, file_name)):
shutil.copy(
os.path.join(donor_model.model.path, file_name),
os.path.join(out_path, file_name),
)
return

# fallback: try actually loading the feature_extractor and saving it
logging.info(f"Reserializing feature_extractor from {donor_model}")
feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(
donor_model.model.path,
revision=donor_model.model.revision,
trust_remote_code=trust_remote_code,
)
_set_chat_template(feature_extractor, merge_config)
feature_extractor.save_pretrained(out_path, safe_serialization=True)
def _model_out_config(
config: MergeConfiguration,
arch_info: ArchitectureInfo,
Expand Down
1 change: 1 addition & 0 deletions mergekit/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class MergeOptions(BaseModel):
low_cpu_memory: bool = False
out_shard_size: int = parse_kmb("5B")
copy_tokenizer: bool = True
copy_feature_extractor: bool = True
clone_tensors: bool = False
trust_remote_code: bool = False
random_seed: Optional[int] = None
Expand Down

0 comments on commit a8264a2

Please sign in to comment.