Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support remove_input_padding for BertForSequenceClassification models #1834

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 99 additions & 37 deletions examples/bert/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def parse_arguments():
parser.add_argument('--max_input_len', type=int, default=512)
parser.add_argument('--gpus_per_node', type=int, default=8)
parser.add_argument('--output_dir', type=str, default='bert_outputs')

parser.add_argument('--remove_input_padding', default=False, action='store_true')
parser.add_argument('--use_bert_attention_plugin',
nargs='?',
const='float16',
Expand Down Expand Up @@ -101,17 +103,91 @@ def parse_arguments():
'RobertaForQuestionAnswering',
'RobertaForSequenceClassification',
])
parser.add_argument('--model_dir', type=str, required=True)
nv-guomingz marked this conversation as resolved.
Show resolved Hide resolved
return parser.parse_args()

def prepare_inputs():
# opt_shape is set to half of max batch_size and seq_len by default
# tune this according to real data distribution
bs_range = [1, (args.max_batch_size + 1) // 2, args.max_batch_size]
inlen_range = [1, (args.max_input_len + 1) // 2, args.max_input_len]
num_tokens_range = [
1,
(args.max_input_len * args.max_batch_size + 1) // 2,
args.max_input_len * args.max_batch_size,
]
if not args.remove_input_padding:
input_ids = tensorrt_llm.Tensor(
name='input_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([('batch_size', [bs_range]),
('input_len', [inlen_range])]),
)
# also called segment_ids
token_type_ids = tensorrt_llm.Tensor(
name='token_type_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([('batch_size', [bs_range]),
('input_len', [inlen_range])]),
)
position_ids = tensorrt_llm.Tensor(
name='position_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([('batch_size', [bs_range]),
('input_len', [inlen_range])]),
)
else:
input_ids = tensorrt_llm.Tensor(
name="input_ids",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("num_tokens", [num_tokens_range])]),
)
token_type_ids = tensorrt_llm.Tensor(
name='token_type_ids',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('num_tokens', [num_tokens_range])]),
)
position_ids = tensorrt_llm.Tensor(
name='position_ids',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('num_tokens',
[num_tokens_range])]),
)

input_lengths = tensorrt_llm.Tensor(
name='input_lengths',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('batch_size', [bs_range])])
)
max_input_length = tensorrt_llm.Tensor(
name="max_input_length",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("max_input_length", [inlen_range])]),
)

inputs = {
'input_ids': input_ids,
'input_lengths': input_lengths,
'token_type_ids': token_type_ids,
'position_ids': position_ids,
'max_input_length': max_input_length,
}
return inputs

if __name__ == '__main__':
args = parse_arguments()
tensorrt_llm.logger.set_level(args.log_level)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)

