From 4845ebc42a812b5263e8f677512ebc560cf81fc7 Mon Sep 17 00:00:00 2001 From: yashkens <42929295+yashkens@users.noreply.github.com> Date: Fri, 25 Nov 2022 05:55:34 +0300 Subject: [PATCH] fix: dialogpt (#223) * fix context; cut unfinished responses * fix context * codestyle fix --- services/dialogpt/server.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/services/dialogpt/server.py b/services/dialogpt/server.py index 53c8dd9c63..6d161f2b80 100644 --- a/services/dialogpt/server.py +++ b/services/dialogpt/server.py @@ -2,6 +2,7 @@ import json import os import time +import re import sentry_sdk import torch @@ -21,7 +22,8 @@ logging.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}") DEFAULT_CONFIDENCE = 0.9 ZERO_CONFIDENCE = 0.0 -MAX_HISTORY_DEPTH = 3 +MAX_HISTORY_DEPTH = 2 +smiles_pattern = re.compile(r":[)(DpP3]") with open(CONFIG_NAME, "r") as f: generation_params = json.load(f) generation_params["num_return_sequences"] = N_HYPOTHESES_TO_GENERATE @@ -45,7 +47,12 @@ def generate_responses(context, model, tokenizer, continue_last_uttr=False): encoded_context = [] - for uttr in context[-MAX_HISTORY_DEPTH:-1]: + + history_depth = MAX_HISTORY_DEPTH + if len(context[-1].split()) > 3: + history_depth = MAX_HISTORY_DEPTH - 1 + + for uttr in context[-history_depth:-1]: encoded_context += [tokenizer.encode(uttr + " " + tokenizer.eos_token, return_tensors="pt")] if continue_last_uttr: encoded_context += [tokenizer.encode(context[-1] + " ", return_tensors="pt")] @@ -64,6 +71,24 @@ def generate_responses(context, model, tokenizer, continue_last_uttr=False): return outputs +def cut_response(response): + # if ends with a smile, it's finished + if smiles_pattern.match(response[-2:]): + return response + + leftover = re.split(r"[.!?]", response)[-1] + if leftover: + # strings with no ending punctuation will be empty + response = response[: -len(leftover)] + + # save smiles from cutting + smile = "" + if smiles_pattern.match(leftover.strip()[:2]): + smile = " " + leftover.strip()[:2] + response += smile + return response.strip() + + @app.route("/respond", methods=["POST"]) def respond(): st_time = time.time() @@ -77,6 +102,7 @@ def respond(): curr_confidences = [] outputs = generate_responses(context, model, tokenizer) for response in outputs: + response = cut_response(response) if len(response) > 3: # drop too short responses curr_responses += [response] @@ -110,6 +136,7 @@ def continue_last_uttr(): curr_responses = [] outputs = generate_responses(context, model, tokenizer, continue_last_uttr=True) for response in outputs: + response = cut_response(response) if len(response) > 3: # drop too short responses curr_responses += [response]