Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: dialogpt #223

Merged
merged 4 commits into from
Nov 25, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions services/dialogpt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os
import time
import re

import sentry_sdk
import torch
Expand All @@ -21,7 +22,8 @@
logging.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}")
DEFAULT_CONFIDENCE = 0.9
ZERO_CONFIDENCE = 0.0
MAX_HISTORY_DEPTH = 3
MAX_HISTORY_DEPTH = 2
smiles_pattern = re.compile(r":[)(DpP3]")
with open(CONFIG_NAME, "r") as f:
generation_params = json.load(f)
generation_params["num_return_sequences"] = N_HYPOTHESES_TO_GENERATE
Expand All @@ -45,7 +47,12 @@

def generate_responses(context, model, tokenizer, continue_last_uttr=False):
encoded_context = []
for uttr in context[-MAX_HISTORY_DEPTH:-1]:

history_depth = MAX_HISTORY_DEPTH
if len(context[-1].split()) > 3:
history_depth = MAX_HISTORY_DEPTH - 1

for uttr in context[-history_depth:-1]:
encoded_context += [tokenizer.encode(uttr + " " + tokenizer.eos_token, return_tensors="pt")]
if continue_last_uttr:
encoded_context += [tokenizer.encode(context[-1] + " ", return_tensors="pt")]
Expand All @@ -64,6 +71,24 @@ def generate_responses(context, model, tokenizer, continue_last_uttr=False):
return outputs


def cut_response(response):
# if ends with a smile, it's finished
if smiles_pattern.match(response[-2:]):
return response

leftover = re.split(r"[.!?]", response)[-1]
if leftover:
# strings with no ending punctuation will be empty
response = response[: -len(leftover)]

# save smiles from cutting
smile = ""
if smiles_pattern.match(leftover.strip()[:2]):
smile = " " + leftover.strip()[:2]
response += smile
return response.strip()


@app.route("/respond", methods=["POST"])
def respond():
st_time = time.time()
Expand All @@ -77,6 +102,7 @@ def respond():
curr_confidences = []
outputs = generate_responses(context, model, tokenizer)
for response in outputs:
response = cut_response(response)
if len(response) > 3:
# drop too short responses
curr_responses += [response]
Expand Down Expand Up @@ -110,6 +136,7 @@ def continue_last_uttr():
curr_responses = []
outputs = generate_responses(context, model, tokenizer, continue_last_uttr=True)
for response in outputs:
response = cut_response(response)
if len(response) > 3:
# drop too short responses
curr_responses += [response]
Expand Down