-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to give Prompt to trained RETRO Model? #33
Comments
#23 contains a notebook with a good example. I think putting it together with the README instructions looks like this: import torch
from retro_pytorch import RETRO, TrainingWrapper
# instantiate RETRO, fit it into the TrainingWrapper with correct settings
retro = RETRO(
max_seq_len = 2048, # max sequence length
enc_dim = 896, # encoder model dimension
enc_depth = 3, # encoder depth
dec_dim = 768, # decoder model dimensions
dec_depth = 12, # decoder depth
dec_cross_attn_layers = (1, 3, 6, 9), # decoder cross attention layers (with causal chunk cross attention)
heads = 8, # attention heads
dim_head = 64, # dimension per head
dec_attn_dropout = 0.25, # decoder attention dropout
dec_ff_dropout = 0.25 # decoder feedforward dropout
).cuda()
wrapper = TrainingWrapper(
retro = retro, # path to retro instance
knn = 2, # knn (2 in paper was sufficient)
chunk_size = 64, # chunk size (64 in paper)
documents_path = './text_folder', # path to folder of text
glob = '**/*.txt', # text glob
chunks_memmap_path = './train.chunks.dat', # path to chunks
seqs_memmap_path = './train.seq.dat', # path to sequence data
doc_ids_memmap_path = './train.doc_ids.dat', # path to document ids per chunk (used for filtering neighbors belonging to same document)
max_chunks = 1_000_000, # maximum cap to chunks
max_seqs = 100_000, # maximum seqs
knn_extra_neighbors = 100, # num extra neighbors to fetch
max_index_memory_usage = '100m',
current_memory_available = '1G'
)
# get the dataloader and optimizer (AdamW with all the correct settings)
train_dl = iter(wrapper.get_dataloader(batch_size = 2, shuffle = True))
optim = wrapper.get_optimizer(lr = 3e-4, wd = 0.01)
# now do your training
# ex. one gradient step
seq, retrieved = map(lambda t: t.cuda(), next(train_dl))
# seq - (2, 2049) - 1 extra token since split by seq[:, :-1], seq[:, 1:]
# retrieved - (2, 32, 2, 128) - 128 since chunk + continuation, each 64 tokens
loss = retro(
seq,
retrieved,
return_loss = True
)
# one gradient step
loss.backward()
optim.step()
optim.zero_grad()
# do above for many steps, then ...
# encode prompt
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
prompt_str = "The movie Dune was released in"
prompt_ids = tokenizer(prompt_str)['input_ids'][1:-1]
prompt = torch.tensor([prompt_ids])
sampled = wrapper.generate(prompt, filter_thres = 0.9, temperature = 1.0)
# decode sample
decoded = tokenizer.decode(sampled.tolist()[0])
print(decoded) The code in the notebook for training several times is probably needed for good results though. |
@filipesilva Can you please share notebook which you are referencing, its not accessible. or if you can share code for training multiple epochs, will be really very helpful. Thanks |
@aakashgoel12 looks like the notebook that was in #23 is not there anymore. I don't have a copy of it, unfortunately. All the code I have is what I put in the comment. |
Thanks @filipesilva. Can you please check if what I have written below is correct or need some modification. Thanks in advance.
|
I really can't tell 😅 I only played around with this a couple of months ago and never really tried again. |
hello Is this in the path retro? Or what dataset is it? |
I am following the instructions on the RETRO-pytorch GItHub repo. After training my model, how do I go about using it to generate responses?
Now when I want to give this model a text input (any prompt), how would I go about doing that? Which method or function would I use? Which model/tokenizer should I use to encode the input prompt and then decode the model output tensor? Is there a method for that?
Example Prompt:
"The movie Dune was released in"
The text was updated successfully, but these errors were encountered: