Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/dialogpt en #136

Merged
merged 17 commits into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion annotators/IntentCatcher/data/intent_phrases.json
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@
},
"what_can_you_do": {
"phrases": [
"((so )){0,1}(tell me ){0,1}what can you do",
"((so )){0,1}(tell me ){0,1}what (else ){0,1}can you do",
"(tell me ){0,1}what are you able to do",
"(tell me ){0,1}what you can do",
"what are your ((skills)|(abilities)|(features))",
Expand Down
4 changes: 3 additions & 1 deletion assistant_dists/dream/cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,6 @@ services:
environment:
DEVICE: cpu
CUDA_VISIBLE_DEVICES: ""

dialogpt:
environment:
CUDA_VISIBLE_DEVICES: ""
5 changes: 5 additions & 0 deletions assistant_dists/dream/dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,11 @@ services:
- "./annotators/midas_predictor:/src"
ports:
- 8121:8121
dialogpt:
volumes:
- "./services/dialogpt:/src"
ports:
- 8125:8125
dff-template-skill:
volumes:
- "./skills/dff_template_skill:/src"
Expand Down
21 changes: 20 additions & 1 deletion assistant_dists/dream/docker-compose.override.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ services:
dff-funfact-skill:8104, dff-bot-persona-skill:8105, news-api-annotator:8112,
dff-gossip-skill:8109, dff-wiki-skill:8111, dff-gaming-skill:8115, topic-recommendation:8113,
user-persona-extractor:8114, wiki-facts:8116, dff-music-skill:8099, entity-detection:8103, dff-art-skill:8117,
midas-predictor:8121, dff-template-skill:8120"
midas-predictor:8121, dialogpt:8125, dff-template-skill:8120"
WAIT_HOSTS_TIMEOUT: ${WAIT_TIMEOUT:-480}
convers-evaluator-annotator:
env_file: [.env]
Expand Down Expand Up @@ -1136,6 +1136,25 @@ services:
memory: 50M
reservations:
memory: 50M

dialogpt:
env_file: [ .env ]
build:
args:
SERVICE_PORT: 8125
SERVICE_NAME: dialogpt
PRETRAINED_MODEL_NAME_OR_PATH: microsoft/DialoGPT-small
context: ./services/dialogpt/
command: flask run -h 0.0.0.0 -p 8125
environment:
- CUDA_VISIBLE_DEVICES=0
- FLASK_APP=server
deploy:
resources:
limits:
memory: 2G
reservations:
memory: 2G

dff-template-skill:
env_file: [.env]
Expand Down
4 changes: 4 additions & 0 deletions assistant_dists/dream/gpu1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ services:
restart: unless-stopped
midas-predictor:
restart: unless-stopped
dialogpt:
restart: unless-stopped
environment:
- CUDA_VISIBLE_DEVICES=9
dff-template-skill:
restart: unless-stopped
version: '3.7'
13 changes: 13 additions & 0 deletions assistant_dists/dream/pipeline_conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,19 @@
],
"state_manager_method": "add_hypothesis"
},
"dialogpt": {
"connector": {
"protocol": "http",
"timeout": 2,
"url": "http://dialogpt:8125/respond"
},
"dialog_formatter": "state_formatters.dp_formatters:last_utt_and_history_dialog",
"response_formatter": "state_formatters.dp_formatters:skill_with_attributes_formatter_service",
"previous_services": [
"skill_selectors"
],
"state_manager_method": "add_hypothesis"
},
"dff_template_skill": {
"connector": {
"protocol": "http",
Expand Down
10 changes: 9 additions & 1 deletion assistant_dists/dream/proxy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,15 @@ services:
environment:
- PROXY_PASS=dream.deeppavlov.ai:8121
- PORT=8121


dialogpt:
command: [ "nginx", "-g", "daemon off;" ]
build:
context: dp/proxy/
dockerfile: Dockerfile
environment:
- PROXY_PASS=dream.deeppavlov.ai:8125

dff-template-skill:
command: [ "nginx", "-g", "daemon off;" ]
build:
Expand Down
3 changes: 3 additions & 0 deletions assistant_dists/dream/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -123,5 +123,8 @@ services:
midas-predictor:
environment:
- CUDA_VISIBLE_DEVICES=6
dialogpt:
environment:
- CUDA_VISIBLE_DEVICES=6
dff-template-skill:
version: '3.7'
20 changes: 20 additions & 0 deletions services/dialogpt/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# syntax=docker/dockerfile:experimental

FROM pytorch/pytorch:1.5-cuda10.1-cudnn7-runtime

WORKDIR /src

ARG PRETRAINED_MODEL_NAME_OR_PATH
ENV PRETRAINED_MODEL_NAME_OR_PATH ${PRETRAINED_MODEL_NAME_OR_PATH}
ARG SERVICE_PORT
ENV SERVICE_PORT ${SERVICE_PORT}

COPY ./requirements.txt /src/requirements.txt
RUN pip install -r /src/requirements.txt

RUN python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('${PRETRAINED_MODEL_NAME_OR_PATH}');"
RUN python -c "from transformers import AutoModelForCausalLM; AutoModelForCausalLM.from_pretrained('${PRETRAINED_MODEL_NAME_OR_PATH}');"

COPY . /src

CMD gunicorn --workers=1 server:app -b 0.0.0.0:${SERVICE_PORT} --timeout=300
3 changes: 3 additions & 0 deletions services/dialogpt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
GPU RAM = 1Gb
cpu time = 0.15 sec
gpu time = 0.05 sec
9 changes: 9 additions & 0 deletions services/dialogpt/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
transformers==4.6.0
flask==1.1.1
itsdangerous==2.0.1
gunicorn==19.9.0
requests==2.22.0
sentry-sdk[flask]==0.14.1
healthcheck==1.3.3
jinja2<=3.0.3
Werkzeug<=2.0.3
83 changes: 83 additions & 0 deletions services/dialogpt/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import logging
import time
import os

import sentry_sdk
import torch
from flask import Flask, request, jsonify
from sentry_sdk.integrations.flask import FlaskIntegration
from transformers import AutoModelForCausalLM, AutoTokenizer

sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])


logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)

PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH")
logging.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}")
DEFAULT_CONFIDENCE = 0.9
ZERO_CONFIDENCE = 0.0
MAX_HISTORY_DEPTH = 3

try:
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
model = AutoModelForCausalLM.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
if torch.cuda.is_available():
model.to("cuda")
logger.info("dialogpt is set to run on cuda")

logger.info("dialogpt is ready")
except Exception as e:
sentry_sdk.capture_exception(e)
logger.exception(e)
raise e

app = Flask(__name__)
logging.getLogger("werkzeug").setLevel("WARNING")


def generate_response(context):
global model, tokenizer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

нам точно нужны глобалы?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

лучше бы поменять сигнатуру функции на generate_response(model, tokenizer, context):
и прокидывать их внутри фласковской функции

encoded_context = []
for uttr in context[-MAX_HISTORY_DEPTH:]:
encoded_context += [tokenizer.encode(uttr + tokenizer.eos_token, return_tensors="pt")]
bot_input_ids = torch.cat(encoded_context, dim=-1)

with torch.no_grad():
if torch.cuda.is_available():
bot_input_ids = bot_input_ids.to("cuda")
chat_history_ids = model.generate(
bot_input_ids, do_sample=True, max_length=50, top_k=3, pad_token_id=tokenizer.eos_token_id
)
if torch.cuda.is_available():
chat_history_ids = chat_history_ids.cpu()
return tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1] :][0], skip_special_tokens=True)


@app.route("/respond", methods=["POST"])
def respond():
st_time = time.time()
contexts = request.json.get("utterances_histories", [])

try:
responses = []
confidences = []
for context in contexts:
response = generate_response(context)
if len(response) > 3:
# drop too short responses
responses += [response]
confidences += [DEFAULT_CONFIDENCE]
else:
responses += [""]
confidences += [ZERO_CONFIDENCE]
except Exception as exc:
logger.exception(exc)
sentry_sdk.capture_exception(exc)
responses = [""] * len(contexts)
confidences = [ZERO_CONFIDENCE] * len(contexts)

total_time = time.time() - st_time
logger.info(f"masked_lm exec time: {total_time:.3f}s")
return jsonify(list(zip(responses, confidences)))
17 changes: 17 additions & 0 deletions services/dialogpt/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import requests


def test_respond():
url = "http://0.0.0.0:8125/respond"

contexts = [["hi", "hi. how are you?"], ["let's chat about movies", "cool. what movies do you like?"]]
gold_result = [["I'm good, how are you?", 0.9], ["I like the new one.", 0.9]]
result = requests.post(url, json={"utterances_histories": contexts}).json()
assert [
len(sample[0]) > 0 and sample[1] > 0.0 for sample in result
], f"Got\n{result}\n, but expected:\n{gold_result}"
print("Success")


if __name__ == "__main__":
test_respond()
3 changes: 3 additions & 0 deletions services/dialogpt/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python test.py
1 change: 1 addition & 0 deletions skill_selectors/rule_based_selector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ async def send(self, payload: Dict, callback: Callable):
skills_for_uttr.append("personal_info_skill")
skills_for_uttr.append("meta_script_skill")
skills_for_uttr.append("dummy_skill")
skills_for_uttr.append("dialogpt")
if len(dialog["utterances"]) < 20:
skills_for_uttr.append("dff_friendship_skill")

Expand Down
1 change: 1 addition & 0 deletions tests/dream/assert_test_dialogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"dff_travel_skill",
"dff_wiki_skill",
"game_cooperative_skill",
"dialogpt"
],
}

Expand Down
2 changes: 1 addition & 1 deletion tests/runtests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ if [[ "$MODE" == "test_skills" || "$MODE" == "all" ]]; then
fact-random fact-retrieval dff-intent-responder-skill badlisted-words \
dff-gossip-skill dff-wiki-skill topic-recommendation dff-science-skill personal-info-skill \
user-persona-extractor small-talk-skill wiki-facts dff-art-skill dff-funfact-skill \
meta-script-skill spelling-preprocessing dff-gaming-skill \
meta-script-skill spelling-preprocessing dff-gaming-skill dialogpt \
dff-music-skill dff-bot-persona-skill entity-detection midas-predictor; do

echo "Run tests for $container"
Expand Down