Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialoGPT Persona #185

Merged
merged 67 commits into from
Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
72842e3
Fix requirements.txt (#84)
AndriiHura Jan 24, 2022
a67e15c
fix itsdangerous requirements
mtalimanchuk Feb 18, 2022
0f8ef0e
pin itsdangerous requirements for all flask==1.1.1 servers
mtalimanchuk Feb 18, 2022
6f0684a
Merge pull request #102 from deepmipt/fix/combined-classification-fla…
mtalimanchuk Feb 18, 2022
d237711
Merge pull request #103 from deepmipt/dev
dilyararimovna Feb 18, 2022
e990264
Merge pull request #107 from deepmipt/dev
dilyararimovna Mar 2, 2022
3208f71
Merge pull request #119 from deepmipt/dev
dilyararimovna Mar 11, 2022
ab44553
Merge pull request #123 from deepmipt/dev
dilyararimovna Mar 18, 2022
1c9a463
Merge pull request #137 from deepmipt/dev
dilyararimovna Apr 8, 2022
f8e4a59
Merge pull request #145 from deepmipt/dev
dilyararimovna Apr 30, 2022
48872a6
Merge pull request #150 from deepmipt/dev
dilyararimovna May 4, 2022
ed42f0c
Merge pull request #153 from deepmipt/dev
dilyararimovna May 5, 2022
30f290c
Merge pull request #155 from deepmipt/dev
dilyararimovna May 6, 2022
de510bc
Merge pull request #158 from deepmipt/dev
dilyararimovna May 11, 2022
ab2dcbd
Merge pull request #165 from deepmipt/dev
dilyararimovna May 27, 2022
525783a
Merge pull request #174 from deepmipt/dev
dilyararimovna Jun 27, 2022
7e87a36
Merge pull request #177 from deepmipt/dev
dilyararimovna Jun 30, 2022
3c2169d
increase timeout to 5s
dmitrymailk Jul 5, 2022
0e9f1bb
add logs
dmitrymailk Jul 5, 2022
49e3270
increase to 100
dmitrymailk Jul 5, 2022
f286d94
add first working version gpt_persona
dmitrymailk Jul 13, 2022
57849b6
add sentence_ranking(not working)
dmitrymailk Jul 14, 2022
ff686d9
fix wrong endpoint
dmitrymailk Jul 14, 2022
36af140
create sentecnce ranker annotator
dmitrymailk Jul 14, 2022
ba674a6
rewrite text generation logic
dmitrymailk Jul 14, 2022
5cb5695
add comments to code
dmitrymailk Jul 14, 2022
3d25e53
fix gpt_persona fallback
dmitrymailk Jul 15, 2022
563f643
clean code, add train script, write tests
dmitrymailk Jul 20, 2022
b0d7f2b
add get_intents, remove dataset, add hyperparams
dmitrymailk Jul 24, 2022
d8d1499
add batch support
dmitrymailk Jul 26, 2022
2d66aa4
fix: move files
dilyararimovna Aug 29, 2022
dad1c10
fix: move files
dilyararimovna Aug 29, 2022
f409638
fix: merge
dilyararimovna Aug 29, 2022
43973f9
fix: codestyle
dilyararimovna Aug 29, 2022
3783b76
fix: remove sentence ranker
dilyararimovna Aug 29, 2022
c2811fe
feat: new distribution and rename skill
dilyararimovna Aug 29, 2022
353379c
feat: new annotator
dilyararimovna Aug 29, 2022
ab88168
feat: relative persona extractor
dilyararimovna Aug 29, 2022
a601237
fix: codestyle
dilyararimovna Aug 29, 2022
051356d
fix: proxy
dilyararimovna Aug 29, 2022
d19cce1
Merge branch 'persona_bot' of https://github.com/dmitrymailk/dream in…
dilyararimovna Aug 29, 2022
ea82137
fix: params
dilyararimovna Aug 29, 2022
b791fb6
fix: volumes
dilyararimovna Aug 29, 2022
270326f
fix: reqs
dilyararimovna Aug 29, 2022
1569dfc
fix: tests
dilyararimovna Aug 29, 2022
ca5042b
fix: tests relative sents extr
dilyararimovna Aug 29, 2022
a42faa6
fix: batching
dilyararimovna Aug 30, 2022
56b72e3
fix: codestyle
dilyararimovna Aug 30, 2022
fd257db
fix: persona extractor tests
dilyararimovna Aug 30, 2022
9f2eed5
fix: persona get
dilyararimovna Aug 30, 2022
bcd32bf
fix: tests
dilyararimovna Aug 30, 2022
9142095
fix: imports
dilyararimovna Aug 30, 2022
1ddfa2b
fix: logs
dilyararimovna Aug 31, 2022
e3f43f6
fix: docs
dilyararimovna Aug 31, 2022
3b9baff
fix: add midas
dilyararimovna Aug 31, 2022
222071b
Merge remote-tracking branch 'origin/dev' into persona_bot
dilyararimovna Sep 16, 2022
f35fda6
fix: command
dilyararimovna Sep 16, 2022
9a8e663
Merge remote-tracking branch 'origin/dev' into persona_bot
dilyararimovna Sep 20, 2022
4c7dfdd
feat: add to main dist and tests
dilyararimovna Sep 20, 2022
28c7dc4
feat: remove infilling, add dialogpt persona based, docs
dilyararimovna Sep 21, 2022
aefd444
fix: gpus
dilyararimovna Sep 26, 2022
84ca2fd
fix: params
dilyararimovna Sep 27, 2022
191764c
fix: param
dilyararimovna Sep 27, 2022
6a7453f
Merge remote-tracking branch 'origin/dev' into persona_bot
dilyararimovna Sep 27, 2022
4cedb68
fix: indent
dilyararimovna Sep 27, 2022
fa51ce6
fix: remove infilling from tests
dilyararimovna Sep 28, 2022
999f185
Merge remote-tracking branch 'origin/dev' into persona_bot
dilyararimovna Oct 7, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ DIALOGPT_SERVICE_URL=http://dialogpt:8091/respond
DIALOGPT_CONTINUE_SERVICE_URL=http://dialogpt:8125/continue
PROMPT_STORYGPT_SERVICE_URL=http://prompt-storygpt:8127/respond
STORYGPT_SERVICE_URL=http://storygpt:8126/respond
SENTENCE_RANKER_SERVICE_URL=http://sentence-ranker:8128/respond
80 changes: 51 additions & 29 deletions README.md

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions annotators/relative_persona_extractor/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
FROM python:3.8.4

ARG SERVICE_PORT
ENV SERVICE_PORT ${SERVICE_PORT}
ARG N_SENTENCES_OT_RETURN
ENV N_SENTENCES_OT_RETURN ${N_SENTENCES_OT_RETURN}

RUN mkdir /src

COPY ./annotators/relative_persona_extractor/ /src/
COPY ./common/ /src/common/

COPY annotators/relative_persona_extractor/requirements.txt /src/requirements.txt
RUN pip install -r /src/requirements.txt

WORKDIR /src

CMD gunicorn --workers=1 server:app --bind 0.0.0.0:8000
5 changes: 5 additions & 0 deletions annotators/relative_persona_extractor/README.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Relative Persona Extractor

An annotator that utilizes Sentence Ranker to find the most relevant to the current context sentences from the bot's persona description.

The number of returned sentences is given as an environmental variable using `N_SENTENCES_OT_RETURN` in `docker-compose.yml`.
9 changes: 9 additions & 0 deletions annotators/relative_persona_extractor/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
flask==1.1.1
itsdangerous==2.0.1
gunicorn==20.0.4
sentry-sdk==0.13.4
requests==2.22.0
click<=8.0.4
jinja2<=3.0.3
Werkzeug<=2.0.3
numpy==1.17.2
73 changes: 73 additions & 0 deletions annotators/relative_persona_extractor/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import logging
import requests
import time
from os import getenv

import numpy as np
import sentry_sdk
from flask import Flask, request, jsonify


sentry_sdk.init(getenv("SENTRY_DSN"))
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.DEBUG)
logger = logging.getLogger(__name__)
app = Flask(__name__)

SENTENCE_RANKER_SERVICE_URL = getenv("SENTENCE_RANKER_SERVICE_URL")
N_SENTENCES_OT_RETURN = int(getenv("N_SENTENCES_OT_RETURN"))
with open("common/persona_sentences.txt", "r") as f:
PERSONA_SENTENCES = f.read().splitlines()
PERSONA_SENTENCES = [x.strip() for x in PERSONA_SENTENCES if len(x.strip())]


def get_result(request):
st_time = time.time()
contexts = request.json["contexts"]
result = []
pairs = []
context_ids = []

for context_id, context in enumerate(contexts):
str_context = " ".join(context)
for sent in PERSONA_SENTENCES:
pairs += [[str_context, sent]]
context_ids += [context_id]
context_ids = np.array(context_ids)
try:
scores = requests.post(SENTENCE_RANKER_SERVICE_URL, json={"sentence_pairs": pairs}, timeout=1.5).json()[0][
"batch"
]
scores = np.array(scores)
for i, context in enumerate(contexts):
curr_ids = np.where(context_ids == i)[0]
most_relevant_sent_ids = np.argsort(scores[curr_ids])[::-1][:N_SENTENCES_OT_RETURN]
curr_result = {
"persona": [PERSONA_SENTENCES[_id] for _id in most_relevant_sent_ids],
"max_similarity": scores[curr_ids][most_relevant_sent_ids[0]],
}
logger.info(f"Persona: {curr_result['persona']}")
result += [curr_result]
except Exception as exc:
logger.exception(exc)
sentry_sdk.capture_exception(exc)
result = [{"persona": [], "max_similarity": 0.0}] * len(contexts)

total_time = time.time() - st_time
logger.info(f"relative-persona-extractor exec time: {total_time:.3f}s")
return result


@app.route("/respond", methods=["POST"])
def respond():
result = get_result(request)
return jsonify(result)


@app.route("/respond_batch", methods=["POST"])
def respond_batch():
result = get_result(request)
return jsonify([{"batch": result}])


if __name__ == "__main__":
app.run(debug=False, host="0.0.0.0", port=3000)
45 changes: 45 additions & 0 deletions annotators/relative_persona_extractor/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import requests


def main():
url = "http://0.0.0.0:8133/respond"
input_data = {
"contexts": [
[
"How do you spend your spare time?",
"I like to watch movies and eat pizza.",
"Cool! What else do you like?",
],
[
"I like to go to the cinema on fridays",
"great. how do you spend your spare time?",
"I like to watch movies",
],
]
}
gold = [
{
"max_similarity": 0.6948127746582031,
"persona": [
"I like Italian food especially pasta and pizza.",
"I like to watch football and basketball on TV.",
"I like watching travel video blogs.",
],
},
{
"max_similarity": 0.6451027989387512,
"persona": [
"I like watching travel video blogs.",
"I like to watch football and basketball on TV.",
"I like Italian food especially pasta and pizza.",
],
},
]

result = requests.post(url, json=input_data).json()
assert result == gold, print(result)
print("Success!")


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions annotators/relative_persona_extractor/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python test.py
7 changes: 6 additions & 1 deletion assistant_dists/dream/cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,13 @@ services:
dialogpt:
environment:
CUDA_VISIBLE_DEVICES: ""
infilling:
sentence-ranker:
environment:
DEVICE: cpu
CUDA_VISIBLE_DEVICES: ""
dialogpt-persona-based:
environment:
DEVICE: cpu
CUDA_VISIBLE_DEVICES: ""
intent-catcher:
environment:
Expand Down
18 changes: 15 additions & 3 deletions assistant_dists/dream/dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,23 @@ services:
- "./services/dialogpt:/src"
ports:
- 8125:8125
infilling:
sentence-ranker:
volumes:
- "./services/infilling:/src"
- "./services/sentence_ranker:/src"
ports:
- 8122:8122
- 8128:8128
dialogpt-persona-based:
volumes:
- "./services/dialogpt_persona_based:/src"
- "./common:/src/common"
ports:
- 8131:8131
relative-persona-extractor:
volumes:
- "./annotators/relative_persona_extractor:/src"
- "./common:/src/common"
ports:
- 8133:8133
dff-template-skill:
volumes:
- "./skills/dff_template_skill:/src"
Expand Down
56 changes: 48 additions & 8 deletions assistant_dists/dream/docker-compose.override.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ services:
dff-funfact-skill:8104, dff-bot-persona-skill:8105, news-api-annotator:8112,
dff-gossip-skill:8109, dff-wiki-skill:8111, dff-gaming-skill:8115, topic-recommendation:8113,
user-persona-extractor:8114, wiki-facts:8116, dff-music-skill:8099, entity-detection:8103, dff-art-skill:8117,
midas-predictor:8121, dialogpt:8125, storygpt:8126, prompt-storygpt:8127, infilling:8122, dff-template-skill:8120"
midas-predictor:8121, dialogpt:8125, storygpt:8126, prompt-storygpt:8127, dff-template-skill:8120"
WAIT_HOSTS_TIMEOUT: ${WAIT_TIMEOUT:-480}
convers-evaluator-annotator:
env_file: [.env]
Expand Down Expand Up @@ -1187,22 +1187,63 @@ services:
reservations:
memory: 2G

infilling:
dialogpt-persona-based:
env_file: [ .env ]
build:
context: ./services/infilling/
args:
SERVICE_PORT: 8122
command: flask run -h 0.0.0.0 -p 8122
SERVICE_PORT: 8131
SERVICE_NAME: dialogpt_persona_based
PRETRAINED_MODEL_NAME_OR_PATH: dim/dialogpt-medium-persona-chat
MAX_PERSONA_SENTENCES: 3
context: .
dockerfile: ./services/dialogpt_persona_based/Dockerfile
command: flask run -h 0.0.0.0 -p 8131
environment:
- CUDA_VISIBLE_DEVICES=0
- FLASK_APP=server
deploy:
resources:
limits:
memory: 2G
reservations:
memory: 2G

