From 67ae1c50506a6fbea6103f256f2cf9bd000f98c9 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 22 Nov 2023 16:55:38 +0800 Subject: [PATCH] feat: optimize db connections in thread --- api/services/completion_service.py | 41 ++++++++++++++++-------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/api/services/completion_service.py b/api/services/completion_service.py index 280bdf76963b01..249766236f61c2 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -232,7 +232,7 @@ def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_m logging.exception("Unknown Error in completion") PubHandler.pub_error(user, generate_task_id, e) finally: - db.session.commit() + db.session.remove() @classmethod def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user, @@ -242,22 +242,25 @@ def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_u def close_pubsub(): with flask_app.app_context(): - user = db.session.merge(detached_user) - - sleep_iterations = 0 - while sleep_iterations < timeout and worker_thread.is_alive(): - if sleep_iterations > 0 and sleep_iterations % 10 == 0: - PubHandler.ping(user, generate_task_id) - - time.sleep(1) - sleep_iterations += 1 - - if worker_thread.is_alive(): - PubHandler.stop(user, generate_task_id) - try: - pubsub.close() - except Exception: - pass + try: + user = db.session.merge(detached_user) + + sleep_iterations = 0 + while sleep_iterations < timeout and worker_thread.is_alive(): + if sleep_iterations > 0 and sleep_iterations % 10 == 0: + PubHandler.ping(user, generate_task_id) + + time.sleep(1) + sleep_iterations += 1 + + if worker_thread.is_alive(): + PubHandler.stop(user, generate_task_id) + try: + pubsub.close() + except Exception: + pass + finally: + db.session.remove() countdown_thread = threading.Thread(target=close_pubsub) countdown_thread.start() @@ -394,7 +397,7 @@ def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict logging.exception(e) raise finally: - db.session.commit() + db.session.remove() try: pubsub.unsubscribe(generate_channel) @@ -436,7 +439,7 @@ def generate() -> Generator: logging.exception(e) raise finally: - db.session.commit() + db.session.remove() try: pubsub.unsubscribe(generate_channel)