Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test more architectures in ORTModel #675

Merged
merged 14 commits into from
Jan 16, 2023
14 changes: 13 additions & 1 deletion optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Type, Union

from transformers import PretrainedConfig, is_tf_available, is_torch_available
from transformers.utils import TF2_WEIGHTS_NAME, WEIGHTS_NAME, logging
Expand Down Expand Up @@ -141,6 +141,7 @@ class TasksManager:
"stable-diffusion": "diffusers",
}

# TODO: some models here support causal-lm export but are not supported in ORTModelForCausalLM
# Set of model topologies we support associated to the tasks supported by each topology and the factory
_SUPPORTED_MODEL_TYPE = {
"audio-spectrogram-transformer": supported_tasks_mapping(
Expand Down Expand Up @@ -725,6 +726,17 @@ def get_supported_tasks_for_model_type(
else:
return TasksManager._SUPPORTED_MODEL_TYPE[model_type][exporter]

@staticmethod
def get_supported_model_type_for_task(task: str, exporter: str) -> List[str]:
"""
Returns the list of supported architectures by the exporter for a given task.
"""
return [
model_type.replace("-", "_")
for model_type in TasksManager._SUPPORTED_MODEL_TYPE
if task in TasksManager._SUPPORTED_MODEL_TYPE[model_type][exporter]
]

@staticmethod
def format_task(task: str) -> str:
return task.replace("-with-past", "")
Expand Down
Loading