We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
in tokenize_data.py
tokenize_data.py
def embed_dataset_batch(model: InversionModel, batch: Dict) -> Dict: assert "input_ids" in batch.keys(), f"invalid keys {batch.keys()}" assert hasattr(model, "call_embedding_model") input_ids = batch["input_ids"] inputs_str = model.tokenizer.batch_decode(input_ids, skip_special_tokens=True) emb_input_ids = model.embedder_tokenizer( inputs_str, max_length=model.config.max_seq_length, truncation=True, padding="max_length", return_tensors="pt", ).to(next(model.parameters()).device) with torch.no_grad(): batch["frozen_embeddings"] = model.call_embedding_model(**emb_input_ids) return batch
the tokens of embedder are sent to call_embedding_model.
call_embedding_model
But in models/inversion_from_logits.py
models/inversion_from_logits.py
def call_embedding_model( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: embedder = self.embedder inputs_str = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True) emb_input_ids = self.embedder_tokenizer( inputs_str, max_length=self.config.max_seq_length, truncation=True, padding="max_length", return_tensors="pt", ).to(next(self.parameters()).device) model_output = embedder(**emb_input_ids) return self._process_embedder_output(model_output, emb_input_ids.attention_mask)
This function expects the model.tokenizer's, not model.embedder_tokenizers's tokens.
model.tokenizer
model.embedder_tokenizers
This causes gibberish tokens to be sent to embedder.
The text was updated successfully, but these errors were encountered:
I can raise a PR if needed.
Sorry, something went wrong.
@themurtazanazir thank you for finding this – a pull request would be amazing!
No branches or pull requests
in
tokenize_data.py
the tokens of embedder are sent to
call_embedding_model
.But in
models/inversion_from_logits.py
This function expects the
model.tokenizer
's, notmodel.embedder_tokenizers
's tokens.This causes gibberish tokens to be sent to embedder.
The text was updated successfully, but these errors were encountered: