-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/dialogpt en #136
Merged
Merged
Feat/dialogpt en #136
Changes from 15 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
2da74b3
feat: dialogpt
dilyararimovna 64ce0eb
feat: dialogpt en
dilyararimovna bb3c1f9
fix: add dialogpt
dilyararimovna f61d507
fix: new version
dilyararimovna ad23e85
fix: codestyle
dilyararimovna 91f5748
more configs
dilyararimovna 4e6b119
more configs
dilyararimovna 1cbd430
Merge remote-tracking branch 'origin/dev' into feat/dialogpt_en
dilyararimovna a20244b
Merge remote-tracking branch 'origin/dev' into feat/dialogpt_en
dilyararimovna e13cc19
fix: emotion skill less confident
dilyararimovna 6b7c454
Merge remote-tracking branch 'origin/dev' into feat/dialogpt_en
dilyararimovna 78c4f14
fix: formatting
dilyararimovna d0e7214
fix: intent catcher
dilyararimovna df3f711
Merge remote-tracking branch 'origin/dev' into feat/dialogpt_en
dilyararimovna b60221b
fix: check for dialogpt
dilyararimovna 3baa162
fix: codestyle
dilyararimovna 93a193b
fix: signature of the function
dilyararimovna File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,4 +43,6 @@ services: | |
environment: | ||
DEVICE: cpu | ||
CUDA_VISIBLE_DEVICES: "" | ||
|
||
dialogpt: | ||
environment: | ||
CUDA_VISIBLE_DEVICES: "" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# syntax=docker/dockerfile:experimental | ||
|
||
FROM pytorch/pytorch:1.5-cuda10.1-cudnn7-runtime | ||
|
||
WORKDIR /src | ||
|
||
ARG PRETRAINED_MODEL_NAME_OR_PATH | ||
ENV PRETRAINED_MODEL_NAME_OR_PATH ${PRETRAINED_MODEL_NAME_OR_PATH} | ||
ARG SERVICE_PORT | ||
ENV SERVICE_PORT ${SERVICE_PORT} | ||
|
||
COPY ./requirements.txt /src/requirements.txt | ||
RUN pip install -r /src/requirements.txt | ||
|
||
RUN python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('${PRETRAINED_MODEL_NAME_OR_PATH}');" | ||
RUN python -c "from transformers import AutoModelForCausalLM; AutoModelForCausalLM.from_pretrained('${PRETRAINED_MODEL_NAME_OR_PATH}');" | ||
|
||
COPY . /src | ||
|
||
CMD gunicorn --workers=1 server:app -b 0.0.0.0:${SERVICE_PORT} --timeout=300 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
GPU RAM = 1Gb | ||
cpu time = 0.15 sec | ||
gpu time = 0.05 sec |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
transformers==4.6.0 | ||
flask==1.1.1 | ||
itsdangerous==2.0.1 | ||
gunicorn==19.9.0 | ||
requests==2.22.0 | ||
sentry-sdk[flask]==0.14.1 | ||
healthcheck==1.3.3 | ||
jinja2<=3.0.3 | ||
Werkzeug<=2.0.3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import logging | ||
import time | ||
import os | ||
|
||
import sentry_sdk | ||
import torch | ||
from flask import Flask, request, jsonify | ||
from sentry_sdk.integrations.flask import FlaskIntegration | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()]) | ||
|
||
|
||
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH") | ||
logging.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}") | ||
DEFAULT_CONFIDENCE = 0.9 | ||
ZERO_CONFIDENCE = 0.0 | ||
MAX_HISTORY_DEPTH = 3 | ||
|
||
try: | ||
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH) | ||
model = AutoModelForCausalLM.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH) | ||
if torch.cuda.is_available(): | ||
model.to("cuda") | ||
logger.info("dialogpt is set to run on cuda") | ||
|
||
logger.info("dialogpt is ready") | ||
except Exception as e: | ||
sentry_sdk.capture_exception(e) | ||
logger.exception(e) | ||
raise e | ||
|
||
app = Flask(__name__) | ||
logging.getLogger("werkzeug").setLevel("WARNING") | ||
|
||
|
||
def generate_response(context): | ||
global model, tokenizer | ||
encoded_context = [] | ||
for uttr in context[-MAX_HISTORY_DEPTH:]: | ||
encoded_context += [tokenizer.encode(uttr + tokenizer.eos_token, return_tensors="pt")] | ||
bot_input_ids = torch.cat(encoded_context, dim=-1) | ||
|
||
with torch.no_grad(): | ||
if torch.cuda.is_available(): | ||
bot_input_ids = bot_input_ids.to("cuda") | ||
chat_history_ids = model.generate( | ||
bot_input_ids, do_sample=True, max_length=50, top_k=3, pad_token_id=tokenizer.eos_token_id | ||
) | ||
if torch.cuda.is_available(): | ||
chat_history_ids = chat_history_ids.cpu() | ||
return tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1] :][0], skip_special_tokens=True) | ||
|
||
|
||
@app.route("/respond", methods=["POST"]) | ||
def respond(): | ||
st_time = time.time() | ||
contexts = request.json.get("utterances_histories", []) | ||
|
||
try: | ||
responses = [] | ||
confidences = [] | ||
for context in contexts: | ||
response = generate_response(context) | ||
if len(response) > 3: | ||
# drop too short responses | ||
responses += [response] | ||
confidences += [DEFAULT_CONFIDENCE] | ||
else: | ||
responses += [""] | ||
confidences += [ZERO_CONFIDENCE] | ||
except Exception as exc: | ||
logger.exception(exc) | ||
sentry_sdk.capture_exception(exc) | ||
responses = [""] * len(contexts) | ||
confidences = [ZERO_CONFIDENCE] * len(contexts) | ||
|
||
total_time = time.time() - st_time | ||
logger.info(f"masked_lm exec time: {total_time:.3f}s") | ||
return jsonify(list(zip(responses, confidences))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import requests | ||
|
||
|
||
def test_respond(): | ||
url = "http://0.0.0.0:8125/respond" | ||
|
||
contexts = [["hi", "hi. how are you?"], ["let's chat about movies", "cool. what movies do you like?"]] | ||
gold_result = [["I'm good, how are you?", 0.9], ["I like the new one.", 0.9]] | ||
result = requests.post(url, json={"utterances_histories": contexts}).json() | ||
assert [ | ||
len(sample[0]) > 0 and sample[1] > 0.0 for sample in result | ||
], f"Got\n{result}\n, but expected:\n{gold_result}" | ||
print("Success") | ||
|
||
|
||
if __name__ == "__main__": | ||
test_respond() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#!/bin/bash | ||
|
||
python test.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,7 @@ | |
"dff_travel_skill", | ||
"dff_wiki_skill", | ||
"game_cooperative_skill", | ||
"dialogpt" | ||
], | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
нам точно нужны глобалы?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
лучше бы поменять сигнатуру функции на
generate_response(model, tokenizer, context):
и прокидывать их внутри фласковской функции