relative-persona-extractor:
env_file: [ .env ]
build:
args:
SERVICE_PORT: 8133
SERVICE_NAME: relative_persona_extractor
N_SENTENCES_OT_RETURN: 3
context: .
dockerfile: ./annotators/relative_persona_extractor/Dockerfile
command: flask run -h 0.0.0.0 -p 8133
environment:
- FLASK_APP=server
deploy:
resources:
limits:
memory: 80M
reservations:
memory: 80M

sentence-ranker:
env_file: [ .env ]
build:
args:
SERVICE_PORT: 8128
PRETRAINED_MODEL_NAME_OR_PATH: sentence-transformers/bert-base-nli-mean-tokens
context: ./services/sentence_ranker/
command: flask run -h 0.0.0.0 -p 8128
environment:
- CUDA_VISIBLE_DEVICES=0
- FLASK_APP=server
deploy:
resources:
limits:
memory: 2.5G # ?
memory: 3G
reservations:
memory: 2.5G # ?
memory: 3G

storygpt:
env_file: [ .env ]
Expand Down Expand Up @@ -1261,4 +1302,3 @@ services:
reservations:
memory: 128M
version: '3.7'

6 changes: 5 additions & 1 deletion assistant_dists/dream/gpu1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,14 @@ services:
restart: unless-stopped
environment:
- CUDA_VISIBLE_DEVICES=9
infilling:
dialogpt-persona-based:
restart: unless-stopped
environment:
- CUDA_VISIBLE_DEVICES=7
sentence-ranker:
restart: unless-stopped
environment:
- CUDA_VISIBLE_DEVICES=9
dff-template-skill:
restart: unless-stopped
version: '3.7'
28 changes: 27 additions & 1 deletion assistant_dists/dream/pipeline_conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,20 @@
"annotators.midas_classification",
"annotators.combined_classification"
]
}
},
"relative_persona_extractor": {
"connector": {
"protocol": "http",
"timeout": 2,
"url": "http://relative-persona-extractor:8133/respond"
},
"dialog_formatter": "state_formatters.dp_formatters:context_formatter_dialog",
"response_formatter": "state_formatters.dp_formatters:simple_formatter_service",
"state_manager_method": "add_annotation",
"previous_services": [
"annotators.spelling_preprocessing"
]
}
},
"skill_selectors": {
"rule_based_selector": {
Expand Down Expand Up @@ -977,6 +990,19 @@
],
"state_manager_method": "add_hypothesis"
},
"dialogpt_persona_based": {
"connector": {
"protocol": "http",
"timeout": 3,
"url": "http://dialogpt-persona-based:8131/respond"
},
"dialog_formatter": "state_formatters.dp_formatters:persona_bot_formatter",
"response_formatter": "state_formatters.dp_formatters:skill_with_attributes_formatter_service",
"previous_services": [
"skill_selectors"
],
"state_manager_method": "add_hypothesis"
},
"dff_template_skill": {
"connector": {
"protocol": "http",
Expand Down
24 changes: 21 additions & 3 deletions assistant_dists/dream/proxy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -576,14 +576,32 @@ services:
- PROXY_PASS=dream.deeppavlov.ai:8125
- PORT=8125

infilling:
sentence-ranker:
command: [ "nginx", "-g", "daemon off;" ]
build:
context: dp/proxy/
dockerfile: Dockerfile
environment:
- PROXY_PASS=dream.deeppavlov.ai:8122
- PORT=8122
- PROXY_PASS=dream.deeppavlov.ai:8128
- PORT=8128

dialogpt-persona-based:
command: [ "nginx", "-g", "daemon off;" ]
build:
context: dp/proxy/
dockerfile: Dockerfile
environment:
- PROXY_PASS=dream.deeppavlov.ai:8131
- PORT=8125

relative-persona-extractor:
command: [ "nginx", "-g", "daemon off;" ]
build:
context: dp/proxy/
dockerfile: Dockerfile
environment:
- PROXY_PASS=dream.deeppavlov.ai:8133
- PORT=8133

dff-short-story-skill:
command: [ "nginx", "-g", "daemon off;" ]
Expand Down
7 changes: 5 additions & 2 deletions assistant_dists/dream/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,11 @@ services:
prompt-storygpt:
environment:
- CUDA_VISIBLE_DEVICES=8
infilling:
dialogpt-persona-based:
environment:
- CUDA_VISIBLE_DEVICES=7
- CUDA_VISIBLE_DEVICES=9
sentence-ranker:
environment:
- CUDA_VISIBLE_DEVICES=9
dff-template-skill:
version: '3.7'
Loading