Skip to content

Commit

Permalink
Feat/dialogpt en (#136)
Browse files Browse the repository at this point in the history
* feat: dialogpt

* feat: dialogpt en

* fix: add dialogpt

* fix: new version

* fix: codestyle

* more configs

* more configs

* fix: emotion skill less confident

* fix: formatting

* fix: intent catcher

* fix: check for dialogpt

* fix: codestyle

* fix: signature of the function
  • Loading branch information
dilyararimovna authored Apr 29, 2022
1 parent f28d81f commit ca14416
Show file tree
Hide file tree
Showing 17 changed files with 195 additions and 5 deletions.
2 changes: 1 addition & 1 deletion annotators/IntentCatcher/data/intent_phrases.json
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@
},
"what_can_you_do": {
"phrases": [
"((so )){0,1}(tell me ){0,1}what can you do",
"((so )){0,1}(tell me ){0,1}what (else ){0,1}can you do",
"(tell me ){0,1}what are you able to do",
"(tell me ){0,1}what you can do",
"what are your ((skills)|(abilities)|(features))",
Expand Down
4 changes: 3 additions & 1 deletion assistant_dists/dream/cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,6 @@ services:
environment:
DEVICE: cpu
CUDA_VISIBLE_DEVICES: ""

dialogpt:
environment:
CUDA_VISIBLE_DEVICES: ""
5 changes: 5 additions & 0 deletions assistant_dists/dream/dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,11 @@ services:
- "./annotators/midas_predictor:/src"
ports:
- 8121:8121
dialogpt:
volumes:
- "./services/dialogpt:/src"
ports:
- 8125:8125
dff-template-skill:
volumes:
- "./skills/dff_template_skill:/src"
Expand Down
21 changes: 20 additions & 1 deletion assistant_dists/dream/docker-compose.override.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ services:
dff-funfact-skill:8104, dff-bot-persona-skill:8105, news-api-annotator:8112,
dff-gossip-skill:8109, dff-wiki-skill:8111, dff-gaming-skill:8115, topic-recommendation:8113,
user-persona-extractor:8114, wiki-facts:8116, dff-music-skill:8099, entity-detection:8103, dff-art-skill:8117,
midas-predictor:8121, dff-template-skill:8120"
midas-predictor:8121, dialogpt:8125, dff-template-skill:8120"
WAIT_HOSTS_TIMEOUT: ${WAIT_TIMEOUT:-480}
convers-evaluator-annotator:
env_file: [.env]
Expand Down Expand Up @@ -1136,6 +1136,25 @@ services:
memory: 50M
reservations:
memory: 50M

dialogpt:
env_file: [ .env ]
build:
args:
SERVICE_PORT: 8125
SERVICE_NAME: dialogpt
PRETRAINED_MODEL_NAME_OR_PATH: microsoft/DialoGPT-small
context: ./services/dialogpt/
command: flask run -h 0.0.0.0 -p 8125
environment:
- CUDA_VISIBLE_DEVICES=0
- FLASK_APP=server
deploy:
resources:
limits:
memory: 2G
reservations:
memory: 2G

dff-template-skill:
env_file: [.env]
Expand Down
4 changes: 4 additions & 0 deletions assistant_dists/dream/gpu1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ services:
restart: unless-stopped
midas-predictor:
restart: unless-stopped
dialogpt:
restart: unless-stopped
environment:
- CUDA_VISIBLE_DEVICES=9
dff-template-skill:
restart: unless-stopped
version: '3.7'
13 changes: 13 additions & 0 deletions assistant_dists/dream/pipeline_conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,19 @@
],
"state_manager_method": "add_hypothesis"
},
"dialogpt": {
"connector": {
"protocol": "http",
"timeout": 2,
"url": "http://dialogpt:8125/respond"
},
"dialog_formatter": "state_formatters.dp_formatters:last_utt_and_history_dialog",
"response_formatter": "state_formatters.dp_formatters:skill_with_attributes_formatter_service",
"previous_services": [
"skill_selectors"
],
"state_manager_method": "add_hypothesis"
},
"dff_template_skill": {
"connector": {
"protocol": "http",
Expand Down
10 changes: 9 additions & 1 deletion assistant_dists/dream/proxy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,15 @@ services:
environment:
- PROXY_PASS=dream.deeppavlov.ai:8121
- PORT=8121


dialogpt:
command: [ "nginx", "-g", "daemon off;" ]
build:
context: dp/proxy/
dockerfile: Dockerfile
environment:
- PROXY_PASS=dream.deeppavlov.ai:8125

dff-template-skill:
command: [ "nginx", "-g", "daemon off;" ]
build:
Expand Down
3 changes: 3 additions & 0 deletions assistant_dists/dream/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -123,5 +123,8 @@ services:
midas-predictor:
environment:
- CUDA_VISIBLE_DEVICES=6
dialogpt:
environment:
- CUDA_VISIBLE_DEVICES=6
dff-template-skill:
version: '3.7'
20 changes: 20 additions & 0 deletions services/dialogpt/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# syntax=docker/dockerfile:experimental

FROM pytorch/pytorch:1.5-cuda10.1-cudnn7-runtime

WORKDIR /src

ARG PRETRAINED_MODEL_NAME_OR_PATH
ENV PRETRAINED_MODEL_NAME_OR_PATH ${PRETRAINED_MODEL_NAME_OR_PATH}
ARG SERVICE_PORT
ENV SERVICE_PORT ${SERVICE_PORT}

COPY ./requirements.txt /src/requirements.txt
RUN pip install -r /src/requirements.txt

RUN python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('${PRETRAINED_MODEL_NAME_OR_PATH}');"
RUN python -c "from transformers import AutoModelForCausalLM; AutoModelForCausalLM.from_pretrained('${PRETRAINED_MODEL_NAME_OR_PATH}');"

COPY . /src

CMD gunicorn --workers=1 server:app -b 0.0.0.0:${SERVICE_PORT} --timeout=300
3 changes: 3 additions & 0 deletions services/dialogpt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
GPU RAM = 1Gb
cpu time = 0.15 sec
gpu time = 0.05 sec
9 changes: 9 additions & 0 deletions services/dialogpt/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
transformers==4.6.0
flask==1.1.1
itsdangerous==2.0.1
gunicorn==19.9.0
requests==2.22.0
sentry-sdk[flask]==0.14.1
healthcheck==1.3.3
jinja2<=3.0.3
Werkzeug<=2.0.3
82 changes: 82 additions & 0 deletions services/dialogpt/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import logging
import time
import os

import sentry_sdk
import torch
from flask import Flask, request, jsonify
from sentry_sdk.integrations.flask import FlaskIntegration
from transformers import AutoModelForCausalLM, AutoTokenizer

sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])


logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)

PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH")
logging.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}")
DEFAULT_CONFIDENCE = 0.9
ZERO_CONFIDENCE = 0.0
MAX_HISTORY_DEPTH = 3

try:
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
model = AutoModelForCausalLM.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
if torch.cuda.is_available():
model.to("cuda")
logger.info("dialogpt is set to run on cuda")

logger.info("dialogpt is ready")
except Exception as e:
sentry_sdk.capture_exception(e)
logger.exception(e)
raise e

app = Flask(__name__)
logging.getLogger("werkzeug").setLevel("WARNING")


def generate_response(context, model, tokenizer):
encoded_context = []
for uttr in context[-MAX_HISTORY_DEPTH:]:
encoded_context += [tokenizer.encode(uttr + tokenizer.eos_token, return_tensors="pt")]
bot_input_ids = torch.cat(encoded_context, dim=-1)

with torch.no_grad():
if torch.cuda.is_available():
bot_input_ids = bot_input_ids.to("cuda")
chat_history_ids = model.generate(
bot_input_ids, do_sample=True, max_length=50, top_k=3, pad_token_id=tokenizer.eos_token_id
)
if torch.cuda.is_available():
chat_history_ids = chat_history_ids.cpu()
return tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1] :][0], skip_special_tokens=True)


@app.route("/respond", methods=["POST"])
def respond():
st_time = time.time()
contexts = request.json.get("utterances_histories", [])

try:
responses = []
confidences = []
for context in contexts:
response = generate_response(context, model, tokenizer)
if len(response) > 3:
# drop too short responses
responses += [response]
confidences += [DEFAULT_CONFIDENCE]
else:
responses += [""]
confidences += [ZERO_CONFIDENCE]
except Exception as exc:
logger.exception(exc)
sentry_sdk.capture_exception(exc)
responses = [""] * len(contexts)
confidences = [ZERO_CONFIDENCE] * len(contexts)

total_time = time.time() - st_time
logger.info(f"masked_lm exec time: {total_time:.3f}s")
return jsonify(list(zip(responses, confidences)))
17 changes: 17 additions & 0 deletions services/dialogpt/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import requests


def test_respond():
url = "http://0.0.0.0:8125/respond"

contexts = [["hi", "hi. how are you?"], ["let's chat about movies", "cool. what movies do you like?"]]
gold_result = [["I'm good, how are you?", 0.9], ["I like the new one.", 0.9]]
result = requests.post(url, json={"utterances_histories": contexts}).json()
assert [
len(sample[0]) > 0 and sample[1] > 0.0 for sample in result
], f"Got\n{result}\n, but expected:\n{gold_result}"
print("Success")


if __name__ == "__main__":
test_respond()
3 changes: 3 additions & 0 deletions services/dialogpt/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python test.py
1 change: 1 addition & 0 deletions skill_selectors/rule_based_selector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ async def send(self, payload: Dict, callback: Callable):
skills_for_uttr.append("personal_info_skill")
skills_for_uttr.append("meta_script_skill")
skills_for_uttr.append("dummy_skill")
skills_for_uttr.append("dialogpt")
if len(dialog["utterances"]) < 20:
skills_for_uttr.append("dff_friendship_skill")

Expand Down
1 change: 1 addition & 0 deletions tests/dream/assert_test_dialogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"dff_travel_skill",
"dff_wiki_skill",
"game_cooperative_skill",
"dialogpt",
],
}

Expand Down
2 changes: 1 addition & 1 deletion tests/runtests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ if [[ "$MODE" == "test_skills" || "$MODE" == "all" ]]; then
fact-random fact-retrieval dff-intent-responder-skill badlisted-words \
dff-gossip-skill dff-wiki-skill topic-recommendation dff-science-skill personal-info-skill \
user-persona-extractor small-talk-skill wiki-facts dff-art-skill dff-funfact-skill \
meta-script-skill spelling-preprocessing dff-gaming-skill \
meta-script-skill spelling-preprocessing dff-gaming-skill dialogpt \
dff-music-skill dff-bot-persona-skill entity-detection midas-predictor; do

echo "Run tests for $container"
Expand Down

0 comments on commit ca14416

Please sign in to comment.