Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Training] Error building gradient graph for bert models for on-device training #22465

Open
riccardopinosio opened this issue Oct 16, 2024 · 5 comments
Labels
contributions welcome lower priority issues for the core ORT teams training issues related to ONNX Runtime training; typically submitted using template

Comments

@riccardopinosio
Copy link

Describe the issue

Hello,

see also this discussion. I'm opening this one as I think it's an issue as sifting through previous issues training should work for bert models.

I am trying to generate artifacts for distilbert like so:

from transformers import AutoModel
import torch

modelName = "distilbert/distilbert-base-uncased"

model = AutoModel.from_pretrained(modelName)
example_input = (
    torch.randint(10, (1, 10)),
    torch.ones(10, dtype=int).view(1,10)
)
model_path = Path("./embedding_test")

torch.onnx.export(
    model,
    example_input,
    "./embedding_test",
    export_params=True,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
    input_names=["input_ids",
                "attention_mask"],
    output_names=["output"])

onnx_model = onnx.load("./embedding_test")

p = Path("./embedding_training")
p.mkdir(exist_ok=True, parents=True)

artifacts.generate_artifacts(onnx_model,
frozen_params=[],
requires_grad=[initializer.name for initializer in onnx_model.graph.initializer],
loss=artifacts.LossType.MSELoss,
optimizer=artifacts.OptimType.AdamW,
loss_input_names=["output"],
artifact_directory=p)

The exported onnx model works perfectly for inference, but artifact generation throws up:

