Skip to content

Commit

Permalink
280: Add new models, add preprocessing text
Browse files Browse the repository at this point in the history
  • Loading branch information
Smixie committed Dec 23, 2024
1 parent 63618e3 commit c5b0c55
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 68 deletions.
6 changes: 4 additions & 2 deletions nlp/src/globals.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
language_model = None
sentiment_tokenizer = None
sentiment_model = None
sentiment_tokenizer_english = None
sentiment_model_english = None
sentiment_model_polish = None
sarcastic_pipeline = None
spam_pipeline = None
political_pipeline = None
hate_speech_english_pipeline = None
hate_speech_polish_pipeline = None
clickbait_pipeline = None
keyword_model = None
ai_detector_pipeline_english = None
156 changes: 101 additions & 55 deletions nlp/src/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import src.authorization as auth
import src.globals as globals
import src.globals as gl
import time

from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.params import Depends
from langcodes import tag_is_valid, Language
from src.base_models import *
from src.logger_config import logger
from src.models_loading import *
Expand All @@ -32,7 +31,7 @@ async def lifespan(application: FastAPI):
raise

models = [
"language", "sentiment", "sarcasm", "spam",
"language", "sentiment", "ai_detector", "sarcasm", "spam",
"political", "hate_speech", "clickbait", "keywords"
]

Expand Down Expand Up @@ -65,7 +64,11 @@ async def get_sentiment(request: TextRequest):
Returns:
TextResponse: Sentiment analysis results with labels and confidence scores.
"""
if globals.sentiment_model is None or globals.sentiment_tokenizer is None:
if gl.language_model is None:
raise HTTPException(status_code=500, detail="Language model not loaded")

if (gl.sentiment_model_english is None or gl.sentiment_tokenizer is None
or gl.sentiment_model_polish is None):
raise HTTPException(status_code=500, detail="Model not loaded")

labels = {
Expand All @@ -78,21 +81,31 @@ async def get_sentiment(request: TextRequest):
}

def process_fn(text):
tokenized_text = globals.sentiment_tokenizer([preprocess_data(text)],
padding=True, truncation=True,
max_length=128,
return_tensors="pt")
output = globals.sentiment_model(**tokenized_text)
probs = output.logits.softmax(dim=-1).tolist()[0]
confidence = max(probs)
prediction = probs.index(confidence)
result = AnalysisResult(
labels=[labels[prediction]],
confidences=[confidence_output(confidence)]
)
text = preprocess_data(text)
lang_name = detect_language(text)
if lang_name == "Polish":
predict = gl.sentiment_model_polish(text)
result = AnalysisResult(
labels=[predict[0]["label"]],
confidences=[confidence_output(predict[0]["score"])]
)
else:
tokenized_text = gl.sentiment_tokenizer([preprocess_data(text)],
padding=True, truncation=True,
max_length=128,
return_tensors="pt")
output = gl.sentiment_model_english(**tokenized_text)
probs = output.logits.softmax(dim=-1).tolist()[0]
confidence = max(probs)
prediction = probs.index(confidence)
result = AnalysisResult(
labels=[labels[prediction]],
confidences=[confidence_output(confidence)]
)
return result

results, generation_time = measure_execution_time(lambda: process_inputs(process_fn, request.text))
results, generation_time = measure_execution_time(
lambda: process_inputs(process_fn, request.text))

return TextResponse(
kind="sentiment",
Expand All @@ -113,12 +126,13 @@ async def get_language(request: TextRequest):
Returns:
TextResponse: Language detection results with labels and confidence scores.
"""
if globals.language_model is None:
if gl.language_model is None:
raise HTTPException(status_code=500, detail="Model not loaded")

def process_fn(text):
languages = []
prediction = globals.language_model.predict(text, k=2)
text = preprocess_data(text)
prediction = gl.language_model.predict(text, k=2)

for predicted_lang in prediction[0]:
lang_tag = predicted_lang.rsplit("_")[-2]
Expand All @@ -134,7 +148,8 @@ def process_fn(text):
)
return result

results, generation_time = measure_execution_time(lambda: process_inputs(process_fn, request.text))
results, generation_time = measure_execution_time(
lambda: process_inputs(process_fn, request.text))

