Skip to content

Commit

Permalink
dialoGPT Persona (#185)
Browse files Browse the repository at this point in the history
* Fix requirements.txt (#84)

* fix itsdangerous requirements

* pin itsdangerous requirements for all flask==1.1.1 servers

* increase timeout to 5s

* add logs

* increase to 100

* add first working version gpt_persona

* add sentence_ranking(not working)

* fix wrong endpoint

* create sentecnce ranker annotator

* rewrite text generation logic

* add comments to code

* fix gpt_persona fallback

* clean code, add train script, write tests

* add get_intents, remove dataset, add hyperparams

* add batch support

* fix: move files

* fix: merge

* fix: codestyle

* fix: remove sentence ranker

* feat: new distribution and rename skill

* feat: new annotator

* feat: relative persona extractor

* fix: codestyle

* fix: proxy

* fix: params

* fix: volumes

* fix: reqs

* fix: tests

* fix: tests relative sents extr

* fix: batching

* fix: codestyle

* fix: persona extractor tests

* fix: persona get

* fix: tests

* fix: imports

* fix: logs

* fix: docs

* fix: add midas

* fix: command

* feat: add to main dist and tests

* feat: remove infilling, add dialogpt persona based, docs

* fix: gpus

* fix: params

* fix: param

* fix: indent

* fix: remove infilling from tests

Co-authored-by: Andrii.Hura <54397922+AndriiHura@users.noreply.github.com>
Co-authored-by: mtalimanchuk <mtalimanchuk@gmail.com>
Co-authored-by: Dilyara Baymurzina <dilyara.rimovna@gmail.com>
  • Loading branch information
4 people authored Oct 10, 2022
1 parent 9374614 commit c3bb406
Show file tree
Hide file tree
Showing 34 changed files with 2,575 additions and 49 deletions.
1 change: 1 addition & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ DIALOGPT_SERVICE_URL=http://dialogpt:8091/respond
DIALOGPT_CONTINUE_SERVICE_URL=http://dialogpt:8125/continue
PROMPT_STORYGPT_SERVICE_URL=http://prompt-storygpt:8127/respond
STORYGPT_SERVICE_URL=http://storygpt:8126/respond
SENTENCE_RANKER_SERVICE_URL=http://sentence-ranker:8128/respond
80 changes: 51 additions & 29 deletions README.md

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions annotators/relative_persona_extractor/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
FROM python:3.8.4

ARG SERVICE_PORT
ENV SERVICE_PORT ${SERVICE_PORT}
ARG N_SENTENCES_OT_RETURN
ENV N_SENTENCES_OT_RETURN ${N_SENTENCES_OT_RETURN}

RUN mkdir /src

COPY ./annotators/relative_persona_extractor/ /src/
COPY ./common/ /src/common/

COPY annotators/relative_persona_extractor/requirements.txt /src/requirements.txt
RUN pip install -r /src/requirements.txt

WORKDIR /src

CMD gunicorn --workers=1 server:app --bind 0.0.0.0:8000
5 changes: 5 additions & 0 deletions annotators/relative_persona_extractor/README.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Relative Persona Extractor

An annotator that utilizes Sentence Ranker to find the most relevant to the current context sentences from the bot's persona description.

The number of returned sentences is given as an environmental variable using `N_SENTENCES_OT_RETURN` in `docker-compose.yml`.
9 changes: 9 additions & 0 deletions annotators/relative_persona_extractor/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
flask==1.1.1
itsdangerous==2.0.1
gunicorn==20.0.4
sentry-sdk==0.13.4
requests==2.22.0
click<=8.0.4
jinja2<=3.0.3
Werkzeug<=2.0.3
numpy==1.17.2
73 changes: 73 additions & 0 deletions annotators/relative_persona_extractor/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import logging
import requests
import time
from os import getenv

import numpy as np
import sentry_sdk
from flask import Flask, request, jsonify


sentry_sdk.init(getenv("SENTRY_DSN"))
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.DEBUG)
logger = logging.getLogger(__name__)
app = Flask(__name__)

SENTENCE_RANKER_SERVICE_URL = getenv("SENTENCE_RANKER_SERVICE_URL")
N_SENTENCES_OT_RETURN = int(getenv("N_SENTENCES_OT_RETURN"))
with open("common/persona_sentences.txt", "r") as f:
PERSONA_SENTENCES = f.read().splitlines()
PERSONA_SENTENCES = [x.strip() for x in PERSONA_SENTENCES if len(x.strip())]


def get_result(request):
st_time = time.time()
contexts = request.json["contexts"]
result = []
pairs = []
context_ids = []

for context_id, context in enumerate(contexts):
str_context = " ".join(context)
for sent in PERSONA_SENTENCES:
pairs += [[str_context, sent]]
context_ids += [context_id]
context_ids = np.array(context_ids)
try:
scores = requests.post(SENTENCE_RANKER_SERVICE_URL, json={"sentence_pairs": pairs}, timeout=1.5).json()[0][
"batch"
]
scores = np.array(scores)
for i, context in enumerate(contexts):
curr_ids = np.where(context_ids == i)[0]
most_relevant_sent_ids = np.argsort(scores[curr_ids])[::-1][:N_SENTENCES_OT_RETURN]
curr_result = {
"persona": [PERSONA_SENTENCES[_id] for _id in most_relevant_sent_ids],
"max_similarity": scores[curr_ids][most_relevant_sent_ids[0]],
}
logger.info(f"Persona: {curr_result['persona']}")
result += [curr_result]
except Exception as exc:
logger.exception(exc)
sentry_sdk.capture_exception(exc)
result = [{"persona": [], "max_similarity": 0.0}] * len(contexts)

total_time = time.time() - st_time
logger.info(f"relative-persona-extractor exec time: {total_time:.3f}s")
return result


@app.route("/respond", methods=["POST"])
def respond():
result = get_result(request)
return jsonify(result)


@app.route("/respond_batch", methods=["POST"])
def respond_batch():
result = get_result(request)
return jsonify([{"batch": result}])


if __name__ == "__main__":
app.run(debug=False, host="0.0.0.0", port=3000)
45 changes: 45 additions & 0 deletions annotators/relative_persona_extractor/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import requests


def main():
url = "http://0.0.0.0:8133/respond"
input_data = {
"contexts": [
[
"How do you spend your spare time?",
"I like to watch movies and eat pizza.",
"Cool! What else do you like?",
],
[
"I like to go to the cinema on fridays",
"great. how do you spend your spare time?",
"I like to watch movies",
],
]
}
gold = [
{
"max_similarity": 0.6948127746582031,
"persona": [
"I like Italian food especially pasta and pizza.",
"I like to watch football and basketball on TV.",
"I like watching travel video blogs.",
],
},
{
"max_similarity": 0.6451027989387512,
"persona": [
"I like watching travel video blogs.",
"I like to watch football and basketball on TV.",
"I like Italian food especially pasta and pizza.",
],
},
]

result = requests.post(url, json=input_data).json()
assert result == gold, print(result)
print("Success!")


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions annotators/relative_persona_extractor/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python test.py
7 changes: 6 additions & 1 deletion assistant_dists/dream/cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,13 @@ services:
dialogpt:
environment:
CUDA_VISIBLE_DEVICES: ""
infilling:
sentence-ranker:
environment:
DEVICE: cpu
CUDA_VISIBLE_DEVICES: ""
dialogpt-persona-based:
environment:
DEVICE: cpu
CUDA_VISIBLE_DEVICES: ""
intent-catcher:
environment:
Expand Down
18 changes: 15 additions & 3 deletions assistant_dists/dream/dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,23 @@ services:
- "./services/dialogpt:/src"
ports:
- 8125:8125
infilling:
sentence-ranker:
volumes:
- "./services/infilling:/src"
- "./services/sentence_ranker:/src"
ports:
- 8122:8122
- 8128:8128
dialogpt-persona-based:
volumes:
- "./services/dialogpt_persona_based:/src"
- "./common:/src/common"
ports:
- 8131:8131
relative-persona-extractor:
volumes:
- "./annotators/relative_persona_extractor:/src"
- "./common:/src/common"
ports:
- 8133:8133
dff-template-skill:
volumes:
- "./skills/dff_template_skill:/src"
Expand Down
56 changes: 48 additions & 8 deletions assistant_dists/dream/docker-compose.override.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ services:
dff-funfact-skill:8104, dff-bot-persona-skill:8105, news-api-annotator:8112,
dff-gossip-skill:8109, dff-wiki-skill:8111, dff-gaming-skill:8115, topic-recommendation:8113,
user-persona-extractor:8114, wiki-facts:8116, dff-music-skill:8099, entity-detection:8103, dff-art-skill:8117,
midas-predictor:8121, dialogpt:8125, storygpt:8126, prompt-storygpt:8127, infilling:8122, dff-template-skill:8120"
midas-predictor:8121, dialogpt:8125, storygpt:8126, prompt-storygpt:8127, dff-template-skill:8120"
WAIT_HOSTS_TIMEOUT: ${WAIT_TIMEOUT:-480}
convers-evaluator-annotator:
env_file: [.env]
Expand Down Expand Up @@ -1187,22 +1187,63 @@ services:
reservations:
memory: 2G

infilling:
dialogpt-persona-based:
env_file: [ .env ]
build:
context: ./services/infilling/
args:
SERVICE_PORT: 8122
command: flask run -h 0.0.0.0 -p 8122
SERVICE_PORT: 8131
SERVICE_NAME: dialogpt_persona_based
PRETRAINED_MODEL_NAME_OR_PATH: dim/dialogpt-medium-persona-chat
MAX_PERSONA_SENTENCES: 3
context: .
dockerfile: ./services/dialogpt_persona_based/Dockerfile
command: flask run -h 0.0.0.0 -p 8131
environment:
- CUDA_VISIBLE_DEVICES=0
- FLASK_APP=server
deploy:
resources:
limits:
memory: 2G
reservations:
memory: 2G

relative-persona-extractor:
env_file: [ .env ]
build:
args:
SERVICE_PORT: 8133
SERVICE_NAME: relative_persona_extractor
N_SENTENCES_OT_RETURN: 3
context: .
dockerfile: ./annotators/relative_persona_extractor/Dockerfile
command: flask run -h 0.0.0.0 -p 8133
environment:
- FLASK_APP=server
deploy:
resources:
limits:
memory: 80M
reservations:
memory: 80M

sentence-ranker:
env_file: [ .env ]
build:
args:
SERVICE_PORT: 8128
PRETRAINED_MODEL_NAME_OR_PATH: sentence-transformers/bert-base-nli-mean-tokens
context: ./services/sentence_ranker/
command: flask run -h 0.0.0.0 -p 8128
environment:
- CUDA_VISIBLE_DEVICES=0
- FLASK_APP=server
deploy:
resources:
limits:
memory: 2.5G # ?
memory: 3G
reservations:
memory: 2.5G # ?
memory: 3G

storygpt:
env_file: [ .env ]
Expand Down Expand Up @@ -1261,4 +1302,3 @@ services:
reservations:
memory: 128M
version: '3.7'

6 changes: 5 additions & 1 deletion assistant_dists/dream/gpu1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,14 @@ services:
restart: unless-stopped
environment:
- CUDA_VISIBLE_DEVICES=9
infilling:
dialogpt-persona-based:
restart: unless-stopped
environment:
- CUDA_VISIBLE_DEVICES=7
sentence-ranker:
restart: unless-stopped
environment:
- CUDA_VISIBLE_DEVICES=9
dff-template-skill:
restart: unless-stopped
version: '3.7'
28 changes: 27 additions & 1 deletion assistant_dists/dream/pipeline_conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,20 @@
"annotators.midas_classification",
"annotators.combined_classification"
]
}
},
"relative_persona_extractor": {
"connector": {
"protocol": "http",
"timeout": 2,
"url": "http://relative-persona-extractor:8133/respond"
},
"dialog_formatter": "state_formatters.dp_formatters:context_formatter_dialog",
"response_formatter": "state_formatters.dp_formatters:simple_formatter_service",
"state_manager_method": "add_annotation",
"previous_services": [
"annotators.spelling_preprocessing"
]
}
},
"skill_selectors": {
"rule_based_selector": {
Expand Down Expand Up @@ -977,6 +990,19 @@
],
"state_manager_method": "add_hypothesis"
},
"dialogpt_persona_based": {
"connector": {
"protocol": "http",
"timeout": 3,
"url": "http://dialogpt-persona-based:8131/respond"
},
"dialog_formatter": "state_formatters.dp_formatters:persona_bot_formatter",
"response_formatter": "state_formatters.dp_formatters:skill_with_attributes_formatter_service",
"previous_services": [
"skill_selectors"
],
"state_manager_method": "add_hypothesis"
},
"dff_template_skill": {
"connector": {
"protocol": "http",
Expand Down
24 changes: 21 additions & 3 deletions assistant_dists/dream/proxy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -576,14 +576,32 @@ services:
- PROXY_PASS=dream.deeppavlov.ai:8125
- PORT=8125

infilling:
sentence-ranker:
command: [ "nginx", "-g", "daemon off;" ]
build:
context: dp/proxy/
dockerfile: Dockerfile
environment:
- PROXY_PASS=dream.deeppavlov.ai:8122
- PORT=8122
- PROXY_PASS=dream.deeppavlov.ai:8128
- PORT=8128

dialogpt-persona-based:
command: [ "nginx", "-g", "daemon off;" ]
build:
context: dp/proxy/
dockerfile: Dockerfile
environment:
- PROXY_PASS=dream.deeppavlov.ai:8131
- PORT=8125

relative-persona-extractor:
command: [ "nginx", "-g", "daemon off;" ]
build:
context: dp/proxy/
dockerfile: Dockerfile
environment:
- PROXY_PASS=dream.deeppavlov.ai:8133
- PORT=8133

dff-short-story-skill:
command: [ "nginx", "-g", "daemon off;" ]
Expand Down
7 changes: 5 additions & 2 deletions assistant_dists/dream/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,11 @@ services:
prompt-storygpt:
environment:
- CUDA_VISIBLE_DEVICES=8
infilling:
dialogpt-persona-based:
environment:
- CUDA_VISIBLE_DEVICES=7
- CUDA_VISIBLE_DEVICES=9
sentence-ranker:
environment:
- CUDA_VISIBLE_DEVICES=9
dff-template-skill:
version: '3.7'
Loading

0 comments on commit c3bb406

Please sign in to comment.