-
Notifications
You must be signed in to change notification settings - Fork 77
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: dialogpt * feat: dialogpt en * fix: add dialogpt * fix: new version * fix: codestyle * more configs * more configs * fix: emotion skill less confident * fix: formatting * fix: intent catcher * fix: check for dialogpt * fix: codestyle * fix: signature of the function
- Loading branch information
1 parent
f28d81f
commit ca14416
Showing
17 changed files
with
195 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,4 +43,6 @@ services: | |
environment: | ||
DEVICE: cpu | ||
CUDA_VISIBLE_DEVICES: "" | ||
|
||
dialogpt: | ||
environment: | ||
CUDA_VISIBLE_DEVICES: "" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# syntax=docker/dockerfile:experimental | ||
|
||
FROM pytorch/pytorch:1.5-cuda10.1-cudnn7-runtime | ||
|
||
WORKDIR /src | ||
|
||
ARG PRETRAINED_MODEL_NAME_OR_PATH | ||
ENV PRETRAINED_MODEL_NAME_OR_PATH ${PRETRAINED_MODEL_NAME_OR_PATH} | ||
ARG SERVICE_PORT | ||
ENV SERVICE_PORT ${SERVICE_PORT} | ||
|
||
COPY ./requirements.txt /src/requirements.txt | ||
RUN pip install -r /src/requirements.txt | ||
|
||
RUN python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('${PRETRAINED_MODEL_NAME_OR_PATH}');" | ||
RUN python -c "from transformers import AutoModelForCausalLM; AutoModelForCausalLM.from_pretrained('${PRETRAINED_MODEL_NAME_OR_PATH}');" | ||
|
||
COPY . /src | ||
|
||
CMD gunicorn --workers=1 server:app -b 0.0.0.0:${SERVICE_PORT} --timeout=300 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
GPU RAM = 1Gb | ||
cpu time = 0.15 sec | ||
gpu time = 0.05 sec |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
transformers==4.6.0 | ||
flask==1.1.1 | ||
itsdangerous==2.0.1 | ||
gunicorn==19.9.0 | ||
requests==2.22.0 | ||
sentry-sdk[flask]==0.14.1 | ||
healthcheck==1.3.3 | ||
jinja2<=3.0.3 | ||
Werkzeug<=2.0.3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import logging | ||
import time | ||
import os | ||
|
||
import sentry_sdk | ||
import torch | ||
from flask import Flask, request, jsonify | ||
from sentry_sdk.integrations.flask import FlaskIntegration | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()]) | ||
|
||
|
||
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH") | ||
logging.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}") | ||
DEFAULT_CONFIDENCE = 0.9 | ||
ZERO_CONFIDENCE = 0.0 | ||
MAX_HISTORY_DEPTH = 3 | ||
|
||
try: | ||
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH) | ||
model = AutoModelForCausalLM.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH) | ||
if torch.cuda.is_available(): | ||
model.to("cuda") | ||
logger.info("dialogpt is set to run on cuda") | ||
|
||
logger.info("dialogpt is ready") | ||
except Exception as e: | ||
sentry_sdk.capture_exception(e) | ||
logger.exception(e) | ||
raise e | ||
|
||
app = Flask(__name__) | ||
logging.getLogger("werkzeug").setLevel("WARNING") | ||
|
||
|
||
def generate_response(context, model, tokenizer): | ||
encoded_context = [] | ||
for uttr in context[-MAX_HISTORY_DEPTH:]: | ||
encoded_context += [tokenizer.encode(uttr + tokenizer.eos_token, return_tensors="pt")] | ||
bot_input_ids = torch.cat(encoded_context, dim=-1) | ||
|
||
with torch.no_grad(): | ||
if torch.cuda.is_available(): | ||
bot_input_ids = bot_input_ids.to("cuda") | ||
chat_history_ids = model.generate( | ||
bot_input_ids, do_sample=True, max_length=50, top_k=3, pad_token_id=tokenizer.eos_token_id | ||
) | ||
if torch.cuda.is_available(): | ||
chat_history_ids = chat_history_ids.cpu() | ||
return tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1] :][0], skip_special_tokens=True) | ||
|
||
|
||
@app.route("/respond", methods=["POST"]) | ||
def respond(): | ||
st_time = time.time() | ||
contexts = request.json.get("utterances_histories", []) | ||
|
||
try: | ||
responses = [] | ||
confidences = [] | ||
for context in contexts: | ||
response = generate_response(context, model, tokenizer) | ||
if len(response) > 3: | ||
# drop too short responses | ||
responses += [response] | ||
confidences += [DEFAULT_CONFIDENCE] | ||
else: | ||
responses += [""] | ||
confidences += [ZERO_CONFIDENCE] | ||
except Exception as exc: | ||
logger.exception(exc) | ||
sentry_sdk.capture_exception(exc) | ||
responses = [""] * len(contexts) | ||
confidences = [ZERO_CONFIDENCE] * len(contexts) | ||
|
||
total_time = time.time() - st_time | ||
logger.info(f"masked_lm exec time: {total_time:.3f}s") | ||
return jsonify(list(zip(responses, confidences))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import requests | ||
|
||
|
||
def test_respond(): | ||
url = "http://0.0.0.0:8125/respond" | ||
|
||
contexts = [["hi", "hi. how are you?"], ["let's chat about movies", "cool. what movies do you like?"]] | ||
gold_result = [["I'm good, how are you?", 0.9], ["I like the new one.", 0.9]] | ||
result = requests.post(url, json={"utterances_histories": contexts}).json() | ||
assert [ | ||
len(sample[0]) > 0 and sample[1] > 0.0 for sample in result | ||
], f"Got\n{result}\n, but expected:\n{gold_result}" | ||
print("Success") | ||
|
||
|
||
if __name__ == "__main__": | ||
test_respond() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#!/bin/bash | ||
|
||
python test.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,7 @@ | |
"dff_travel_skill", | ||
"dff_wiki_skill", | ||
"game_cooperative_skill", | ||
"dialogpt", | ||
], | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters