Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add module __repr__ methods #2191

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion tensorrt_llm/models/qwen/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Optional, Union

from ..._utils import pad_vocab_size
from ...logger import logger
from ...functional import Tensor, recv, send, sigmoid
from ...layers import (MLP, MOE, Attention, AttentionMaskType, ColumnLinear,
Embedding, GatedMLP, RmsNorm, RowLinear)
Expand Down Expand Up @@ -309,13 +310,19 @@ def from_hugging_face(

if not use_preloading:
hf_model = load_hf_qwen(hf_model_dir, load_model_on_cpu)

logger.info(f"HuggingFace model: {hf_model}")

model = QWenForCausalLM(config)

logger.info(f"TensorRT-LLM model: {model}")

if use_hf_gptq_checkpoint:
weights = load_weights_from_hf_gptq_model(hf_model, config)
else:
weights = load_weights_from_hf_model(hf_model, config)

check_share_embedding(weights, config)
model = QWenForCausalLM(config)
model.load(weights)
return model

Expand Down
61 changes: 61 additions & 0 deletions tensorrt_llm/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@
from .logger import logger


def _addindent(s_, numSpaces):
s = s_.split('\n')
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s


class Module(object):

def __init__(self) -> None:
Expand Down Expand Up @@ -191,6 +203,23 @@ def update_parameters(self, torch_module):
for k, v in self.named_parameters():
v.value = tm[k].detach().cpu().numpy()

def _get_name(self):
return self.__class__.__name__

def __repr__(self):
# We treat the extra repr like the sub-module, one item per line
child_lines = []
for key, module in self._modules.items():
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append('(' + key + '): ' + mod_str)
main_str = self._get_name() + '('
if child_lines:
# simple one-liner info, which most builtin Modules will use
main_str += '\n ' + '\n '.join(child_lines) + '\n'
main_str += ')'
return main_str


class ModuleList(Module):

Expand Down Expand Up @@ -221,3 +250,35 @@ def __setitem__(self, idx, module) -> None:

def __len__(self):
return len(self._modules)

def __repr__(self):
"""Return a custom repr for ModuleList that compresses repeated module representations."""
list_of_reprs = [repr(item) for item in self]
if len(list_of_reprs) == 0:
return self._get_name() + "()"

start_end_indices = [[0, 0]]
repeated_blocks = [list_of_reprs[0]]
for i, r in enumerate(list_of_reprs[1:], 1):
if r == repeated_blocks[-1]:
start_end_indices[-1][1] += 1
continue

start_end_indices.append([i, i])
repeated_blocks.append(r)

lines = []
main_str = self._get_name() + "("
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
local_repr = f"({start_id}): {b}" # default repr

if start_id != end_id:
n = end_id - start_id + 1
local_repr = f"({start_id}-{end_id}): {n} x {b}"

local_repr = _addindent(local_repr, 2)
lines.append(local_repr)

main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str