From 3b43f4cbd303309714538a0481e681b90922db70 Mon Sep 17 00:00:00 2001 From: magpiezhang Date: Tue, 25 Jun 2024 14:29:28 +0800 Subject: [PATCH 1/3] support remove_input_padding for BertForSequenceClassification models --- examples/bert/build.py | 136 ++++++++++++++----- examples/bert/run_remove_input_padding.py | 149 ++++++++++++++++++++ tensorrt_llm/models/bert/model.py | 158 ++++++++++++++-------- 3 files changed, 350 insertions(+), 93 deletions(-) create mode 100644 examples/bert/run_remove_input_padding.py diff --git a/examples/bert/build.py b/examples/bert/build.py index df72c102c..07ce897cf 100644 --- a/examples/bert/build.py +++ b/examples/bert/build.py @@ -70,6 +70,8 @@ def parse_arguments(): parser.add_argument('--max_input_len', type=int, default=512) parser.add_argument('--gpus_per_node', type=int, default=8) parser.add_argument('--output_dir', type=str, default='bert_outputs') + + parser.add_argument('--remove_input_padding', default=False, action='store_true') parser.add_argument('--use_bert_attention_plugin', nargs='?', const='float16', @@ -101,8 +103,84 @@ def parse_arguments(): 'RobertaForQuestionAnswering', 'RobertaForSequenceClassification', ]) + parser.add_argument('--model_dir', type=str, required=True) return parser.parse_args() +def prepare_inputs(): + # opt_shape is set to half of max batch_size and seq_len by default + # tune this according to real data distribution + bs_range = [1, (args.max_batch_size + 1) // 2, args.max_batch_size] + inlen_range = [1, (args.max_input_len + 1) // 2, args.max_input_len] + num_tokens_range = [ + 1, + (args.max_input_len * args.max_batch_size + 1) // 2, + args.max_input_len * args.max_batch_size, + ] + if not args.remove_input_padding: + input_ids = tensorrt_llm.Tensor( + name='input_ids', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([('batch_size', [bs_range]), + ('input_len', [inlen_range])]), + ) + # also called segment_ids + token_type_ids = tensorrt_llm.Tensor( + name='token_type_ids', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([('batch_size', [bs_range]), + ('input_len', [inlen_range])]), + ) + position_ids = tensorrt_llm.Tensor( + name='position_ids', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([('batch_size', [bs_range]), + ('input_len', [inlen_range])]), + ) + else: + input_ids = tensorrt_llm.Tensor( + name="input_ids", + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([("num_tokens", [num_tokens_range])]), + ) + token_type_ids = tensorrt_llm.Tensor( + name='token_type_ids', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([('num_tokens', [num_tokens_range])]), + ) + position_ids = tensorrt_llm.Tensor( + name='position_ids', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([('num_tokens', + [num_tokens_range])]), + ) + + input_lengths = tensorrt_llm.Tensor( + name='input_lengths', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([('batch_size', [bs_range])]) + ) + max_input_length = tensorrt_llm.Tensor( + name="max_input_length", + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([("max_input_length", [inlen_range])]), + ) + + inputs = { + 'input_ids': input_ids, + 'input_lengths': input_lengths, + 'token_type_ids': token_type_ids, + 'position_ids': position_ids, + 'max_input_length': max_input_length, + } + return inputs if __name__ == '__main__': args = parse_arguments() @@ -110,8 +188,6 @@ def parse_arguments(): if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) - bs_range = [1, (args.max_batch_size + 1) // 2, args.max_batch_size] - inlen_range = [1, (args.max_input_len + 1) // 2, args.max_input_len] torch_dtype = torch.float16 if args.dtype == 'float16' else torch.float32 trt_dtype = trt.float16 if args.dtype == 'float16' else trt.float32 @@ -126,22 +202,27 @@ def parse_arguments(): max_input_len=args.max_input_len, ) # Initialize model - if 'Roberta' in args.model: model_type = 'Roberta' else: model_type = 'Bert' - bert_config = globals()[f'{model_type}Config']( + # initialize config with input arguments and update from json + config_cls = globals()[f'{model_type}Config'] + config = dict( vocab_size=args.vocab_size, - hidden_size=args.n_embd, + num_labels=args.n_labels, num_hidden_layers=args.n_layer, + max_position_embeddings=args.n_positions, + hidden_size=args.n_embd, num_attention_heads=args.n_head, - intermediate_size=4 * args.n_embd, + intermediate_size=4 * args.n_embd if args.n_embd else None, hidden_act=args.hidden_act, - max_position_embeddings=args.n_positions, torch_dtype=torch_dtype, ) + json_config = config_cls.get_config_dict(args.model_dir)[0] + config.update((k,v) for k,v in json_config.items() if v is not None) + bert_config = config_cls.from_dict(config) output_name = 'hidden_states' if args.model == 'BertModel' or args.model == 'RobertaModel': @@ -198,8 +279,10 @@ def parse_arguments(): ) output_name = 'logits' elif args.model == 'BertForSequenceClassification' or args.model == 'RobertaForSequenceClassification': - hf_bert = globals()[f'{model_type}ForSequenceClassification']( - bert_config).cuda().to(torch_dtype).eval() + hf_bert = globals()[f'{model_type}ForSequenceClassification'](config=bert_config) + state_dict = torch.load(os.path.join(args.model_dir, "pytorch_model.bin")) + hf_bert.load_state_dict(state_dict, strict=False) + tensorrt_llm_bert = tensorrt_llm.models.BertForSequenceClassification( num_layers=bert_config.num_hidden_layers, num_heads=bert_config.num_attention_heads, @@ -210,8 +293,7 @@ def parse_arguments(): type_vocab_size=bert_config.type_vocab_size, pad_token_id=bert_config.pad_token_id, is_roberta=(model_type == 'Roberta'), - num_labels=args. - n_labels, # TODO: this might just need to be a constant + num_labels=bert_config.num_labels, mapping=Mapping(world_size=args.world_size, rank=args.rank, tp_size=args.world_size), # TP only @@ -231,6 +313,10 @@ def parse_arguments(): # Module -> Network network = builder.create_network() network.plugin_config.to_legacy_setting() + if args.remove_input_padding: + assert args.model == "BertForSequenceClassification", \ + "remove_input_padding is only supported for BertForSequenceClassification models" + network.plugin_config.remove_input_padding = True if args.use_bert_attention_plugin: network.plugin_config.bert_attention_plugin = args.use_bert_attention_plugin if args.use_gemm_plugin: @@ -250,34 +336,10 @@ def parse_arguments(): network.set_named_parameters(tensorrt_llm_bert.named_parameters()) # Forward - input_ids = tensorrt_llm.Tensor( - name='input_ids', - dtype=trt.int32, - shape=[-1, -1], - dim_range=OrderedDict([('batch_size', [bs_range]), - ('input_len', [inlen_range])]), - ) - - # also called segment_ids - token_type_ids = tensorrt_llm.Tensor( - name='token_type_ids', - dtype=trt.int32, - shape=[-1, -1], - dim_range=OrderedDict([('batch_size', [bs_range]), - ('input_len', [inlen_range])]), - ) - - input_lengths = tensorrt_llm.Tensor(name='input_lengths', - dtype=trt.int32, - shape=[-1], - dim_range=OrderedDict([ - ('batch_size', [bs_range]) - ])) + inputs = prepare_inputs() # logits for QA BERT, or hidden_state for vanilla BERT - output = tensorrt_llm_bert(input_ids=input_ids, - input_lengths=input_lengths, - token_type_ids=token_type_ids) + output = tensorrt_llm_bert(**inputs) # Mark outputs output_dtype = trt.float16 if args.dtype == 'float16' else trt.float32 diff --git a/examples/bert/run_remove_input_padding.py b/examples/bert/run_remove_input_padding.py new file mode 100644 index 000000000..07f0ad100 --- /dev/null +++ b/examples/bert/run_remove_input_padding.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os +import random +from typing import List + +# isort: off +import torch +import tensorrt as trt +# isort: on + +import tensorrt_llm +from tensorrt_llm import logger +from tensorrt_llm.runtime import Session, TensorInfo + +from build import get_engine_name # isort:skip + + +def trt_dtype_to_torch(dtype): + if dtype == trt.float16: + return torch.float16 + elif dtype == trt.float32: + return torch.float32 + elif dtype == trt.int32: + return torch.int32 + else: + raise TypeError("%s is not supported" % dtype) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--log_level", type=str, default="info") + parser.add_argument("--engine_dir", type=str) + + return parser.parse_args() + +def process_input(input_ids_list: List[torch.Tensor], + token_type_ids_list: List[torch.Tensor]): + input_lengths = [] # [batch_size] + position_ids_list = [] # [batch_size, seq_len] + max_input_length = 0 + for i, input_ids in enumerate(input_ids_list): + input_len = len(input_ids) + assert input_len == len(token_type_ids_list[i]), f"sample {i}: len(input_ids)={len(input_ids)}, " \ + f"len(token_type_ids)={len(token_type_ids_list[i])}, not equal" + input_lengths.append(input_len) + position_ids_list.append(torch.arange(0, input_len, dtype=torch.int32)) + max_input_length = max(max_input_length, input_len) + + # [num_tokens] + input_ids = torch.concat(input_ids_list).int().cuda() + token_type_ids = torch.concat(token_type_ids_list).int().cuda() + position_ids = torch.concat(position_ids_list).int().cuda() + + input_lengths = torch.tensor(input_lengths).int().cuda() # [batch_size] + max_input_length = torch.empty((max_input_length, )).int().cuda() + return input_ids, input_lengths, token_type_ids, position_ids, max_input_length + +if __name__ == '__main__': + args = parse_arguments() + + tensorrt_llm.logger.set_level(args.log_level) + + config_path = os.path.join(args.engine_dir, 'config.json') + with open(config_path, 'r') as f: + config = json.load(f) + dtype = config['builder_config']['precision'] + world_size = config['builder_config']['tensor_parallel'] + assert world_size == tensorrt_llm.mpi_world_size(), \ + f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})' + + model_name = config['builder_config']['name'] + runtime_rank = tensorrt_llm.mpi_rank() if world_size > 1 else 0 + + runtime_mapping = tensorrt_llm.Mapping(world_size, + runtime_rank, + tp_size=world_size) + torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node) + + serialize_path = get_engine_name(model_name, dtype, world_size, + runtime_rank) + serialize_path = os.path.join(args.engine_dir, serialize_path) + + stream = torch.cuda.current_stream().cuda_stream + logger.info(f'Loading engine from {serialize_path}') + with open(serialize_path, 'rb') as f: + engine_buffer = f.read() + logger.info(f'Creating session from engine') + session = Session.from_serialized_engine(engine_buffer) + + remove_input_padding = config["plugin_config"]["remove_input_padding"] + assert remove_input_padding, "This is a demo for BERT models with remove_input_padding enabled" + + + for i in range(3): + batch_size = (i + 1) * 4 + # use list of tensor to represent unpadded samples + input_ids = [] + token_type_ids = [] + for _ in range(batch_size): + seq_len = random.randint(64, 128) + input_ids.append(torch.randint(100, size=(seq_len, )).int().cuda()) + token_type_ids.append(torch.randint(0, 1, size=(seq_len, )).int().cuda()) + + input_ids, input_lengths, token_type_ids, position_ids, max_input_length = \ + process_input(input_ids, token_type_ids) + inputs = { + "input_ids": input_ids, + "input_lengths": input_lengths, + "token_type_ids": token_type_ids, + "position_ids": position_ids, + "max_input_length": max_input_length + } + output_info = session.infer_shapes([ + TensorInfo("input_ids", trt.DataType.INT32, input_ids.shape), + TensorInfo("input_lengths", trt.DataType.INT32, input_lengths.shape), + TensorInfo("token_type_ids", trt.DataType.INT32, token_type_ids.shape), + TensorInfo("position_ids", trt.DataType.INT32, position_ids.shape), + TensorInfo("max_input_length", trt.DataType.INT32, max_input_length.shape) + ]) + outputs = { + t.name: torch.empty(tuple(t.shape), + dtype=trt_dtype_to_torch(t.dtype), + device='cuda') + for t in output_info + } + output_name = "logits" + assert output_name in outputs, f'{output_name} not found in outputs, check if build.py set output name correctly' + + ok = session.run(inputs, outputs, stream) + assert ok, "Runtime execution failed" + torch.cuda.synchronize() + res = outputs[output_name] + print(res) + diff --git a/tensorrt_llm/models/bert/model.py b/tensorrt_llm/models/bert/model.py index 122a32247..39a33a322 100644 --- a/tensorrt_llm/models/bert/model.py +++ b/tensorrt_llm/models/bert/model.py @@ -18,8 +18,8 @@ from ..._common import default_net from ...functional import (ACT2FN, bert_attention, cast, concat, constant, - expand, expand_mask, matmul, select, shape, slice, - softmax, split, unsqueeze) + expand, expand_mask, matmul, select, shape, + slice, softmax, split, unsqueeze, cumsum, index_select) from ...layers import MLP, ColumnLinear, Embedding, LayerNorm, Linear, RowLinear from ...mapping import Mapping from ...module import Module, ModuleList @@ -82,7 +82,11 @@ def __init__(self, tp_group=tp_group, tp_size=tp_size) - def forward(self, hidden_states, attention_mask=None, input_lengths=None): + def forward(self, + hidden_states, + attention_mask=None, + input_lengths=None, + max_input_length=None): qkv = self.qkv(hidden_states) # attention @@ -90,9 +94,12 @@ def forward(self, hidden_states, attention_mask=None, input_lengths=None): assert input_lengths is not None context = bert_attention(qkv, input_lengths, self.num_attention_heads, - self.attention_head_size, 1.0) + self.attention_head_size, + q_scaling=1.0, + max_input_length=max_input_length) else: - + assert not default_net().plugin_config.remove_input_padding, \ + "remove_input_padding requires bert_attention_plugin enabled" def transpose_for_scores(x): new_x_shape = concat([ shape(x, 0), @@ -156,12 +163,17 @@ def __init__(self, self.post_layernorm = LayerNorm(normalized_shape=hidden_size, dtype=dtype) - def forward(self, hidden_states, attention_mask=None, input_lengths=None): + def forward(self, + hidden_states, + attention_mask=None, + input_lengths=None, + max_input_length=None): residual = hidden_states attention_output = self.attention(hidden_states, attention_mask=attention_mask, - input_lengths=input_lengths) + input_lengths=input_lengths, + max_input_length=max_input_length) hidden_states = residual + attention_output @@ -219,53 +231,57 @@ def forward(self, input_lengths=None, position_ids=None, token_type_ids=None, - hidden_states=None): - - seq_len_2d = concat([1, shape(input_ids, 1)]) - - # create position ids - position_ids_buffer = constant( - np.expand_dims( - np.arange(self.max_position_embeddings).astype(np.int32), 0)) - tmp_position_ids = slice(position_ids_buffer, - starts=[0, 0], - sizes=seq_len_2d) - tmp_position_ids = expand(tmp_position_ids, shape(input_ids)) #BxL - tmp_input_lengths = unsqueeze(input_lengths, 1) #Bx1 - tmp_input_lengths = expand(tmp_input_lengths, shape(input_ids)) #BxL - mask = tmp_position_ids < tmp_input_lengths # BxL - mask = mask.cast('int32') - - if position_ids is None: - if self.is_roberta: - # see create_position_ids_from_input_ids() in https://github.com/huggingface/transformers/blob/main/src/transformers/models/roberta/modeling_roberta.py - position_ids = (tmp_position_ids + 1) * mask - position_ids = position_ids + self.padding_idx - else: - position_ids = slice(position_ids_buffer, + hidden_states=None, + max_input_length=None): + # remove_input_padding requires these fields as explicit input + extended_attention_mask = None + if not default_net().plugin_config.remove_input_padding: + seq_len_2d = concat([1, shape(input_ids, 1)]) + + # create position ids + position_ids_buffer = constant( + np.expand_dims( + np.arange(self.max_position_embeddings).astype(np.int32), 0)) + tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d) - position_ids = expand(position_ids, shape(input_ids)) - - # create extended_attention_mask as https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py - extended_attention_mask = expand_mask(mask, tgt_len=1) # BxL -> Bx1x1xL - - # create token_type_ids - if token_type_ids is None: - token_type_ids_buffer = constant( - np.expand_dims( - np.zeros(self.max_position_embeddings).astype(np.int32), 0)) - token_type_ids = slice(token_type_ids_buffer, - starts=[0, 0], - sizes=seq_len_2d) - token_type_ids = expand(token_type_ids, shape(input_ids)) + tmp_position_ids = expand(tmp_position_ids, shape(input_ids)) #BxL + tmp_input_lengths = unsqueeze(input_lengths, 1) #Bx1 + tmp_input_lengths = expand(tmp_input_lengths, shape(input_ids)) #BxL + mask = tmp_position_ids < tmp_input_lengths # BxL + mask = mask.cast('int32') + + if position_ids is None: + if self.is_roberta: + # see create_position_ids_from_input_ids() in https://github.com/huggingface/transformers/blob/main/src/transformers/models/roberta/modeling_roberta.py + position_ids = (tmp_position_ids + 1) * mask + position_ids = position_ids + self.padding_idx + else: + position_ids = slice(position_ids_buffer, + starts=[0, 0], + sizes=seq_len_2d) + position_ids = expand(position_ids, shape(input_ids)) + + # create extended_attention_mask as https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py + extended_attention_mask = expand_mask(mask, tgt_len=1) # BxL -> Bx1x1xL + + # create token_type_ids + if token_type_ids is None: + token_type_ids_buffer = constant( + np.expand_dims( + np.zeros(self.max_position_embeddings).astype(np.int32), 0)) + token_type_ids = slice(token_type_ids_buffer, + starts=[0, 0], + sizes=seq_len_2d) + token_type_ids = expand(token_type_ids, shape(input_ids)) hidden_states = self.embedding(input_ids, position_ids, token_type_ids) for layer in self.layers: hidden_states = layer(hidden_states=hidden_states, input_lengths=input_lengths, - attention_mask=extended_attention_mask) + attention_mask=extended_attention_mask, + max_input_length=max_input_length) return hidden_states @@ -325,10 +341,25 @@ def __init__(self, hidden_size, dtype): self.dense = Linear(hidden_size, hidden_size, dtype=dtype) self.activation = ACT2FN['tanh'] - def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = select(hidden_states, 1, 0) + def forward(self, hidden_states, input_lengths, remove_input_padding): + if not remove_input_padding: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = select(hidden_states, 1, 0) + else: + # when remove_input_padding is enabled, the shape of hidden_states is [num_tokens, hidden_size] + # We can take the first token of each sequence according to input_lengths, + # and then do pooling similar to padding mode. + # For example, if input_lengths is [8, 5, 6], then the indices of first tokens + # should be [0, 8, 13] + first_token_indices = cumsum( + concat([0, + slice(input_lengths, starts=[0], + sizes=(shape(input_lengths) - constant(np.array([1], dtype=np.int32)))) + ]), 0 + ) + first_token_tensor = index_select(hidden_states, 0, first_token_indices) + pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output @@ -389,20 +420,35 @@ def __init__(self, dtype=dtype) def forward(self, - input_ids=None, - input_lengths=None, + input_ids, + input_lengths, token_type_ids=None, position_ids=None, - hidden_states=None): - + hidden_states=None, + max_input_length=None): + + remove_input_padding = default_net().plugin_config.remove_input_padding + + # required as explicit input in remove_input_padding mode + # see examples/bert/run_remove_input_padding.py for how to create them from input_ids and input_lengths + if remove_input_padding: + assert token_type_ids is not None and \ + position_ids is not None and \ + max_input_length is not None, \ + "token_type_ids, position_ids, max_input_length is required " \ + "in remove_input_padding mode" + hidden_states = self.bert.forward(input_ids=input_ids, input_lengths=input_lengths, token_type_ids=token_type_ids, position_ids=position_ids, - hidden_states=hidden_states) + hidden_states=hidden_states, + max_input_length=max_input_length) if not self.is_roberta: - pooled_output = self.pooler(hidden_states) + pooled_output = self.pooler(hidden_states=hidden_states, + input_lengths=input_lengths, + remove_input_padding=remove_input_padding) logits = self.classifier(pooled_output) else: logits = self.classifier(hidden_states) From cbf60cc425230cb71d35858a913fd066103c06df Mon Sep 17 00:00:00 2001 From: magpiezhang Date: Tue, 25 Jun 2024 18:02:52 +0800 Subject: [PATCH 2/3] support remove_input_padding for BertForSequenceClassification models --- examples/bert/run.py | 4 ++++ examples/bert/run_remove_input_padding.py | 4 ++-- tensorrt_llm/models/bert/model.py | 4 ++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/bert/run.py b/examples/bert/run.py index d1a5b0cff..e9307e1ff 100644 --- a/examples/bert/run.py +++ b/examples/bert/run.py @@ -55,6 +55,10 @@ def parse_arguments(): config_path = os.path.join(args.engine_dir, 'config.json') with open(config_path, 'r') as f: config = json.load(f) + + assert config["plugin_config"]["remove_input_padding"] == False, \ + "Please refer to run_remove_input_padding.py for running BERT models with remove_input_padding enabled" + dtype = config['builder_config']['precision'] world_size = config['builder_config']['tensor_parallel'] assert world_size == tensorrt_llm.mpi_world_size(), \ diff --git a/examples/bert/run_remove_input_padding.py b/examples/bert/run_remove_input_padding.py index 07f0ad100..5231b5e44 100644 --- a/examples/bert/run_remove_input_padding.py +++ b/examples/bert/run_remove_input_padding.py @@ -50,8 +50,8 @@ def parse_arguments(): def process_input(input_ids_list: List[torch.Tensor], token_type_ids_list: List[torch.Tensor]): - input_lengths = [] # [batch_size] - position_ids_list = [] # [batch_size, seq_len] + input_lengths = [] + position_ids_list = [] max_input_length = 0 for i, input_ids in enumerate(input_ids_list): input_len = len(input_ids) diff --git a/tensorrt_llm/models/bert/model.py b/tensorrt_llm/models/bert/model.py index 39a33a322..09fdd9283 100644 --- a/tensorrt_llm/models/bert/model.py +++ b/tensorrt_llm/models/bert/model.py @@ -18,8 +18,8 @@ from ..._common import default_net from ...functional import (ACT2FN, bert_attention, cast, concat, constant, - expand, expand_mask, matmul, select, shape, - slice, softmax, split, unsqueeze, cumsum, index_select) + expand, expand_mask, matmul, select, shape, slice, + softmax, split, unsqueeze, cumsum, index_select) from ...layers import MLP, ColumnLinear, Embedding, LayerNorm, Linear, RowLinear from ...mapping import Mapping from ...module import Module, ModuleList From 37e3579282215f23891c3ef6ec9486ea3327400c Mon Sep 17 00:00:00 2001 From: magpiezhang Date: Wed, 26 Jun 2024 12:47:05 +0800 Subject: [PATCH 3/3] backward compatibilty for no model_dir arg --- examples/bert/build.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/bert/build.py b/examples/bert/build.py index 07ce897cf..a9650df34 100644 --- a/examples/bert/build.py +++ b/examples/bert/build.py @@ -103,7 +103,7 @@ def parse_arguments(): 'RobertaForQuestionAnswering', 'RobertaForSequenceClassification', ]) - parser.add_argument('--model_dir', type=str, required=True) + parser.add_argument('--model_dir', type=str, default=None) return parser.parse_args() def prepare_inputs(): @@ -279,9 +279,11 @@ def prepare_inputs(): ) output_name = 'logits' elif args.model == 'BertForSequenceClassification' or args.model == 'RobertaForSequenceClassification': - hf_bert = globals()[f'{model_type}ForSequenceClassification'](config=bert_config) - state_dict = torch.load(os.path.join(args.model_dir, "pytorch_model.bin")) - hf_bert.load_state_dict(state_dict, strict=False) + hf_bert = globals()[f'{model_type}ForSequenceClassification']( + bert_config).cuda().to(torch_dtype).eval() + if args.model_dir: + state_dict = torch.load(os.path.join(args.model_dir, "pytorch_model.bin")) + hf_bert.load_state_dict(state_dict, strict=False) tensorrt_llm_bert = tensorrt_llm.models.BertForSequenceClassification( num_layers=bert_config.num_hidden_layers,