Skip to content

Commit

Permalink
fix: dialogpt (#223)
Browse files Browse the repository at this point in the history
* fix context; cut unfinished responses

* fix context

* codestyle fix
  • Loading branch information
yashkens authored Nov 25, 2022
1 parent 528e7c4 commit be73916
Showing 1 changed file with 29 additions and 2 deletions.
31 changes: 29 additions & 2 deletions services/dialogpt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os
import time
import re

import sentry_sdk
import torch
Expand All @@ -21,7 +22,8 @@
logging.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}")
DEFAULT_CONFIDENCE = 0.9
ZERO_CONFIDENCE = 0.0
MAX_HISTORY_DEPTH = 3
MAX_HISTORY_DEPTH = 2
smiles_pattern = re.compile(r":[)(DpP3]")
with open(CONFIG_NAME, "r") as f:
generation_params = json.load(f)
generation_params["num_return_sequences"] = N_HYPOTHESES_TO_GENERATE
Expand All @@ -45,7 +47,12 @@

def generate_responses(context, model, tokenizer, continue_last_uttr=False):
encoded_context = []
for uttr in context[-MAX_HISTORY_DEPTH:-1]:

history_depth = MAX_HISTORY_DEPTH
if len(context[-1].split()) > 3:
history_depth = MAX_HISTORY_DEPTH - 1

for uttr in context[-history_depth:-1]:
encoded_context += [tokenizer.encode(uttr + " " + tokenizer.eos_token, return_tensors="pt")]
if continue_last_uttr:
encoded_context += [tokenizer.encode(context[-1] + " ", return_tensors="pt")]
Expand All @@ -64,6 +71,24 @@ def generate_responses(context, model, tokenizer, continue_last_uttr=False):
return outputs


def cut_response(response):
# if ends with a smile, it's finished
if smiles_pattern.match(response[-2:]):
return response

leftover = re.split(r"[.!?]", response)[-1]
if leftover:
# strings with no ending punctuation will be empty
response = response[: -len(leftover)]

# save smiles from cutting
smile = ""
if smiles_pattern.match(leftover.strip()[:2]):
smile = " " + leftover.strip()[:2]
response += smile
return response.strip()


@app.route("/respond", methods=["POST"])
def respond():
st_time = time.time()
Expand All @@ -77,6 +102,7 @@ def respond():
curr_confidences = []
outputs = generate_responses(context, model, tokenizer)
for response in outputs:
response = cut_response(response)
if len(response) > 3:
# drop too short responses
curr_responses += [response]
Expand Down Expand Up @@ -110,6 +136,7 @@ def continue_last_uttr():
curr_responses = []
outputs = generate_responses(context, model, tokenizer, continue_last_uttr=True)
for response in outputs:
response = cut_response(response)
if len(response) > 3:
# drop too short responses
curr_responses += [response]
Expand Down

0 comments on commit be73916

Please sign in to comment.