-
Notifications
You must be signed in to change notification settings - Fork 77
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Fix requirements.txt (#84) * fix itsdangerous requirements * pin itsdangerous requirements for all flask==1.1.1 servers * increase timeout to 5s * add logs * increase to 100 * add first working version gpt_persona * add sentence_ranking(not working) * fix wrong endpoint * create sentecnce ranker annotator * rewrite text generation logic * add comments to code * fix gpt_persona fallback * clean code, add train script, write tests * add get_intents, remove dataset, add hyperparams * add batch support * fix: move files * fix: merge * fix: codestyle * fix: remove sentence ranker * feat: new distribution and rename skill * feat: new annotator * feat: relative persona extractor * fix: codestyle * fix: proxy * fix: params * fix: volumes * fix: reqs * fix: tests * fix: tests relative sents extr * fix: batching * fix: codestyle * fix: persona extractor tests * fix: persona get * fix: tests * fix: imports * fix: logs * fix: docs * fix: add midas * fix: command * feat: add to main dist and tests * feat: remove infilling, add dialogpt persona based, docs * fix: gpus * fix: params * fix: param * fix: indent * fix: remove infilling from tests Co-authored-by: Andrii.Hura <54397922+AndriiHura@users.noreply.github.com> Co-authored-by: mtalimanchuk <mtalimanchuk@gmail.com> Co-authored-by: Dilyara Baymurzina <dilyara.rimovna@gmail.com>
- Loading branch information
1 parent
9374614
commit c3bb406
Showing
34 changed files
with
2,575 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
FROM python:3.8.4 | ||
|
||
ARG SERVICE_PORT | ||
ENV SERVICE_PORT ${SERVICE_PORT} | ||
ARG N_SENTENCES_OT_RETURN | ||
ENV N_SENTENCES_OT_RETURN ${N_SENTENCES_OT_RETURN} | ||
|
||
RUN mkdir /src | ||
|
||
COPY ./annotators/relative_persona_extractor/ /src/ | ||
COPY ./common/ /src/common/ | ||
|
||
COPY annotators/relative_persona_extractor/requirements.txt /src/requirements.txt | ||
RUN pip install -r /src/requirements.txt | ||
|
||
WORKDIR /src | ||
|
||
CMD gunicorn --workers=1 server:app --bind 0.0.0.0:8000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Relative Persona Extractor | ||
|
||
An annotator that utilizes Sentence Ranker to find the most relevant to the current context sentences from the bot's persona description. | ||
|
||
The number of returned sentences is given as an environmental variable using `N_SENTENCES_OT_RETURN` in `docker-compose.yml`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
flask==1.1.1 | ||
itsdangerous==2.0.1 | ||
gunicorn==20.0.4 | ||
sentry-sdk==0.13.4 | ||
requests==2.22.0 | ||
click<=8.0.4 | ||
jinja2<=3.0.3 | ||
Werkzeug<=2.0.3 | ||
numpy==1.17.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import logging | ||
import requests | ||
import time | ||
from os import getenv | ||
|
||
import numpy as np | ||
import sentry_sdk | ||
from flask import Flask, request, jsonify | ||
|
||
|
||
sentry_sdk.init(getenv("SENTRY_DSN")) | ||
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.DEBUG) | ||
logger = logging.getLogger(__name__) | ||
app = Flask(__name__) | ||
|
||
SENTENCE_RANKER_SERVICE_URL = getenv("SENTENCE_RANKER_SERVICE_URL") | ||
N_SENTENCES_OT_RETURN = int(getenv("N_SENTENCES_OT_RETURN")) | ||
with open("common/persona_sentences.txt", "r") as f: | ||
PERSONA_SENTENCES = f.read().splitlines() | ||
PERSONA_SENTENCES = [x.strip() for x in PERSONA_SENTENCES if len(x.strip())] | ||
|
||
|
||
def get_result(request): | ||
st_time = time.time() | ||
contexts = request.json["contexts"] | ||
result = [] | ||
pairs = [] | ||
context_ids = [] | ||
|
||
for context_id, context in enumerate(contexts): | ||
str_context = " ".join(context) | ||
for sent in PERSONA_SENTENCES: | ||
pairs += [[str_context, sent]] | ||
context_ids += [context_id] | ||
context_ids = np.array(context_ids) | ||
try: | ||
scores = requests.post(SENTENCE_RANKER_SERVICE_URL, json={"sentence_pairs": pairs}, timeout=1.5).json()[0][ | ||
"batch" | ||
] | ||
scores = np.array(scores) | ||
for i, context in enumerate(contexts): | ||
curr_ids = np.where(context_ids == i)[0] | ||
most_relevant_sent_ids = np.argsort(scores[curr_ids])[::-1][:N_SENTENCES_OT_RETURN] | ||
curr_result = { | ||
"persona": [PERSONA_SENTENCES[_id] for _id in most_relevant_sent_ids], | ||
"max_similarity": scores[curr_ids][most_relevant_sent_ids[0]], | ||
} | ||
logger.info(f"Persona: {curr_result['persona']}") | ||
result += [curr_result] | ||
except Exception as exc: | ||
logger.exception(exc) | ||
sentry_sdk.capture_exception(exc) | ||
result = [{"persona": [], "max_similarity": 0.0}] * len(contexts) | ||
|
||
total_time = time.time() - st_time | ||
logger.info(f"relative-persona-extractor exec time: {total_time:.3f}s") | ||
return result | ||
|
||
|
||
@app.route("/respond", methods=["POST"]) | ||
def respond(): | ||
result = get_result(request) | ||
return jsonify(result) | ||
|
||
|
||
@app.route("/respond_batch", methods=["POST"]) | ||
def respond_batch(): | ||
result = get_result(request) | ||
return jsonify([{"batch": result}]) | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(debug=False, host="0.0.0.0", port=3000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import requests | ||
|
||
|
||
def main(): | ||
url = "http://0.0.0.0:8133/respond" | ||
input_data = { | ||
"contexts": [ | ||
[ | ||
"How do you spend your spare time?", | ||
"I like to watch movies and eat pizza.", | ||
"Cool! What else do you like?", | ||
], | ||
[ | ||
"I like to go to the cinema on fridays", | ||
"great. how do you spend your spare time?", | ||
"I like to watch movies", | ||
], | ||
] | ||
} | ||
gold = [ | ||
{ | ||
"max_similarity": 0.6948127746582031, | ||
"persona": [ | ||
"I like Italian food especially pasta and pizza.", | ||
"I like to watch football and basketball on TV.", | ||
"I like watching travel video blogs.", | ||
], | ||
}, | ||
{ | ||
"max_similarity": 0.6451027989387512, | ||
"persona": [ | ||
"I like watching travel video blogs.", | ||
"I like to watch football and basketball on TV.", | ||
"I like Italian food especially pasta and pizza.", | ||
], | ||
}, | ||
] | ||
|
||
result = requests.post(url, json=input_data).json() | ||
assert result == gold, print(result) | ||
print("Success!") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#!/bin/bash | ||
|
||
python test.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.