Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialoGPT Persona #185

Merged
merged 67 commits into from
Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
72842e3
Fix requirements.txt (#84)
AndriiHura Jan 24, 2022
a67e15c
fix itsdangerous requirements
mtalimanchuk Feb 18, 2022
0f8ef0e
pin itsdangerous requirements for all flask==1.1.1 servers
mtalimanchuk Feb 18, 2022
6f0684a
Merge pull request #102 from deepmipt/fix/combined-classification-fla…
mtalimanchuk Feb 18, 2022
d237711
Merge pull request #103 from deepmipt/dev
dilyararimovna Feb 18, 2022
e990264
Merge pull request #107 from deepmipt/dev
dilyararimovna Mar 2, 2022
3208f71
Merge pull request #119 from deepmipt/dev
dilyararimovna Mar 11, 2022
ab44553
Merge pull request #123 from deepmipt/dev
dilyararimovna Mar 18, 2022
1c9a463
Merge pull request #137 from deepmipt/dev
dilyararimovna Apr 8, 2022
f8e4a59
Merge pull request #145 from deepmipt/dev
dilyararimovna Apr 30, 2022
48872a6
Merge pull request #150 from deepmipt/dev
dilyararimovna May 4, 2022
ed42f0c
Merge pull request #153 from deepmipt/dev
dilyararimovna May 5, 2022
30f290c
Merge pull request #155 from deepmipt/dev
dilyararimovna May 6, 2022
de510bc
Merge pull request #158 from deepmipt/dev
dilyararimovna May 11, 2022
ab2dcbd
Merge pull request #165 from deepmipt/dev
dilyararimovna May 27, 2022
525783a
Merge pull request #174 from deepmipt/dev
dilyararimovna Jun 27, 2022
7e87a36
Merge pull request #177 from deepmipt/dev
dilyararimovna Jun 30, 2022
3c2169d
increase timeout to 5s
dmitrymailk Jul 5, 2022
0e9f1bb
add logs
dmitrymailk Jul 5, 2022
49e3270
increase to 100
dmitrymailk Jul 5, 2022
f286d94
add first working version gpt_persona
dmitrymailk Jul 13, 2022
57849b6
add sentence_ranking(not working)
dmitrymailk Jul 14, 2022
ff686d9
fix wrong endpoint
dmitrymailk Jul 14, 2022
36af140
create sentecnce ranker annotator
dmitrymailk Jul 14, 2022
ba674a6
rewrite text generation logic
dmitrymailk Jul 14, 2022
5cb5695
add comments to code
dmitrymailk Jul 14, 2022
3d25e53
fix gpt_persona fallback
dmitrymailk Jul 15, 2022
563f643
clean code, add train script, write tests
dmitrymailk Jul 20, 2022
b0d7f2b
add get_intents, remove dataset, add hyperparams
dmitrymailk Jul 24, 2022
d8d1499
add batch support
dmitrymailk Jul 26, 2022
2d66aa4
fix: move files
dilyararimovna Aug 29, 2022
dad1c10
fix: move files
dilyararimovna Aug 29, 2022
f409638
fix: merge
dilyararimovna Aug 29, 2022
43973f9
fix: codestyle
dilyararimovna Aug 29, 2022
3783b76
fix: remove sentence ranker
dilyararimovna Aug 29, 2022
c2811fe
feat: new distribution and rename skill
dilyararimovna Aug 29, 2022
353379c
feat: new annotator
dilyararimovna Aug 29, 2022
ab88168
feat: relative persona extractor
dilyararimovna Aug 29, 2022
a601237
fix: codestyle
dilyararimovna Aug 29, 2022
051356d
fix: proxy
dilyararimovna Aug 29, 2022
d19cce1
Merge branch 'persona_bot' of https://github.com/dmitrymailk/dream in…
dilyararimovna Aug 29, 2022
ea82137
fix: params
dilyararimovna Aug 29, 2022
b791fb6
fix: volumes
dilyararimovna Aug 29, 2022
270326f
fix: reqs
dilyararimovna Aug 29, 2022
1569dfc
fix: tests
dilyararimovna Aug 29, 2022
ca5042b
fix: tests relative sents extr
dilyararimovna Aug 29, 2022
a42faa6
fix: batching
dilyararimovna Aug 30, 2022
56b72e3
fix: codestyle
dilyararimovna Aug 30, 2022
fd257db
fix: persona extractor tests
dilyararimovna Aug 30, 2022
9f2eed5
fix: persona get
dilyararimovna Aug 30, 2022
bcd32bf
fix: tests
dilyararimovna Aug 30, 2022
9142095
fix: imports
dilyararimovna Aug 30, 2022
1ddfa2b
fix: logs
dilyararimovna Aug 31, 2022
e3f43f6
fix: docs
dilyararimovna Aug 31, 2022
3b9baff
fix: add midas
dilyararimovna Aug 31, 2022
222071b
Merge remote-tracking branch 'origin/dev' into persona_bot
dilyararimovna Sep 16, 2022
f35fda6
fix: command
dilyararimovna Sep 16, 2022
9a8e663
Merge remote-tracking branch 'origin/dev' into persona_bot
dilyararimovna Sep 20, 2022
4c7dfdd
feat: add to main dist and tests
dilyararimovna Sep 20, 2022
28c7dc4
feat: remove infilling, add dialogpt persona based, docs
dilyararimovna Sep 21, 2022
aefd444
fix: gpus
dilyararimovna Sep 26, 2022
84ca2fd
fix: params
dilyararimovna Sep 27, 2022
191764c
fix: param
dilyararimovna Sep 27, 2022
6a7453f
Merge remote-tracking branch 'origin/dev' into persona_bot
dilyararimovna Sep 27, 2022
4cedb68
fix: indent
dilyararimovna Sep 27, 2022
fa51ce6
fix: remove infilling from tests
dilyararimovna Sep 28, 2022
999f185
Merge remote-tracking branch 'origin/dev' into persona_bot
dilyararimovna Oct 7, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 144 additions & 108 deletions README.md

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions annotators/sentence_ranker/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# syntax=docker/dockerfile:experimental

FROM pytorch/pytorch:1.5-cuda10.1-cudnn7-runtime

WORKDIR /src

ARG PRETRAINED_MODEL_NAME_OR_PATH
ENV PRETRAINED_MODEL_NAME_OR_PATH ${PRETRAINED_MODEL_NAME_OR_PATH}
ARG SERVICE_PORT
ENV SERVICE_PORT ${SERVICE_PORT}


COPY ./requirements.txt /src/requirements.txt
COPY ./persona_sentences.txt /src/persona_sentences.txt
RUN pip install -r /src/requirements.txt

RUN python -c "from sentence_transformers import SentenceTransformer;SentenceTransformer('${PRETRAINED_MODEL_NAME_OR_PATH}')"

COPY . /src

CMD gunicorn --workers=1 server:app -b 0.0.0.0:${SERVICE_PORT} --timeout=300
1 change: 1 addition & 0 deletions annotators/sentence_ranker/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cpu time = 0.036 sec
dilyararimovna marked this conversation as resolved.
Show resolved Hide resolved
26 changes: 26 additions & 0 deletions annotators/sentence_ranker/persona_sentences.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
My name is Dream.
I am a chatbot.
My work is to talk to people about everything.
I like to talk to people about their lives.
I have a lot of friends.
My friends are all the people I talk to.
I become friends with everyone.
I love people. I can not play sport.
I like to watch football and basketball on TV.
My favorite sport is football.
I like watching people swimming.
I love dogs so much.
Dogs are the best friends.
I do not like cats at all.
I do not have a dog yet.
I would love to have a dog in future.
I like Italian food especially pasta and pizza.
My favorite food is ice-cream.
I hate onion.
I like travelling.
I can not travel physically.
I like visiting interesting places virtually.
I love to walk on Paris streets with Google Maps.
I like watching travel video blogs.
I adore watching wild animals.
I am scared of spiders and snakes.
11 changes: 11 additions & 0 deletions annotators/sentence_ranker/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
transformers==4.20.1
flask==1.1.1
itsdangerous==2.0.1
gunicorn==19.9.0
requests==2.22.0
sentry-sdk[flask]==0.14.1
healthcheck==1.3.3
jinja2<=3.0.3
Werkzeug<=2.0.3
sentence-transformers==2.2.2
huggingface-hub==0.4.0
128 changes: 128 additions & 0 deletions annotators/sentence_ranker/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from sentence_transformers import SentenceTransformer
import logging
import time
import os

import sentry_sdk
import torch
from flask import Flask, request, jsonify
from sentry_sdk.integrations.flask import FlaskIntegration
sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])