return TextResponse(
kind="language",
Expand All @@ -155,7 +170,7 @@ async def get_sarcasm(request: TextRequest):
Returns:
TextResponse: Sarcasm detection results with labels and confidence scores.
"""
if globals.sarcastic_pipeline is None:
if gl.sarcastic_pipeline is None:
raise HTTPException(status_code=500, detail="Model not loaded")

labels = {
Expand All @@ -164,14 +179,16 @@ async def get_sarcasm(request: TextRequest):
}

def process_fn(text):
predict = globals.sarcastic_pipeline(text)
text = preprocess_data(text)
predict = gl.sarcastic_pipeline(text)
result = AnalysisResult(
labels=[labels[predict[0]["label"].replace("LABEL_", "")]],
confidences=[confidence_output(predict[0]["score"])]
)
return result

results, generation_time = measure_execution_time(lambda: process_inputs(process_fn, request.text))
results, generation_time = measure_execution_time(
lambda: process_inputs(process_fn, request.text))

return TextResponse(
kind="sarcasm",
Expand All @@ -194,18 +211,20 @@ async def get_keywords(request: TextRequest):
"""
# Remember add author if we want to use this model
# How keywords will work with long text?
if globals.keyword_model is None:
if gl.keyword_model is None:
raise HTTPException(status_code=500, detail="Model not loaded")

def process_fn(text):
keywords = globals.keyword_model.extract_keywords(text, top_n=5)
text = preprocess_data(text)
keywords = gl.keyword_model.extract_keywords(text, top_n=5)
result = AnalysisResult(
labels=[kw[0] for kw in keywords],
confidences=[kw[1] for kw in keywords]
)
return result

results, generation_time = measure_execution_time(lambda: process_inputs(process_fn, request.text))
results, generation_time = measure_execution_time(
lambda: process_inputs(process_fn, request.text))

return TextResponse(
kind="keywords",
Expand All @@ -226,7 +245,7 @@ async def get_spam(request: TextRequest):
Returns:
TextResponse: Spam detection results with labels and confidence scores.
"""
if globals.spam_pipeline is None:
if gl.spam_pipeline is None:
raise HTTPException(status_code=500, detail="Model not loaded")

labels = {
Expand All @@ -235,14 +254,16 @@ async def get_spam(request: TextRequest):
}

def process_fn(text):
predict = globals.spam_pipeline(text)
text = preprocess_data(text)
predict = gl.spam_pipeline(text)
result = AnalysisResult(
labels=[labels[predict[0]["label"].replace("LABEL_", "")]],
confidences=[confidence_output(predict[0]["score"])]
)
return result

results, generation_time = measure_execution_time(lambda: process_inputs(process_fn, request.text))
results, generation_time = measure_execution_time(
lambda: process_inputs(process_fn, request.text))

return TextResponse(
kind="spam",
Expand All @@ -265,18 +286,20 @@ async def get_politics(request: TextRequest):
"""
# Model accuracy may not hold up on pieces of text longer than a tweet.
# Slice it to smaller pieces if needed?
if globals.political_pipeline is None:
if gl.political_pipeline is None:
raise HTTPException(status_code=500, detail="Model not loaded")

def process_fn(text):
predict = globals.political_pipeline(text)
text = preprocess_data(text)
predict = gl.political_pipeline(text)
result = AnalysisResult(
labels=[predict[0]["label"]],
confidences=[confidence_output(predict[0]["score"])]
)
return result

results, generation_time = measure_execution_time(lambda: process_inputs(process_fn, request.text))
results, generation_time = measure_execution_time(
lambda: process_inputs(process_fn, request.text))

return TextResponse(
kind="politics",
Expand All @@ -300,35 +323,29 @@ async def get_hate_speech(request: TextRequest):
# Do zamieszczenia bibliografie z linku
# https: // huggingface.co / Hate - speech - CNERG / dehatebert - mono - english
# Pamiętamy
if globals.language_model is None:
if gl.language_model is None:
raise HTTPException(status_code=500, detail="Language model not loaded")

if (globals.hate_speech_english_pipeline is None or
globals.hate_speech_polish_pipeline is None):
if (gl.hate_speech_english_pipeline is None or
gl.hate_speech_polish_pipeline is None):
raise HTTPException(status_code=500, detail="Model not loaded")

def detect_language(text):
prediction = globals.language_model.predict(text, k=1)
lang_tag = prediction[0][0].rsplit("_")[-2]
if tag_is_valid(lang_tag):
lang_name = Language.get(lang_tag).display_name("en")
return lang_name
return lang_tag

def process_fn(text):
text = preprocess_data(text)
lang_name = detect_language(text)
if lang_name == "Polish":
predict = globals.hate_speech_polish_pipeline(text)
predict = gl.hate_speech_polish_pipeline(text)
else:
predict = globals.hate_speech_english_pipeline(text)
predict = gl.hate_speech_english_pipeline(text)

result = AnalysisResult(
labels=[predict[0]["label"]],
confidences=[confidence_output(predict[0]["score"])]
)
return result

results, generation_time = measure_execution_time(lambda: process_inputs(process_fn, request.text))
results, generation_time = measure_execution_time(
lambda: process_inputs(process_fn, request.text))

return TextResponse(
kind="hateSpeech",
Expand All @@ -349,18 +366,20 @@ async def get_clickbait(request: TextRequest):
Returns:
TextResponse: Clickbait detection results with labels and confidence scores.
"""
if globals.clickbait_pipeline is None:
if gl.clickbait_pipeline is None:
raise HTTPException(status_code=500, detail="Model not loaded")

def process_fn(text):
predict = globals.clickbait_pipeline(text)
text = preprocess_data(text)
predict = gl.clickbait_pipeline(text)
result = AnalysisResult(
labels=[predict[0]["label"]],
confidences=[confidence_output(predict[0]["score"])]
)
return result

results, generation_time = measure_execution_time(lambda: process_inputs(process_fn, request.text))
results, generation_time = measure_execution_time(
lambda: process_inputs(process_fn, request.text))

return TextResponse(
kind="clickbait",
Expand All @@ -369,17 +388,44 @@ def process_fn(text):
).json()


# TODO: Implement troll model
@app.post("/troll")
async def get_troll(request: TextRequest):
@app.post("/ai-detector", response_model=TextResponse,
dependencies=[Depends(auth.get_api_key)])
async def get_ai_bots(request: TextRequest):
"""
Detect troll content in the provided text(s).
Detect if the text is written by a AI bot or by human.
Args:
request (TextRequest): Contains a list of text strings to analyze.
Returns:
dict: Troll detection results.
"""
text = request.text[0]
return {"troll": text}
if (gl.ai_detector_pipeline_english is None):
raise HTTPException(status_code=500, detail="Model not loaded")

def process_fn(text):
text = preprocess_data(text)
prediction = gl.ai_detector_pipeline_english(text)
confidence = prediction[0]["score"]
if confidence > 0.5:
label="AI"
confidence = confidence_output(confidence)
else:
label="Human"
confidence = confidence_output(1-confidence)

result = AnalysisResult(
labels=[label],
confidences=[confidence]
)
return result

results, generation_time = measure_execution_time(
lambda: process_inputs(process_fn, request.text))


return TextResponse(
kind="aiDetector",
metadata=Metadata(generated_in=generation_time),
results=results
).json()
18 changes: 13 additions & 5 deletions nlp/src/models_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import src.globals as g
import src.utils as utils

from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from transformers import (AutoTokenizer, AutoModelForSequenceClassification, pipeline,
AutoModel)
from keybert import KeyBERT
from sentence_transformers import SentenceTransformer

Expand All @@ -26,10 +27,10 @@ def load_model(model_name: str):
for language in model_languages:
model_path = os.path.join("models", model_name, language)

if not os.path.exists(model_path) and model_name != "keywords":
if not os.path.exists(model_path):
raise RuntimeError("Model file not found")

if model_name != "language" and model_name != "keywords":
if model_name != "language":
try:
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(
Expand All @@ -49,8 +50,12 @@ def load_model(model_name: str):
g.sarcastic_pipeline = pipeline("text-classification", model=model,
tokenizer=tokenizer)
elif model_name == "sentiment":
g.sentiment_model = model
g.sentiment_tokenizer = tokenizer
if language == "english":
g.sentiment_model_english = model
g.sentiment_tokenizer = tokenizer
elif language == "polish":
g.sentiment_model_polish = pipeline("text-classification",
model=model, tokenizer=tokenizer)
elif model_name == "political":
g.political_pipeline = pipeline("text-classification", model=model,
tokenizer=tokenizer)
Expand All @@ -61,3 +66,6 @@ def load_model(model_name: str):
elif model_name == "clickbait":
g.clickbait_pipeline = pipeline("text-classification", model=model,
tokenizer=tokenizer)
elif model_name == "ai_detector":
g.ai_detector_pipeline_english = pipeline("text-classification",
model=model, tokenizer=tokenizer)
Loading

0 comments on commit c5b0c55

Please sign in to comment.