diff --git a/cacheflow/models/model_utils.py b/cacheflow/models/model_utils.py index 80b7c01092474..0e7e4d3b2dd09 100644 --- a/cacheflow/models/model_utils.py +++ b/cacheflow/models/model_utils.py @@ -25,7 +25,8 @@ def get_model( torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()] else: torch_dtype = dtype - for model_class, model in MODEL_CLASSES.items(): + for model_class, hf_model in MODEL_CLASSES.items(): if model_class in model_name: - return model.from_pretrained(model_name, torch_dtype=torch_dtype) + model = hf_model.from_pretrained(model_name, torch_dtype=torch_dtype) + return model.eval() raise ValueError(f'Invalid model name: {model_name}') diff --git a/cacheflow/models/opt.py b/cacheflow/models/opt.py index 422ec2632d264..14d38d4073195 100644 --- a/cacheflow/models/opt.py +++ b/cacheflow/models/opt.py @@ -232,6 +232,28 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + # NOTE(woosuk): While the following methods are not called in the model code, + # they may be internally used by the transformers library. + # For example, tie_weights() does not work without these methods. + # Thus, do not delete these methods. + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + def forward( self, input_ids: torch.LongTensor,