class SentenceRanker:
def __init__(self,
persona_sentences=None,
sentence_model=None
):
"""ranks a person's sentences based on context

Args:
persona_sentences (List[str]): a list of sentences constituting a complete person. Defaults to None.
sentence_model (SentenceTransformer): model for translating a sentence into a vector. Defaults to None.
"""
self.persona_sentences = persona_sentences
self.sentence_model = sentence_model
self.sentence_embeddings = self.sentence_model.encode(
persona_sentences,
convert_to_tensor=True
)
# for caching similar queries
self.ranked_sentences = {}

def rank_sentences(self, query, k):
"""returns top k sentences that are similar to query

Args:
query (str): sentence on the basis of which we are looking for similar
k (int): the number of sentences returned. Defaults to 5.

Returns:
List[List[str], float]: ranked sentences and maximum cosine distance among all
"""
key = f"{query}_{k}"
if self.ranked_sentences.get(key, False):
return self.ranked_sentences[key]

user_sentence_embeddings = self.sentence_model.encode(query, convert_to_tensor=True)

cos_sim_ranks = self.cos_sim(
user_sentence_embeddings,
self.sentence_embeddings
)

top_indices = torch.argsort(cos_sim_ranks, descending=True)
max_similarity = float(cos_sim_ranks[top_indices][0])
top_indices = list(top_indices[:k].cpu().numpy())
similar_sentences = [self.persona_sentences[idx] for idx in top_indices]
self.ranked_sentences[key] = similar_sentences, max_similarity
return [similar_sentences, max_similarity]

