Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug in tying OPT embeddings #1

Merged
merged 1 commit into from
Feb 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cacheflow/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def get_model(
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()]
else:
torch_dtype = dtype
for model_class, model in MODEL_CLASSES.items():
for model_class, hf_model in MODEL_CLASSES.items():
if model_class in model_name:
return model.from_pretrained(model_name, torch_dtype=torch_dtype)
model = hf_model.from_pretrained(model_name, torch_dtype=torch_dtype)
return model.eval()
raise ValueError(f'Invalid model name: {model_name}')
22 changes: 22 additions & 0 deletions cacheflow/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,28 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

# NOTE(woosuk): While the following methods are not called in the model code,
# they may be internally used by the transformers library.
# For example, tie_weights() does not work without these methods.
# Thus, do not delete these methods.
def get_input_embeddings(self):
return self.model.decoder.embed_tokens

def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value

def get_output_embeddings(self):
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def set_decoder(self, decoder):
self.model.decoder = decoder

def get_decoder(self):
return self.model.decoder

def forward(
self,
input_ids: torch.LongTensor,
Expand Down