-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
95 lines (78 loc) · 3.26 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import streamlit as st
from transformers import pipeline
import time
import pandas as pd
import numpy as np
import joblib
from transformers import AutoModel, AutoTokenizer, AutoConfig
import torch
import torch.nn as nn
import pickle
class ToxicSimpleNNModel(nn.Module):
def __init__(self, path):
super(ToxicSimpleNNModel, self).__init__()
self.backbone = AutoModel.from_pretrained(path)
self.dropout = nn.Dropout(0.3)
self.linear = nn.Linear(in_features=self.backbone.pooler.dense.out_features*2,out_features=8)
def forward(self, input_ids, attention_masks):
seq_x, _= self.backbone(input_ids=input_ids, attention_mask=attention_masks, return_dict=False)
apool = torch.mean(seq_x, 1)
mpool, _ = torch.max(seq_x, 1)
x = torch.cat((apool, mpool), 1)
x = self.dropout(x)
return self.linear(x)
def load_topic_model(base_path, model_path):
net = ToxicSimpleNNModel(base_path)
net.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
return net
@st.cache(allow_output_mutation=True,suppress_st_warning=True)
def Topic_generation_load(base_path, model_path):
print('loading topic_model')
model = load_topic_model(base_path, model_path)
return model
@st.cache(allow_output_mutation=True,suppress_st_warning=True)
def load_summarization_model():
print('loading summarization model')
summarization_pipe = pipeline('summarization', model = 't5-small')
print('sentiment model loading')
return summarization_pipe
def get_summarization(text, summarization, max_lenght):
return summarization(text, max_length=max_lenght)[0]['summary_text']
def predict_topic(text, tokenizer):
encoded = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=512,
pad_to_max_length=True
)
tokens = torch.tensor(encoded['input_ids']).unsqueeze(0)
attention_masks = torch.tensor(encoded['attention_mask']).unsqueeze(0)
outputs = model(tokens, attention_masks)
topics = nn.functional.sigmoid(outputs).detach().numpy()
_ , topics = torch.topk(torch.tensor(topics), dim = 1, k = 3)
topics = np.array(topics)
top_dic = {'0':'business','1':'elections','2':'entertainment',
'3':'news','4':'opinion','5':'sci-tech','6':'society',
'7':'sport'}
l = []
for i in topics[0]:
l.append(top_dic[str(i)])
return l
with open('tokenizer.obj' , 'rb') as f:
tokenizer = pickle.load(f)
summarization_pipe = load_summarization_model()
model = Topic_generation_load('tiny-bert' ,'model.bin')
st.title('News Summary Generation and Topic Prediction')
st.markdown('Here you can enter the News in first text box and can get news summary around the subject')
text = st.text_input('Enter News here:',key=0)
if st.checkbox('Start Generate Summary'):
st.write('uncheck the box if you are done')
option = st.sidebar.selectbox(label='Max_Lenght',options=['20','40','50'])
summary = get_summarization(text, summarization_pipe, int(option))
st.write('Final summary ' + summary)
else: pass
st.markdown('Once done you can get the top Topics the news relate too')
if st.button('Predict Topics'):
l = predict_topic(text, tokenizer)
st.markdown('Top topics related are '+' , '.join(l))
else: pass