-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dialoGPT Persona #185
dialoGPT Persona #185
Changes from 27 commits
72842e3
a67e15c
0f8ef0e
6f0684a
d237711
e990264
3208f71
ab44553
1c9a463
f8e4a59
48872a6
ed42f0c
30f290c
de510bc
ab2dcbd
525783a
7e87a36
3c2169d
0e9f1bb
49e3270
f286d94
57849b6
ff686d9
36af140
ba674a6
5cb5695
3d25e53
563f643
b0d7f2b
d8d1499
2d66aa4
dad1c10
f409638
43973f9
3783b76
c2811fe
353379c
ab88168
a601237
051356d
d19cce1
ea82137
b791fb6
270326f
1569dfc
ca5042b
a42faa6
56b72e3
fd257db
9f2eed5
bcd32bf
9142095
1ddfa2b
e3f43f6
3b9baff
222071b
f35fda6
9a8e663
4c7dfdd
28c7dc4
aefd444
84ca2fd
191764c
6a7453f
4cedb68
fa51ce6
999f185
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# syntax=docker/dockerfile:experimental | ||
|
||
FROM pytorch/pytorch:1.5-cuda10.1-cudnn7-runtime | ||
|
||
WORKDIR /src | ||
|
||
ARG PRETRAINED_MODEL_NAME_OR_PATH | ||
ENV PRETRAINED_MODEL_NAME_OR_PATH ${PRETRAINED_MODEL_NAME_OR_PATH} | ||
ARG SERVICE_PORT | ||
ENV SERVICE_PORT ${SERVICE_PORT} | ||
# ARG N_HYPOTHESES_TO_GENERATE | ||
# ENV N_HYPOTHESES_TO_GENERATE ${N_HYPOTHESES_TO_GENERATE} | ||
|
||
|
||
COPY ./requirements.txt /src/requirements.txt | ||
COPY ./persona_sentences.txt /src/persona_sentences.txt | ||
RUN pip install -r /src/requirements.txt | ||
|
||
RUN python -c "from sentence_transformers import SentenceTransformer;SentenceTransformer('${PRETRAINED_MODEL_NAME_OR_PATH}')" | ||
|
||
COPY . /src | ||
|
||
CMD gunicorn --workers=1 server:app -b 0.0.0.0:${SERVICE_PORT} --timeout=300 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
I used to be ADDICTED to Diet Coke. | ||
I hate puzzles. | ||
I was in a commercial for my aunt and uncle’s store. | ||
I have a scar on my left hand from where I cat bit me when I was babysitting when I was 12. | ||
I used to try to sleep in my splits when I was a dancer because I wanted to have perfect flexibility. | ||
I love being outside, but I hate smelling like outside. | ||
I am afraid of jack-in-the-boxes even though I know when they are going to pop out. | ||
In second grade, I fell off the counter I was standing on trying to get a mug for milk and I had to wear an ice pack in my pants to school for a couple weeks. | ||
For the third grade play, the part I wanted the most was to play the mom. | ||
I almost went to the University of Iowa to major in English and be on the dance team. | ||
The high school I graduated from, I only went to for my senior year. | ||
My sister, Caroline, and I get asked if we are twins all the time, and when people find out we aren’t actually twins, they always think she’s older. | ||
I’m really close to my Dad. | ||
When my parents first met Cam’s parents, they kept saying how Cam’s mom looked so familiar, but couldn’t figure out where they’d seen her before. | ||
Cam and I planned for him to ask me out on November 15th because 15 was his favorite number. | ||
When I was younger I was OBSESSED with Jesse McCartney in the way that some people nowadays are obsessed with One Direction. | ||
I either get told I look like Hilary Duff or Leighton Meester. I’m ok with either of those!! | ||
A lot of people can’t figure out what nationality I am. Some people say I look 100% American and some people think I’m Asian. I am 50% Honduran and my mom is from the states. | ||
No, I am not fluent in Spanish. | ||
I haven’t had un-painted toenails in 6 years. I just don’t like the way my toenails look without polish haha!! | ||
I was a drill team officer in high school and it was one of the most challenging and rewarding positions. | ||
I’m a literature nerd. I LOVE Shakespeare, Jane Austen, Charlotte Bronte, and anything Greek mythology. | ||
I ALMOST got certified to be a Zumba instructor, but then I was transferring so I didn’t. | ||
Oddly enough, I think my favorite food ever is tuna salad. | ||
I’m obsessed with libraries and churches. When I’m in a new town, I like to wander around those if I can find them! | ||
My favorite place in the world is Eureka Springs, Arkansas. | ||
I didn’t get a cell phone until my freshman year of high school, and I didn’t get an iPhone until I was a junior in high school. | ||
If I could have any job in the world, I would own a coffee shop that also sold art, flowers, and stationary. | ||
I plan on working for myself at some point in my life. | ||
I’m insanely good with state capitols. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
transformers==4.20.1 | ||
flask==1.1.1 | ||
itsdangerous==2.0.1 | ||
gunicorn==19.9.0 | ||
requests==2.22.0 | ||
sentry-sdk[flask]==0.14.1 | ||
healthcheck==1.3.3 | ||
jinja2<=3.0.3 | ||
Werkzeug<=2.0.3 | ||
sentence-transformers==2.2.2 | ||
huggingface-hub==0.4.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from typing import List, Tuple | ||
from sentence_transformers import SentenceTransformer | ||
import logging | ||
import time | ||
import os | ||
|
||
import sentry_sdk | ||
import torch | ||
from flask import Flask, request, jsonify | ||
from sentry_sdk.integrations.flask import FlaskIntegration | ||
sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()]) | ||
|
||
class SentenceRanker: | ||
def __init__(self, | ||
persona_sentences=None, | ||
sentence_model=None | ||
): | ||
"""_summary_ | ||
|
||
Args: | ||
persona_sentences (List[str]): список предложений составляющие полную персону. Defaults to None. | ||
sentence_model (SentenceTransformer): модель для перевода предложения в вектор. Defaults to None. | ||
""" | ||
self.persona_sentences = persona_sentences | ||
self.sentence_model = sentence_model | ||
self.sentence_embeddings = self.sentence_model.encode( | ||
persona_sentences, | ||
convert_to_tensor=True | ||
) | ||
# для кеширования похожих запросов | ||
self.ranked_sentences = {} | ||
|
||
def rank_sentences(self, query, k): | ||
"""возвращает топ k предложений которые похожи на query | ||
|
||
Args: | ||
query (str): предложение, на основе которого ищем похожие | ||
k (int): количество возвращаемых предложений. Defaults to 5. | ||
|
||
Returns: | ||
List[List[str], float]: отранжированные предложения и максимальное косинусное расстояние среди всех | ||
""" | ||
key = f"{query}_{k}" | ||
if self.ranked_sentences.get(key, False): | ||
return self.ranked_sentences[key] | ||
|
||
user_sentence_embeddings = self.sentence_model.encode(query, convert_to_tensor=True) | ||
|
||
cos_sim_ranks = self.cos_sim( | ||
user_sentence_embeddings, | ||
self.sentence_embeddings | ||
) | ||
|
||
top_indices = torch.argsort(cos_sim_ranks, descending=True) | ||
max_similarity = float(cos_sim_ranks[top_indices][0]) | ||
top_indices = list(top_indices[:k].cpu().numpy()) | ||
similar_sentences = [self.persona_sentences[idx] for idx in top_indices] | ||
self.ranked_sentences[key] = similar_sentences, max_similarity | ||
return [similar_sentences, max_similarity] | ||
|
||
def cos_sim(self, a, b): | ||
"""возвращает косинусное расстояние | ||
|
||
K - количество предложений для сравнения | ||
N - размерность возвращаемого вектора | ||
Args: | ||
a (torch.FloatTensor): shape (1, N) | ||
b (torch.FloatTensor): shape (K, N) | ||
|
||
Returns: | ||
torch.FloatTensor: shape (1, K) тензор с косинусными расстояниями | ||
""" | ||
a_norm = torch.nn.functional.normalize(a, p=2, dim=1) | ||
b_norm = torch.nn.functional.normalize(b, p=2, dim=1) | ||
return torch.sum(a_norm * b_norm, dim=1) | ||
|
||
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO) | ||
# logging.getLogger("werkzeug").setLevel("INFO") | ||
logger = logging.getLogger(__name__) | ||
|
||
PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH") | ||
# logging.info(f'PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}') | ||
DEFAULT_CONFIDENCE = 10 | ||
N_HYPOTHESES_TO_GENERATE = int(os.environ.get("N_HYPOTHESES_TO_GENERATE", 1)) | ||
ZERO_CONFIDENCE = 0.0 | ||
MAX_HISTORY_DEPTH = 3 | ||
TOP_SIMILAR_SENTENCES = 5 | ||
|
||
try: | ||
sentence_model = SentenceTransformer(PRETRAINED_MODEL_NAME_OR_PATH) | ||
|
||
persona = open("./persona_sentences.txt").read() | ||
persona_sentences = persona.split("\n") | ||
persona_sentences = [item for item in persona_sentences if len(item) > 0] | ||
|
||
sentence_ranker = SentenceRanker( | ||
persona_sentences=persona_sentences, | ||
sentence_model=sentence_model | ||
) | ||
logger.info("sentence_ranker is ready") | ||
except Exception as e: | ||
sentry_sdk.capture_exception(e) | ||
logger.exception(e) | ||
raise e | ||
|
||
app = Flask(__name__) | ||
|
||
@app.route("/response", methods=["POST"]) | ||
dilyararimovna marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def respond(): | ||
try: | ||
dialogs = request.json.get("dialogs", []) | ||
dilyararimovna marked this conversation as resolved.
Show resolved
Hide resolved
|
||
process_result = [] | ||
# берем последнюю реплику, затем забираем у нее результат работы аннотатора sentseg | ||
context_str = dialogs[0]["human_utterances"][-1]["annotations"]['sentseg']['punct_sent'] | ||
max_likelihood_sentences, max_sentence_similarity = sentence_ranker.rank_sentences( | ||
[context_str], | ||
k=TOP_SIMILAR_SENTENCES | ||
) | ||
|
||
process_result.append([ | ||
max_likelihood_sentences, | ||
max_sentence_similarity | ||
]) | ||
|
||
except Exception as exc: | ||
logger.exception(exc) | ||
sentry_sdk.capture_exception(exc) | ||
|
||
return jsonify( | ||
process_result | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,10 @@ services: | |
environment: | ||
DEVICE: cpu | ||
CUDA_VISIBLE_DEVICES: "" | ||
dialogpt-persona: | ||
environment: | ||
DEVICE: cpu | ||
CUDA_VISIBLE_DEVICES: "" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sentence ranker тоже |
||
intent-catcher: | ||
environment: | ||
DEVICE: cpu | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,12 +2,10 @@ services: | |
agent: | ||
command: sh -c 'bin/wait && python -m deeppavlov_agent.run -ch http_client -pl assistant_dists/dream_mini/pipeline_conf.json --cors' | ||
environment: | ||
WAIT_HOSTS: "convers-evaluator-annotator:8004, dff-program-y-skill:8008, sentseg:8011, convers-evaluation-selector:8009, | ||
dff-intent-responder-skill:8012, intent-catcher:8014, badlisted-words:8018, | ||
spelling-preprocessing:8074, dialogpt:8125" | ||
WAIT_HOSTS: "convers-evaluator-annotator:8004, dff-program-y-skill:8008, sentseg:8011, convers-evaluation-selector:8009, dff-intent-responder-skill:8012, intent-catcher:8014, badlisted-words:8018, spelling-preprocessing:8074, dialogpt:8125, dialogpt-persona:8131, sentence-ranker:8130" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. верни в нескольтко строк - там удобнее смотртеь, пожалуйста |
||
WAIT_HOSTS_TIMEOUT: ${WAIT_TIMEOUT:-480} | ||
convers-evaluator-annotator: | ||
env_file: [.env] | ||
env_file: [ .env ] | ||
build: | ||
args: | ||
CONFIG: conveval.json | ||
|
@@ -27,7 +25,7 @@ services: | |
memory: 2G | ||
|
||
dff-program-y-skill: | ||
env_file: [.env] | ||
env_file: [ .env ] | ||
build: | ||
args: | ||
SERVICE_PORT: 8008 | ||
|
@@ -43,9 +41,8 @@ services: | |
reservations: | ||
memory: 1024M | ||
|
||
|
||
sentseg: | ||
env_file: [.env] | ||
env_file: [ .env ] | ||
build: | ||
context: ./annotators/SentSeg/ | ||
command: flask run -h 0.0.0.0 -p 8011 | ||
|
@@ -59,18 +56,18 @@ services: | |
memory: 1.5G | ||
|
||
convers-evaluation-selector: | ||
env_file: [.env] | ||
env_file: [ .env ] | ||
build: | ||
args: | ||
TAG_BASED_SELECTION: 1 | ||
TAG_BASED_SELECTION: 0 | ||
CALL_BY_NAME_PROBABILITY: 0.5 | ||
PROMPT_PROBA: 0.3 | ||
PROMPT_PROBA: 0 | ||
ACKNOWLEDGEMENT_PROBA: 0.3 | ||
PRIORITIZE_WITH_REQUIRED_ACT: 1 | ||
PRIORITIZE_NO_DIALOG_BREAKDOWN: 0 | ||
PRIORITIZE_WITH_SAME_TOPIC_ENTITY: 1 | ||
PRIORITIZE_WITH_SAME_TOPIC_ENTITY: 0 | ||
IGNORE_DISLIKED_SKILLS: 0 | ||
GREETING_FIRST: 1 | ||
GREETING_FIRST: 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. вернуть параметры на старые |
||
RESTRICTION_FOR_SENSITIVE_CASE: 1 | ||
PRIORITIZE_PROMTS_WHEN_NO_SCRIPTS: 0 | ||
ADD_ACKNOWLEDGMENTS_IF_POSSIBLE: 1 | ||
|
@@ -105,15 +102,15 @@ services: | |
memory: 128M | ||
|
||
intent-catcher: | ||
env_file: [.env] | ||
env_file: [ .env ] | ||
build: | ||
context: . | ||
dockerfile: ./annotators/IntentCatcherTransformers/Dockerfile | ||
args: | ||
SERVICE_PORT: 8014 | ||
CONFIG_NAME: intents_model_dp_config.json | ||
INTENT_PHRASES_PATH: intent_phrases.json | ||
command: python -m flask run -h 0.0.0.0 -p 8014 | ||
command: python -m flask run -h 0.0.0.0 -p 8014 | ||
environment: | ||
- FLASK_APP=server | ||
- CUDA_VISIBLE_DEVICES=0 | ||
|
@@ -125,7 +122,7 @@ services: | |
memory: 3.5G | ||
|
||
badlisted-words: | ||
env_file: [.env] | ||
env_file: [ .env ] | ||
build: | ||
context: annotators/BadlistedWordsDetector/ | ||
command: flask run -h 0.0.0.0 -p 8018 | ||
|
@@ -139,7 +136,7 @@ services: | |
memory: 256M | ||
|
||
spelling-preprocessing: | ||
env_file: [.env] | ||
env_file: [ .env ] | ||
build: | ||
context: ./annotators/spelling_preprocessing/ | ||
command: flask run -h 0.0.0.0 -p 8074 | ||
|
@@ -172,4 +169,42 @@ services: | |
reservations: | ||
memory: 2G | ||
|
||
dialogpt-persona: | ||
env_file: [ .env ] | ||
build: | ||
args: | ||
SERVICE_PORT: 8131 | ||
SERVICE_NAME: dialogpt_persona | ||
PRETRAINED_MODEL_NAME_OR_PATH: dim/dialogpt-medium-persona-chat | ||
context: ./services/dialogpt_persona/ | ||
command: flask run -h 0.0.0.0 -p 8131 | ||
environment: | ||
- CUDA_VISIBLE_DEVICES=0 | ||
- FLASK_APP=server | ||
deploy: | ||
resources: | ||
limits: | ||
memory: 2G | ||
reservations: | ||
memory: 2G | ||
|
||
sentence-ranker: | ||
env_file: [ .env ] | ||
build: | ||
args: | ||
SERVICE_PORT: 8130 | ||
SERVICE_NAME: sentence_ranker | ||
PRETRAINED_MODEL_NAME_OR_PATH: 'sentence-transformers/nli-distilroberta-base-v2' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. можно без кавычек |
||
context: ./annotators/sentence_ranker/ | ||
command: flask run -h 0.0.0.0 -p 8130 | ||
environment: | ||
- CUDA_VISIBLE_DEVICES=0 | ||
- FLASK_APP=server | ||
deploy: | ||
resources: | ||
limits: | ||
memory: 1G | ||
reservations: | ||
memory: 1G | ||
|
||
version: '3.7' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
англ