Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
tengomucho committed Jan 14, 2025
1 parent 8b6eb13 commit 633922d
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def model_can_use_jetstream_pt(model_path: str) -> bool:
"""
config = AutoConfig.from_pretrained(model_path)
# For now few models are supported
supported_models = ["llama", "gemma", "mixtral"]
supported_models = ["llama", "gemma", "mixtral", "qwen2"]
if config.model_type not in supported_models:
return False
if jetstream_pt_available():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from transformers import AutoConfig

from .compatibility import model_can_use_jetstream_pt
from .models import GemmaModel, LlamaModel, MixtralModel
from .models import GemmaModel, LlamaModel, MixtralModel, Qwen2Model


class OptimumJetstreamEngine(PyTorchEngine):
Expand Down Expand Up @@ -66,6 +66,8 @@ def load_model_info(config: "PretrainedConfig") -> Any:
model_class = GemmaModel
elif config.model_type == "mixtral":
model_class = MixtralModel
elif config.model_type == "qwen2":
model_class = Qwen2Model
else:
raise ValueError(f"Unsupported model type {config.model_type}")
model_info = fetch_models.ModelInfo(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .gemma_model_hf import GemmaModelHf as GemmaModel
from .llama_model_exportable_hf import TransformerHf as LlamaModel
from .mixtral_model_hf import MixtralModelHf as MixtralModel
from .qwen2_model import Qwen2Model
Loading

0 comments on commit 633922d

Please sign in to comment.