bs_range = [1, (args.max_batch_size + 1) // 2, args.max_batch_size]
inlen_range = [1, (args.max_input_len + 1) // 2, args.max_input_len]
torch_dtype = torch.float16 if args.dtype == 'float16' else torch.float32
trt_dtype = trt.float16 if args.dtype == 'float16' else trt.float32

Expand All @@ -126,22 +202,27 @@ def parse_arguments():
max_input_len=args.max_input_len,
)
# Initialize model

if 'Roberta' in args.model:
model_type = 'Roberta'
else:
model_type = 'Bert'

bert_config = globals()[f'{model_type}Config'](
# initialize config with input arguments and update from json
config_cls = globals()[f'{model_type}Config']
config = dict(
vocab_size=args.vocab_size,
hidden_size=args.n_embd,
num_labels=args.n_labels,
num_hidden_layers=args.n_layer,
max_position_embeddings=args.n_positions,
hidden_size=args.n_embd,
num_attention_heads=args.n_head,
intermediate_size=4 * args.n_embd,
intermediate_size=4 * args.n_embd if args.n_embd else None,
hidden_act=args.hidden_act,
max_position_embeddings=args.n_positions,
torch_dtype=torch_dtype,
)
json_config = config_cls.get_config_dict(args.model_dir)[0]
config.update((k,v) for k,v in json_config.items() if v is not None)
bert_config = config_cls.from_dict(config)

output_name = 'hidden_states'
if args.model == 'BertModel' or args.model == 'RobertaModel':
Expand Down Expand Up @@ -198,8 +279,10 @@ def parse_arguments():
)
output_name = 'logits'
elif args.model == 'BertForSequenceClassification' or args.model == 'RobertaForSequenceClassification':
hf_bert = globals()[f'{model_type}ForSequenceClassification'](
bert_config).cuda().to(torch_dtype).eval()
hf_bert = globals()[f'{model_type}ForSequenceClassification'](config=bert_config)
state_dict = torch.load(os.path.join(args.model_dir, "pytorch_model.bin"))
hf_bert.load_state_dict(state_dict, strict=False)

tensorrt_llm_bert = tensorrt_llm.models.BertForSequenceClassification(
num_layers=bert_config.num_hidden_layers,
num_heads=bert_config.num_attention_heads,
Expand All @@ -210,8 +293,7 @@ def parse_arguments():
type_vocab_size=bert_config.type_vocab_size,
pad_token_id=bert_config.pad_token_id,
is_roberta=(model_type == 'Roberta'),
num_labels=args.
n_labels, # TODO: this might just need to be a constant
num_labels=bert_config.num_labels,
mapping=Mapping(world_size=args.world_size,
rank=args.rank,
tp_size=args.world_size), # TP only
Expand All @@ -231,6 +313,10 @@ def parse_arguments():
# Module -> Network
network = builder.create_network()
network.plugin_config.to_legacy_setting()
if args.remove_input_padding:
assert args.model == "BertForSequenceClassification", \
"remove_input_padding is only supported for BertForSequenceClassification models"
network.plugin_config.remove_input_padding = True
if args.use_bert_attention_plugin:
network.plugin_config.bert_attention_plugin = args.use_bert_attention_plugin
if args.use_gemm_plugin:
Expand All @@ -250,34 +336,10 @@ def parse_arguments():
network.set_named_parameters(tensorrt_llm_bert.named_parameters())

# Forward
input_ids = tensorrt_llm.Tensor(
name='input_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([('batch_size', [bs_range]),
('input_len', [inlen_range])]),
)

# also called segment_ids
token_type_ids = tensorrt_llm.Tensor(
name='token_type_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([('batch_size', [bs_range]),
('input_len', [inlen_range])]),
)

input_lengths = tensorrt_llm.Tensor(name='input_lengths',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([
('batch_size', [bs_range])
]))
inputs = prepare_inputs()

# logits for QA BERT, or hidden_state for vanilla BERT
output = tensorrt_llm_bert(input_ids=input_ids,
input_lengths=input_lengths,
token_type_ids=token_type_ids)
output = tensorrt_llm_bert(**inputs)

# Mark outputs
output_dtype = trt.float16 if args.dtype == 'float16' else trt.float32
Expand Down
4 changes: 4 additions & 0 deletions examples/bert/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def parse_arguments():
config_path = os.path.join(args.engine_dir, 'config.json')
with open(config_path, 'r') as f:
config = json.load(f)

assert config["plugin_config"]["remove_input_padding"] == False, \
"Please refer to run_remove_input_padding.py for running BERT models with remove_input_padding enabled"

dtype = config['builder_config']['precision']
world_size = config['builder_config']['tensor_parallel']
assert world_size == tensorrt_llm.mpi_world_size(), \
Expand Down
149 changes: 149 additions & 0 deletions examples/bert/run_remove_input_padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import random
from typing import List

# isort: off
import torch
import tensorrt as trt
# isort: on

import tensorrt_llm
from tensorrt_llm import logger
from tensorrt_llm.runtime import Session, TensorInfo

from build import get_engine_name # isort:skip


def trt_dtype_to_torch(dtype):
if dtype == trt.float16:
return torch.float16
elif dtype == trt.float32:
return torch.float32
elif dtype == trt.int32:
return torch.int32
else:
raise TypeError("%s is not supported" % dtype)


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--log_level", type=str, default="info")
parser.add_argument("--engine_dir", type=str)