def cos_sim(self, a, b):
"""returns the cosine distance

K - number of sentences to compare
N - dimension of the returned vector
Args:
a (torch.FloatTensor): shape (1, N)
b (torch.FloatTensor): shape (K, N)

Returns:
torch.FloatTensor: shape (1, K) tensor with cosine distances
"""
a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
return torch.sum(a_norm * b_norm, dim=1)

logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)

logger = logging.getLogger(__name__)

PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH")
TOP_SIMILAR_SENTENCES = 5

try:
sentence_model = SentenceTransformer(PRETRAINED_MODEL_NAME_OR_PATH)

persona = open("./persona_sentences.txt").read()
persona_sentences = persona.split("\n")
persona_sentences = [item.strip() for item in persona_sentences if len(item) > 0]

sentence_ranker = SentenceRanker(
persona_sentences=persona_sentences,
sentence_model=sentence_model
)
logger.info("sentence_ranker is ready")
except Exception as e:
sentry_sdk.capture_exception(e)
logger.exception(e)
raise e

app = Flask(__name__)

@app.route("/response", methods=["POST"])
dilyararimovna marked this conversation as resolved.
Show resolved Hide resolved
def respond():
try:
start_time = time.time()
dialogs = request.json.get("dialogs", [])
dilyararimovna marked this conversation as resolved.
Show resolved Hide resolved
process_result = []
# take the last replica, then take the result of the sentseg annotator from it
last_utterance = dialogs[0]["human_utterances"][-1]["annotations"]['sentseg']['punct_sent']
max_likelihood_sentences, max_sentence_similarity = sentence_ranker.rank_sentences(
[last_utterance],
k=TOP_SIMILAR_SENTENCES
)

process_result.append([
max_likelihood_sentences,
max_sentence_similarity
])

except Exception as exc:
logger.exception(exc)
sentry_sdk.capture_exception(exc)

total_time = time.time() - start_time
logger.info(f"sentence_ranker exec time: {total_time:.3f}s")
return jsonify(
process_result
)
32 changes: 32 additions & 0 deletions annotators/sentence_ranker/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os
import requests
import json

def test_respond():
url = "http://0.0.0.0:8130/response"

# [0]["human_utterances"][-1]["annotations"]['sentseg']['punct_sent']
test_data = {
"dialogs": [
{
"human_utterances": [
{
"annotations": {
"sentseg": {
"punct_sent": "Hi. Do you like onions?"
dilyararimovna marked this conversation as resolved.
Show resolved Hide resolved

}
}
}
]
}
]
}

result = requests.post(url, json=test_data).json()
assert len(result[0][0]) > 0, "Empty response"
print("Success")


if __name__ == "__main__":
test_respond()
8 changes: 8 additions & 0 deletions assistant_dists/dream_mini/cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ services:
environment:
DEVICE: cpu
CUDA_VISIBLE_DEVICES: ""
dialogpt-persona:
environment:
DEVICE: cpu
CUDA_VISIBLE_DEVICES: ""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sentence ranker тоже

sentence-ranker:
environment:
DEVICE: cpu
CUDA_VISIBLE_DEVICES: ""
intent-catcher:
environment:
DEVICE: cpu
Expand Down
13 changes: 13 additions & 0 deletions assistant_dists/dream_mini/dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,17 @@ services:
- "./services/dialogpt:/src"
ports:
- 8125:8125

dialogpt-persona:
volumes:
- "./services/dialogpt_persona:/src"
ports:
- 8131:8131

sentence-ranker:
volumes:
- "./annotators/sentence_ranker:/src"
ports:
- 8130:8130

version: "3.7"
Loading