Skip to content

Commit

Permalink
Improve ORTModel documentation (#1245)
Browse files Browse the repository at this point in the history
* why wouldn't this be fine?

* build the doc pls

* see if it builds

* missing dep docbuild

* meaningful documentation

* fix

* some more improvements

* we actually need to inherit first from ORTStableDiffusionPipelineBase for some class attributes

* would that work?

* fix
  • Loading branch information
fxmarty authored Aug 28, 2023
1 parent 0ee2dfa commit 128ce3e
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 67 deletions.
2 changes: 1 addition & 1 deletion docs/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/doc-builder.git

RUN git clone $clone_url && cd optimum && git checkout $commit_sha
RUN python3 -m pip install --no-cache-dir ./optimum[onnxruntime,benchmark,quality,exporters-tf,doc-build]
RUN python3 -m pip install --no-cache-dir ./optimum[onnxruntime,benchmark,quality,exporters-tf,doc-build,diffusers]
43 changes: 10 additions & 33 deletions docs/source/onnxruntime/package_reference/modeling_ort.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -26,44 +26,32 @@ The following ORT classes are available for the following natural language proce

### ORTModelForCausalLM

This class officially supports bloom, codegen, gpt2, gpt_bigcode, gpt_neo, gpt_neox, gptj, llama.

[[autodoc]] onnxruntime.ORTModelForCausalLM
- forward

### ORTModelForMaskedLM

This class officially supports albert, bert, camembert, convbert, data2vec_text, deberta, deberta_v2, distilbert, electra, flaubert, ibert, mobilebert, roberta, roformer, squeezebert, xlm, xlm_roberta.

[[autodoc]] onnxruntime.ORTModelForMaskedLM

### ORTModelForSeq2SeqLM

This class officially supports bart, blenderbot, blenderbot_small, longt5, m2m_100, marian, mbart, mt5, pegasus, t5.

[[autodoc]] onnxruntime.ORTModelForSeq2SeqLM
- forward

### ORTModelForSequenceClassification

This class officially supports albert, bart, bert, camembert, convbert, data2vec_text, deberta, deberta_v2, distilbert, electra, flaubert, ibert, mbart, mobilebert, nystromformer, roberta, roformer, squeezebert, xlm, xlm_roberta.

[[autodoc]] onnxruntime.ORTModelForSequenceClassification

### ORTModelForTokenClassification

This class officially supports albert, bert, bloom, camembert, convbert, data2vec_text, deberta, deberta_v2, distilbert, electra, flaubert, gpt2, ibert, mobilebert, roberta, roformer, squeezebert, xlm, xlm_roberta.

[[autodoc]] onnxruntime.ORTModelForTokenClassification

### ORTModelForMultipleChoice

This class officially supports albert, bert, camembert, convbert, data2vec_text, deberta_v2, distilbert, electra, flaubert, ibert, mobilebert, nystromformer, roberta, roformer, squeezebert, xlm, xlm_roberta.

[[autodoc]] onnxruntime.ORTModelForMultipleChoice

### ORTModelForQuestionAnswering

This class officially supports albert, bart, bert, camembert, convbert, data2vec_text, deberta, deberta_v2, distilbert, electra, flaubert, gptj, ibert, mbart, mobilebert, nystromformer, roberta, roformer, squeezebert, xlm, xlm_roberta.

[[autodoc]] onnxruntime.ORTModelForQuestionAnswering

## Computer vision
Expand All @@ -72,14 +60,10 @@ The following ORT classes are available for the following computer vision tasks.

### ORTModelForImageClassification

This class officially supports beit, convnext, data2vec_vision, deit, levit, mobilenet_v1, mobilenet_v2, mobilevit, poolformer, resnet, segformer, swin, vit.

[[autodoc]] onnxruntime.ORTModelForImageClassification

### ORTModelForSemanticSegmentation

This class officially supports segformer.

[[autodoc]] onnxruntime.ORTModelForSemanticSegmentation

## Audio
Expand All @@ -88,32 +72,23 @@ The following ORT classes are available for the following audio tasks.

### ORTModelForAudioClassification

This class officially supports audio_spectrogram_transformer, data2vec_audio, hubert, sew, sew_d, unispeech, unispeech_sat, wavlm, wav2vec2, wav2vec2-conformer.

[[autodoc]] onnxruntime.ORTModelForAudioClassification

### ORTModelForAudioFrameClassification

This class officially supports data2vec_audio, unispeech_sat, wavlm, wav2vec2, wav2vec2-conformer.

[[autodoc]] onnxruntime.ORTModelForAudioFrameClassification

### ORTModelForCTC

This class officially supports data2vec_audio, hubert, sew, sew_d, unispeech, unispeech_sat, wavlm, wav2vec2, wav2vec2-conformer.

[[autodoc]] onnxruntime.ORTModelForCTC

### ORTModelForSpeechSeq2Seq

This class officially supports whisper, speech_to_text.

[[autodoc]] onnxruntime.ORTModelForSpeechSeq2Seq
- forward

### ORTModelForAudioXVector

This class officially supports data2vec_audio, unispeech_sat, wavlm, wav2vec2, wav2vec2-conformer.

[[autodoc]] onnxruntime.ORTModelForAudioXVector

## Multimodal
Expand All @@ -122,15 +97,13 @@ The following ORT classes are available for the following multimodal tasks.

### ORTModelForVision2Seq

This class officially supports trocr and vision-encoder-decoder.

[[autodoc]] onnxruntime.ORTModelForVision2Seq
- forward

### ORTModelForPix2Struct

This class officially supports pix2struct.

[[autodoc]] onnxruntime.ORTModelForPix2Struct
- forward

## Custom Tasks

Expand All @@ -149,20 +122,24 @@ The following ORT classes are available for the following custom tasks.
#### ORTStableDiffusionPipeline

[[autodoc]] onnxruntime.ORTStableDiffusionPipeline
- __call__

#### ORTStableDiffusionImg2ImgPipeline

[[autodoc]] onnxruntime.ORTStableDiffusionImg2ImgPipeline
- __call__

#### ORTStableDiffusionInpaintPipeline

[[autodoc]] onnxruntime.ORTStableDiffusionInpaintPipeline

- __call__

#### ORTStableDiffusionXLPipeline

[[autodoc]] onnxruntime.ORTStableDiffusionXLPipeline
- __call__

#### ORTStableDiffusionXLImg2ImgPipeline

[[autodoc]] onnxruntime.ORTStableDiffusionXLImg2ImgPipeline
- __call__
7 changes: 4 additions & 3 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
from transformers import AutoModelForCausalLM, GenerationConfig
from transformers.file_utils import add_start_docstrings_to_model_forward
from transformers.file_utils import add_end_docstrings, add_start_docstrings_to_model_forward
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

import onnxruntime
Expand All @@ -35,7 +35,7 @@
from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from .base import ORTDecoder
from .constants import DECODER_MERGED_ONNX_FILE_PATTERN, DECODER_ONNX_FILE_PATTERN, DECODER_WITH_PAST_ONNX_FILE_PATTERN
from .modeling_ort import ORTModel
from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel
from .models.bloom import bloom_convert_to_bloom_cache, bloom_convert_to_standard_cache
from .utils import (
ONNX_DECODER_NAME,
Expand Down Expand Up @@ -622,9 +622,10 @@ def to(self, device: Union[torch.device, str, int]):
return self


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForCausalLM(ORTModelDecoder, GenerationMixin):
"""
ONNX model with a causal language modeling head for ONNX Runtime inference.
ONNX model with a causal language modeling head for ONNX Runtime inference. This class officially supports bloom, codegen, gpt2, gpt_bigcode, gpt_neo, gpt_neox, gptj, llama.
"""

auto_model_class = AutoModelForCausalLM
Expand Down
45 changes: 34 additions & 11 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from diffusers.utils import CONFIG_NAME
from huggingface_hub import snapshot_download
from transformers import CLIPFeatureExtractor, CLIPTokenizer
from transformers.file_utils import add_end_docstrings

import onnxruntime as ort

Expand All @@ -51,7 +52,7 @@
DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER,
DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER,
)
from .modeling_ort import ORTModel
from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel
from .utils import (
_ORT_TO_NP_TYPE,
ONNX_WEIGHTS_NAME,
Expand Down Expand Up @@ -288,6 +289,7 @@ def _from_pretrained(
patterns = set(config.keys())
sub_models_to_load = patterns.intersection({"feature_extractor", "tokenizer", "tokenizer_2", "scheduler"})

print("GO HERE")
if not os.path.isdir(model_id):
patterns.update({"vae_encoder", "vae_decoder"})
allow_patterns = {os.path.join(k, "*") for k in patterns if not k.startswith("_")}
Expand Down Expand Up @@ -441,6 +443,7 @@ def to(self, device: Union[torch.device, str, int]):

@classmethod
def _load_config(cls, config_name_or_path: Union[str, os.PathLike], **kwargs):
print("cls here", cls)
return cls.load_config(config_name_or_path, **kwargs)

def _save_config(self, save_directory):
Expand Down Expand Up @@ -531,19 +534,31 @@ def forward(self, sample: np.ndarray):
return outputs


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTStableDiffusionPipeline(ORTStableDiffusionPipelineBase, StableDiffusionPipelineMixin):
def __call__(self, *args, **kwargs):
return StableDiffusionPipelineMixin.__call__(self, *args, **kwargs)
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline).
"""

__call__ = StableDiffusionPipelineMixin.__call__


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTStableDiffusionImg2ImgPipeline(ORTStableDiffusionPipelineBase, StableDiffusionImg2ImgPipelineMixin):
def __call__(self, *args, **kwargs):
return StableDiffusionImg2ImgPipelineMixin.__call__(self, *args, **kwargs)
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/img2img#diffusers.StableDiffusionImg2ImgPipeline).
"""

__call__ = StableDiffusionImg2ImgPipelineMixin.__call__


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTStableDiffusionInpaintPipeline(ORTStableDiffusionPipelineBase, StableDiffusionInpaintPipelineMixin):
def __call__(self, *args, **kwargs):
return StableDiffusionInpaintPipelineMixin.__call__(self, *args, **kwargs)
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionInpaintPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/inpaint#diffusers.StableDiffusionInpaintPipeline).
"""

__call__ = StableDiffusionInpaintPipelineMixin.__call__


class ORTStableDiffusionXLPipelineBase(ORTStableDiffusionPipelineBase):
Expand Down Expand Up @@ -585,11 +600,19 @@ def __init__(
self.watermark = StableDiffusionXLWatermarker()


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTStableDiffusionXLPipeline(ORTStableDiffusionXLPipelineBase, StableDiffusionXLPipelineMixin):
def __call__(self, *args, **kwargs):
return StableDiffusionXLPipelineMixin.__call__(self, *args, **kwargs)
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline).
"""

__call__ = StableDiffusionXLPipelineMixin.__call__


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTStableDiffusionXLImg2ImgPipeline(ORTStableDiffusionXLPipelineBase, StableDiffusionXLImg2ImgPipelineMixin):
def __call__(self, *args, **kwargs):
return StableDiffusionXLImg2ImgPipelineMixin.__call__(self, *args, **kwargs)
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLImg2ImgPipeline).
"""

__call__ = StableDiffusionXLImg2ImgPipelineMixin.__call__
27 changes: 14 additions & 13 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@
_PROCESSOR_FOR_DOC = "AutoProcessor"

ONNX_MODEL_END_DOCSTRING = r"""
This model inherits from [`~onnxruntime.modeling_ort.ORTModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving)
This model inherits from [`~onnxruntime.modeling_ort.ORTModel`], check its documentation for the generic methods the
library implements for all its model (such as downloading or saving).
This class should be initialized using the [`onnxruntime.modeling_ort.ORTModel.from_pretrained`] method.
"""
Expand Down Expand Up @@ -969,7 +969,7 @@ def forward(
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForMaskedLM(ORTModel):
"""
ONNX Model with a MaskedLMOutput for masked language modeling tasks.
ONNX Model with a MaskedLMOutput for masked language modeling tasks. This class officially supports albert, bert, camembert, convbert, data2vec_text, deberta, deberta_v2, distilbert, electra, flaubert, ibert, mobilebert, roberta, roformer, squeezebert, xlm, xlm_roberta.
"""

auto_model_class = AutoModelForMaskedLM
Expand Down Expand Up @@ -1072,7 +1072,7 @@ def forward(
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForQuestionAnswering(ORTModel):
"""
ONNX Model with a QuestionAnsweringModelOutput for extractive question-answering tasks like SQuAD.
ONNX Model with a QuestionAnsweringModelOutput for extractive question-answering tasks like SQuAD. This class officially supports albert, bart, bert, camembert, convbert, data2vec_text, deberta, deberta_v2, distilbert, electra, flaubert, gptj, ibert, mbart, mobilebert, nystromformer, roberta, roformer, squeezebert, xlm, xlm_roberta.
"""

auto_model_class = AutoModelForQuestionAnswering
Expand Down Expand Up @@ -1195,7 +1195,7 @@ def forward(
class ORTModelForSequenceClassification(ORTModel):
"""
ONNX Model with a sequence classification/regression head on top (a linear layer on top of the
pooled output) e.g. for GLUE tasks.
pooled output) e.g. for GLUE tasks. This class officially supports albert, bart, bert, camembert, convbert, data2vec_text, deberta, deberta_v2, distilbert, electra, flaubert, ibert, mbart, mobilebert, nystromformer, roberta, roformer, squeezebert, xlm, xlm_roberta.
"""

auto_model_class = AutoModelForSequenceClassification
Expand Down Expand Up @@ -1296,7 +1296,8 @@ def forward(
class ORTModelForTokenClassification(ORTModel):
"""
ONNX Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
for Named-Entity-Recognition (NER) tasks.
for Named-Entity-Recognition (NER) tasks. This class officially supports albert, bert, bloom, camembert, convbert, data2vec_text, deberta, deberta_v2, distilbert, electra, flaubert, gpt2, ibert, mobilebert, roberta, roformer, squeezebert, xlm, xlm_roberta.
"""

auto_model_class = AutoModelForTokenClassification
Expand Down Expand Up @@ -1394,7 +1395,7 @@ def forward(
class ORTModelForMultipleChoice(ORTModel):
"""
ONNX Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
softmax) e.g. for RocStories/SWAG tasks.
softmax) e.g. for RocStories/SWAG tasks. This class officially supports albert, bert, camembert, convbert, data2vec_text, deberta_v2, distilbert, electra, flaubert, ibert, mobilebert, nystromformer, roberta, roformer, squeezebert, xlm, xlm_roberta.
"""

auto_model_class = AutoModelForMultipleChoice
Expand Down Expand Up @@ -1499,7 +1500,7 @@ def forward(
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForImageClassification(ORTModel):
"""
ONNX Model for image-classification tasks.
ONNX Model for image-classification tasks. This class officially supports beit, convnext, data2vec_vision, deit, levit, mobilenet_v1, mobilenet_v2, mobilevit, poolformer, resnet, segformer, swin, vit.
"""

auto_model_class = AutoModelForImageClassification
Expand Down Expand Up @@ -1593,7 +1594,7 @@ def forward(
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForSemanticSegmentation(ORTModel):
"""
ONNX Model for semantic-segmentation, with an all-MLP decode head on top e.g. for ADE20k, CityScapes.
ONNX Model for semantic-segmentation, with an all-MLP decode head on top e.g. for ADE20k, CityScapes. This class officially supports segformer.
"""

auto_model_class = AutoModelForSemanticSegmentation
Expand Down Expand Up @@ -1700,7 +1701,7 @@ def _prepare_onnx_inputs(self, use_torch: bool, **kwargs):
class ORTModelForAudioClassification(ORTModel):
"""
ONNX Model for audio-classification, with a sequence classification head on top (a linear layer over the pooled output) for tasks like
SUPERB Keyword Spotting.
SUPERB Keyword Spotting. This class officially supports audio_spectrogram_transformer, data2vec_audio, hubert, sew, sew_d, unispeech, unispeech_sat, wavlm, wav2vec2, wav2vec2-conformer.
"""

auto_model_class = AutoModelForAudioClassification
Expand Down Expand Up @@ -1785,7 +1786,7 @@ def forward(
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForCTC(ORTModel):
"""
ONNX Model with a language modeling head on top for Connectionist Temporal Classification (CTC).
ONNX Model with a language modeling head on top for Connectionist Temporal Classification (CTC). This class officially supports data2vec_audio, hubert, sew, sew_d, unispeech, unispeech_sat, wavlm, wav2vec2, wav2vec2-conformer.
"""

auto_model_class = AutoModelForCTC
Expand Down Expand Up @@ -1868,7 +1869,7 @@ def forward(
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForAudioXVector(ORTModel):
"""
ONNX Model with an XVector feature extraction head on top for tasks like Speaker Verification.
ONNX Model with an XVector feature extraction head on top for tasks like Speaker Verification. This class officially supports data2vec_audio, unispeech_sat, wavlm, wav2vec2, wav2vec2-conformer.
"""

auto_model_class = AutoModelForAudioXVector
Expand Down Expand Up @@ -1957,7 +1958,7 @@ def forward(
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForAudioFrameClassification(ORTModel):
"""
ONNX Model with a frame classification head on top for tasks like Speaker Diarization.
ONNX Model with a frame classification head on top for tasks like Speaker Diarization. This class officially supports data2vec_audio, unispeech_sat, wavlm, wav2vec2, wav2vec2-conformer.
"""

auto_model_class = AutoModelForAudioFrameClassification
Expand Down
Loading

0 comments on commit 128ce3e

Please sign in to comment.