return parser.parse_args()

def process_input(input_ids_list: List[torch.Tensor],
token_type_ids_list: List[torch.Tensor]):
input_lengths = []
position_ids_list = []
max_input_length = 0
for i, input_ids in enumerate(input_ids_list):
input_len = len(input_ids)
assert input_len == len(token_type_ids_list[i]), f"sample {i}: len(input_ids)={len(input_ids)}, " \
f"len(token_type_ids)={len(token_type_ids_list[i])}, not equal"
input_lengths.append(input_len)
position_ids_list.append(torch.arange(0, input_len, dtype=torch.int32))
max_input_length = max(max_input_length, input_len)

# [num_tokens]
input_ids = torch.concat(input_ids_list).int().cuda()
token_type_ids = torch.concat(token_type_ids_list).int().cuda()
position_ids = torch.concat(position_ids_list).int().cuda()

input_lengths = torch.tensor(input_lengths).int().cuda() # [batch_size]
max_input_length = torch.empty((max_input_length, )).int().cuda()
return input_ids, input_lengths, token_type_ids, position_ids, max_input_length

if __name__ == '__main__':
args = parse_arguments()

tensorrt_llm.logger.set_level(args.log_level)

config_path = os.path.join(args.engine_dir, 'config.json')
with open(config_path, 'r') as f:
config = json.load(f)
dtype = config['builder_config']['precision']
world_size = config['builder_config']['tensor_parallel']
assert world_size == tensorrt_llm.mpi_world_size(), \
f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'

model_name = config['builder_config']['name']
runtime_rank = tensorrt_llm.mpi_rank() if world_size > 1 else 0

runtime_mapping = tensorrt_llm.Mapping(world_size,
runtime_rank,
tp_size=world_size)
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)

serialize_path = get_engine_name(model_name, dtype, world_size,
runtime_rank)
serialize_path = os.path.join(args.engine_dir, serialize_path)

stream = torch.cuda.current_stream().cuda_stream
logger.info(f'Loading engine from {serialize_path}')
with open(serialize_path, 'rb') as f:
engine_buffer = f.read()
logger.info(f'Creating session from engine')
session = Session.from_serialized_engine(engine_buffer)

remove_input_padding = config["plugin_config"]["remove_input_padding"]
assert remove_input_padding, "This is a demo for BERT models with remove_input_padding enabled"


for i in range(3):
batch_size = (i + 1) * 4
# use list of tensor to represent unpadded samples
input_ids = []
token_type_ids = []
for _ in range(batch_size):
seq_len = random.randint(64, 128)
input_ids.append(torch.randint(100, size=(seq_len, )).int().cuda())
token_type_ids.append(torch.randint(0, 1, size=(seq_len, )).int().cuda())

input_ids, input_lengths, token_type_ids, position_ids, max_input_length = \
process_input(input_ids, token_type_ids)
inputs = {
"input_ids": input_ids,
"input_lengths": input_lengths,
"token_type_ids": token_type_ids,
"position_ids": position_ids,
"max_input_length": max_input_length
}
output_info = session.infer_shapes([
TensorInfo("input_ids", trt.DataType.INT32, input_ids.shape),
TensorInfo("input_lengths", trt.DataType.INT32, input_lengths.shape),
TensorInfo("token_type_ids", trt.DataType.INT32, token_type_ids.shape),
TensorInfo("position_ids", trt.DataType.INT32, position_ids.shape),
TensorInfo("max_input_length", trt.DataType.INT32, max_input_length.shape)
])
outputs = {
t.name: torch.empty(tuple(t.shape),
dtype=trt_dtype_to_torch(t.dtype),
device='cuda')
for t in output_info
}
output_name = "logits"
assert output_name in outputs, f'{output_name} not found in outputs, check if build.py set output name correctly'

ok = session.run(inputs, outputs, stream)
assert ok, "Runtime execution failed"
torch.cuda.synchronize()
res = outputs[output_name]
print(res)

Loading