{
	"name": "RuntimeError",
	"message": "/onnxruntime_src/orttraining/orttraining/core/graph/gradient_builder_base.h:123 onnxruntime::training::ArgDef onnxruntime::training::GradientBuilderBase::O(size_t, bool) const i < node_->OutputDefs().size() was false. 
",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File /home/rpinosio/repositories/knights/hugot/python/generate_embedding_training_model.py:1
----> 1 artifacts.generate_artifacts(onnx_model,
      2 frozen_params=[],
      3 requires_grad=[initializer.name for initializer in onnx_model.graph.initializer],
      4 loss=artifacts.LossType.MSELoss,
      5 optimizer=artifacts.OptimType.AdamW,
      6 loss_input_names=[\"output\"],
      7 artifact_directory=p)

File ~/miniconda3/envs/hugoTrainer/lib/python3.12/site-packages/onnxruntime/training/artifacts.py:193, in generate_artifacts(model, requires_grad, frozen_params, loss, optimizer, artifact_directory, prefix, ort_format, custom_op_library, additional_output_names, nominal_checkpoint, loss_input_names)
    186     custom_op_library_path = pathlib.Path(custom_op_library)
    188 with onnxblock.base(loaded_model, model_path), (
    189     onnxblock.custom_op_library(custom_op_library_path)
    190     if custom_op_library is not None
    191     else contextlib.nullcontext()
    192 ):
--> 193     _ = training_block(*[output.name for output in loaded_model.graph.output])
    194     training_model, eval_model = training_block.to_model_proto()
    195     model_params = training_block.parameters()

File ~/miniconda3/envs/hugoTrainer/lib/python3.12/site-packages/onnxruntime/training/onnxblock/onnxblock.py:204, in TrainingBlock.__call__(self, *args, **kwargs)
    196 self._parameters = _training_graph_utils.get_model_parameters(model, self._requires_grad, self._frozen_params)
    198 # Build the gradient graph. The gradient graph building is composed of the following steps:
    199 #   - Move all model parameters to model inputs.
    200 #   - Run orttraining graph transformers on the model.
    201 #   - Add the gradient graph to the optimized model.
    202 # The order of model inputs after gradient graph building is: user inputs, model parameters as inputs
    203 # The order of the model outputs is: user outputs, model parameter gradients (in the order of parameter inputs)
--> 204 self._training_model, self._eval_model = _training_graph_utils.build_gradient_graph(
    205     model, self._requires_grad, self._frozen_params, output, accessor._GLOBAL_CUSTOM_OP_LIBRARY
    206 )
    208 logging.debug(\"Adding gradient accumulation nodes for training block %s\", self.__class__.__name__)
    210 _training_graph_utils.build_gradient_accumulation_graph(self._training_model, self._requires_grad)

File ~/miniconda3/envs/hugoTrainer/lib/python3.12/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:130, in build_gradient_graph(model, requires_grad, frozen_params, output_names, custom_op_library)
    127 optimized_model = onnx.load_from_string(get_optimized_model(model.SerializeToString(), requires_grad, options))
    129 # Assumption is that the first graph output is the loss output
--> 130 gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names[0], options)
    132 _reorder_outputs(gradient_model, output_names, requires_grad)
    134 return gradient_model, eval_model

File ~/miniconda3/envs/hugoTrainer/lib/python3.12/site-packages/onnxruntime/training/onnxblock/_training_graph_utils.py:84, in _gradient_model_for(model, requires_grad, loss_name, options)
     79 logging.debug(
     80     \"The loss output is %s. The gradient graph will be built starting from %s_grad.\", loss_name, loss_name
     81 )
     83 builder = GradientGraphBuilder(model.SerializeToString(), {loss_name}, requires_grad, loss_name, options)
---> 84 builder.build()
     85 return onnx.load_from_string(builder.get_model())

RuntimeError: /onnxruntime_src/orttraining/orttraining/core/graph/gradient_builder_base.h:123 onnxruntime::training::ArgDef onnxruntime::training::GradientBuilderBase::O(size_t, bool) const i < node_->OutputDefs().size() was false. 
}

Seems to have issues building the gradient graph as it gets out of bounds on OutputDefs.

To reproduce

See the code provided above.

Urgency

It's blocking the development of go bindings to onnx training which we want to use in our product.

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.19.2

PyTorch Version

2.4.1+cu121

Execution Provider

Default CPU

Execution Provider Library Version

No response

@riccardopinosio riccardopinosio added the training issues related to ONNX Runtime training; typically submitted using template label Oct 16, 2024
@github-actions github-actions bot added the model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. label Oct 16, 2024
@snnn snnn added contributions welcome lower priority issues for the core ORT teams and removed model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. labels Oct 16, 2024
@jkbeavers
Copy link
Contributor

jkbeavers commented Oct 23, 2024

This looks similar to the issue I had and fixed in #22414 . You can verify it's the same issue if you change your loss to crossentropy and see artifact generation succeed.
If so, if you try using a nightly build, local build using master, or wait for the 1.20 release, it should be resolved (with mse loss).

@riccardopinosio
Copy link
Author

@jkbeavers thanks, will try that. By the way, i thought onnxruntime training was going to be deprecated?

@rkoystart
Copy link

rkoystart commented Nov 2, 2024

I am also facing a similar issue. Below is the code to reproduce.

from onnxruntime.training.api import CheckpointState, Module, Optimizer
from onnxruntime import InferenceSession
import os
import torch.nn as nn
import torch
import numpy as np
import time
import onnx
from sklearn.metrics import f1_score, accuracy_score
import json
import io
from tqdm import tqdm
from constants import TEST_FILE,TRAIN_FILE,NUM_CLASSES,BATCH_SIZE,LEARNING_RATE,EPOCHS,DROPOUT_RATE,HIDDEN_SIZE,LABEL_FILE,DEVICE, EMBEDDING_DIM
from typing import Any
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
import time
import datetime
from dataset_reader import TextDataset2, collate_batch, TEXT_PIPELINE

class TextClassificationModel(nn.Module):
    def __init__(self, embedding_dim, hidden_size, num_classes):
        super(TextClassificationModel, self).__init__()
        self.fc_1 = nn.Linear(embedding_dim, hidden_size)
        self.drop = torch.nn.Dropout(DROPOUT_RATE)
        self.relu = torch.nn.ReLU()
        self.fc_2 = nn.Linear(hidden_size, num_classes)
        self.layernorm = nn.LayerNorm(num_classes)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.fc_1.weight.data.uniform_(-initrange, initrange)
        self.fc_1.bias.data.zero_()
        self.fc_2.weight.data.uniform_(-initrange, initrange)
        self.fc_2.bias.data.zero_()

    def forward(self, text):
        y = self.fc_1(text)
        y = self.drop(y)
        y = self.relu(y)
        y = self.fc_2(y)        
        y = self.layernorm(y)
        return y


pt_model = TextClassificationModel(
    EMBEDDING_DIM, 
    HIDDEN_SIZE,
    31
    ).to(DEVICE)

t1 = datetime.datetime.now()
train_loader = TextDataset2(
    TRAIN_FILE,
    LABEL_FILE
)
t2 = datetime.datetime.now()
# logger.info(f'Dataset loading time: {t2-t1}')
train_dataloader = DataLoader(
    train_loader, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch
)

for labels,model_inputs in train_dataloader:
    break

model_outputs = pt_model(model_inputs)

print(model_outputs)

if isinstance(model_outputs, torch.Tensor):
    model_outputs = [model_outputs]
input_names = ["input"]
output_names = ["output"]
dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}}

