-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
136 lines (106 loc) · 4.8 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
from datasets import load_dataset
import evaluate
import numpy as np
# Load the dataset
dataset = load_dataset("krishna8421/chats_with_title", split="train")
dataset = dataset.train_test_split(test_size=0.2)
# Load the model
checkpoint = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
source_text = dataset['train'][:]['chats']
target_text = dataset['train'][:]['title']
# Tokenize the source and target text
tokenized_source_text = tokenizer(list(source_text), truncation=False, padding=False)
tokenized_target_text = tokenizer(list(target_text), truncation=False, padding=False)
# Find maximum lengths for source and target sequences
max_source = max(len(item) for item in tokenized_source_text['input_ids'])
max_target = max(len(item) for item in tokenized_target_text['input_ids'])
# Preprocess function for mapping dataset
def preprocess_function(unit):
# Prepend "summarize: " to each chat for summarization task
inputs = ["summarize: " + con for con in unit["chats"]]
# Tokenize inputs and labels
model_inputs = tokenizer(inputs, padding='max_length', truncation=True, max_length=max_source)
labels = tokenizer(text_target=unit["title"], padding='max_length', truncation=True, max_length=max_target)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
# Map preprocess function to dataset
tokenized_data = dataset.map(preprocess_function, batched=True)
# Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)
# Load Rouge for evaluation
rouge = evaluate.load("rouge")
# Compute metrics function
def compute_metrics(preds):
predictions, labels = preds
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
result["gen_len"] = np.mean(prediction_lens)
return {k: round(v, 4) for k, v in result.items()}
# Load model
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
# Output directory for saving trained model
out_dir = "models/title-generation-checkpoint"
if not os.path.exists(out_dir):
os.makedirs(out_dir)
# Training arguments
training_args = Seq2SeqTrainingArguments(
output_dir=out_dir,
evaluation_strategy="epoch",
learning_rate=0.001,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
weight_decay=0.01,
save_total_limit=5,
num_train_epochs=20,
predict_with_generate=True,
fp16=False,
push_to_hub=False,
)
# Trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=tokenized_data["train"],
eval_dataset=tokenized_data["test"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
# Train the model
trainer.train()
# Save the trained model
trainer.save_model(os.path.join(out_dir, 'trained-model'))
# Load trained model
model = AutoModelForSeq2SeqLM.from_pretrained(os.path.join(out_dir, 'trained-model')).to('cuda')
# Example input text for inference
input_text = 'summarize: '+"""Have you ever considered taking up a new hobby or trying something completely out of your comfort zone?
Funny you mention that - I've been thinking about it lately, but I'm not sure where to start.
How about trying salsa dancing? There's a dance studio downtown that offers beginner classes.
Salsa dancing, really? I've never thought about it, but it could be fun. When do they have classes?
They have classes on Tuesday evenings. It's a great way to learn a new skill, stay active, and meet new people.
That actually sounds exciting! Okay, I'm in. Let's give salsa dancing a shot this Tuesday.
Awesome! It's a date then. I'll see you at the dance studio. Get ready to step out of your comfort zone!"""#input()
# Perform inference
with torch.no_grad():
tokenized_text = tokenizer(input_text, truncation=True, padding=True, return_tensors='pt')
source_ids = tokenized_text['input_ids'].to('cuda', dtype=torch.long)
source_mask = tokenized_text['attention_mask'].to('cuda', dtype=torch.long)
generated_ids = model.generate(
input_ids=source_ids,
attention_mask=source_mask,
max_length=512,
num_beams=5,
repetition_penalty=1,
length_penalty=1,
early_stopping=True,
no_repeat_ngram_size=2
)
pred = tokenizer.decode(generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
print("\noutput:\n" + pred)