f = io.BytesIO()
torch.onnx.export(
    pt_model,
    model_inputs,
    "model.onnx",
    input_names=input_names,
    output_names=output_names,
    opset_version=17,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
    dynamic_axes=dynamic_axes,
    export_params=True,
    keep_initializers_as_inputs=False,
)

onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)

# print(model_inputs)

from onnxruntime.training import artifacts
requires_grad = [name for name, param in pt_model.named_parameters() if param.requires_grad]

frozen_params = [name for name, param in pt_model.named_parameters() if not param.requires_grad]

print(requires_grad, frozen_params)
artifacts.generate_artifacts(
    onnx_model,
    optimizer=artifacts.OptimType.AdamW,
    loss=artifacts.LossType.BCEWithLogitsLoss,
    # loss=artifacts.LossType.MSELoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    artifact_directory="files",
    additional_output_names=["output"])

RuntimeError: /onnxruntime_src/orttraining/orttraining/core/graph/gradient_builder_base.h:123 onnxruntime::training::ArgDef onnxruntime::training::GradientBuilderBase::O(size_t, bool) const i < node_->OutputDefs().size() was false.

constant.py

from sentence_transformers import SentenceTransformer
import json

DROPOUT_RATE = 0.5
DEVICE = 'cpu'
HIDDEN_SIZE = 768
BATCH_SIZE = 32
LEARNING_RATE = 0.1
EPOCHS = 50
EMB_MODEL = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
EMBEDDING_DIM = EMB_MODEL.get_sentence_embedding_dimension()
TRAIN_FILE = "train.json"
TEST_FILE = "test.json"
LABEL_FILE = 'labels.json'
NUM_CLASSES = 31 # len(json.load(open(LABEL_FILE)))

It didnot work inspite of changing the loss function to CrossEntropyLoss.
Also 1.20.0 version is not available for Onnxruntime_training library.
Also using the previous version of onnxruntime_training version is also not helping to resolve this issue

@jkbeavers @riccardopinosio can you guys help me out of this ?

@riccardopinosio
Copy link
Author

@rkoystart I believe 1.20.0 does not have onnxruntime training because it's being deprecated, at least according to this page. I was not sure whether Microsoft plans to support a training flow for onnxruntime going further so I didn't spend any more time on this.

@martinkorelic
Copy link

@riccardopinosio Also hoping that the training flow support was not deprecated forever, there was no explanation added to the release notes.

I am suspecting that the error maybe is coming from the output of the model being passed to the onnx defined loss function. Perhaps you could avoid using the onnx loss function and instead define and compute the your own loss within the model before outputting the loss result?
Some HF transformers already compute the loss inside the model if provided with the "labels" argument. If you provide dummy input of labels and the output name loss then torch.onnx.export could also trace and compute the loss from my understanding.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributions welcome lower priority issues for the core ORT teams training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

No branches or pull requests

5 participants