Skip to content

Commit

Permalink
Merge pull request #106 from RasaHQ/ENG-680-DEFAULT_KEEP_ALIVE_TIMEOUT
Browse files Browse the repository at this point in the history
Fix connection to action server - [ENG 680]
  • Loading branch information
tmbo authored Nov 22, 2023
2 parents a4a1d17 + ebdb08d commit 199f6d0
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 157 deletions.
96 changes: 51 additions & 45 deletions rasa/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,53 +112,59 @@ async def _pull_model_and_fingerprint(

logger.debug(f"Requesting model from server {model_server.url}...")

try:
params = model_server.combine_parameters()
async with model_server.session.request(
"GET",
model_server.url,
timeout=DEFAULT_REQUEST_TIMEOUT,
headers=headers,
params=params,
) as resp:
if resp.status in [204, 304]:
logger.debug(
"Model server returned {} status code, "
"indicating that no new model is available. "
"Current fingerprint: {}"
"".format(resp.status, fingerprint)
)
return None
elif resp.status == 404:
logger.debug(
"Model server could not find a model at the requested "
"endpoint '{}'. It's possible that no model has been "
"trained, or that the requested tag hasn't been "
"assigned.".format(model_server.url)
)
return None
elif resp.status != 200:
logger.debug(
"Tried to fetch model from server, but server response "
"status code is {}. We'll retry later..."
"".format(resp.status)
async with model_server.session() as session:
try:
params = model_server.combine_parameters()
async with session.request(
"GET",
model_server.url,
timeout=DEFAULT_REQUEST_TIMEOUT,
headers=headers,
params=params,
) as resp:

if resp.status in [204, 304]:
logger.debug(
"Model server returned {} status code, "
"indicating that no new model is available. "
"Current fingerprint: {}"
"".format(resp.status, fingerprint)
)
return None
elif resp.status == 404:
logger.debug(
"Model server could not find a model at the requested "
"endpoint '{}'. It's possible that no model has been "
"trained, or that the requested tag hasn't been "
"assigned.".format(model_server.url)
)
return None
elif resp.status != 200:
logger.debug(
"Tried to fetch model from server, but server response "
"status code is {}. We'll retry later..."
"".format(resp.status)
)
return None

model_path = Path(model_directory) / resp.headers.get(
"filename", "model.tar.gz"
)
return None
model_path = Path(model_directory) / resp.headers.get(
"filename", "model.tar.gz"
with open(model_path, "wb") as file:
file.write(await resp.read())

logger.debug("Saved model to '{}'".format(os.path.abspath(model_path)))

# return the new fingerprint
return resp.headers.get("ETag")

except aiohttp.ClientError as e:
logger.debug(
"Tried to fetch model from server, but "
"couldn't reach server. We'll retry later... "
"Error: {}.".format(e)
)
with open(model_path, "wb") as file:
file.write(await resp.read())
logger.debug("Saved model to '{}'".format(os.path.abspath(model_path)))
# return the new fingerprint
return resp.headers.get("ETag")
except aiohttp.ClientError as e:
logger.debug(
"Tried to fetch model from server, but "
"couldn't reach server. We'll retry later... "
"Error: {}.".format(e)
)
return None
return None


async def _run_model_pulling_worker(model_server: EndpointConfig, agent: Agent) -> None:
Expand Down
2 changes: 2 additions & 0 deletions rasa/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

DEFAULT_LOCK_LIFETIME = 60 # in seconds

DEFAULT_KEEP_ALIVE_TIMEOUT = 120 # in seconds

BEARER_TOKEN_PREFIX = "Bearer "

# The lowest priority is intended to be used by machine learning policies.
Expand Down
48 changes: 3 additions & 45 deletions rasa/core/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Optional,
Text,
Tuple,
TYPE_CHECKING,
Union,
Dict,
)
Expand All @@ -34,8 +33,6 @@
from sanic import Sanic
from asyncio import AbstractEventLoop

if TYPE_CHECKING:
from aiohttp import ClientSession

logger = logging.getLogger() # get the root logger

Expand Down Expand Up @@ -120,6 +117,7 @@ def configure_app(
request_timeout: Optional[int] = None,
server_listeners: Optional[List[Tuple[Callable, Text]]] = None,
use_uvloop: Optional[bool] = True,
keep_alive_timeout: int = constants.DEFAULT_KEEP_ALIVE_TIMEOUT,
) -> Sanic:
"""Run the agent."""
rasa.core.utils.configure_file_logging(
Expand All @@ -139,6 +137,7 @@ def configure_app(
else:
app = _create_app_without_api(cors)

app.config.KEEP_ALIVE_TIMEOUT = keep_alive_timeout
if _is_apple_silicon_system() or not use_uvloop:
app.config.USE_UVLOOP = False
# some library still sets the loop to uvloop, even if disabled for sanic
Expand Down Expand Up @@ -251,7 +250,7 @@ def serve_application(
partial(load_agent_on_start, model_path, endpoints, remote_storage),
"before_server_start",
)
app.register_listener(create_connection_pools, "after_server_start")

app.register_listener(close_resources, "after_server_stop")

number_of_workers = rasa.core.utils.number_of_sanic_workers(
Expand Down Expand Up @@ -313,44 +312,3 @@ async def close_resources(app: Sanic, _: AbstractEventLoop) -> None:
event_broker = current_agent.tracker_store.event_broker
if event_broker:
await event_broker.close()

action_endpoint = current_agent.action_endpoint
if action_endpoint:
await action_endpoint.session.close()

model_server = current_agent.model_server
if model_server:
await model_server.session.close()


async def create_connection_pools(app: Sanic, _: AbstractEventLoop) -> None:
"""Create connection pools for the agent's action server and model server."""
current_agent = getattr(app.ctx, "agent", None)
if not current_agent:
logger.debug("No agent found after server start.")
return None

create_action_endpoint_connection_pool(current_agent)
create_model_server_connection_pool(current_agent)

return None


def create_action_endpoint_connection_pool(agent: Agent) -> Optional["ClientSession"]:
"""Create a connection pool for the action endpoint."""
action_endpoint = agent.action_endpoint
if not action_endpoint:
logger.debug("No action endpoint found after server start.")
return None

return action_endpoint.session


def create_model_server_connection_pool(agent: Agent) -> Optional["ClientSession"]:
"""Create a connection pool for the model server."""
model_server = agent.model_server
if not model_server:
logger.debug("No model server endpoint found after server start.")
return None

return model_server.session
68 changes: 39 additions & 29 deletions rasa/utils/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import ssl
from functools import cached_property

import aiohttp
import logging
import os
from aiohttp.client_exceptions import ContentTypeError
from sanic.request import Request
Expand All @@ -11,10 +9,11 @@
from rasa.shared.exceptions import FileNotFoundException
import rasa.shared.utils.io
import rasa.utils.io
import structlog
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT


logger = logging.getLogger(__name__)
structlogger = structlog.get_logger()


def read_endpoint_config(
Expand All @@ -32,9 +31,13 @@ def read_endpoint_config(

return EndpointConfig.from_dict(content[endpoint_type])
except FileNotFoundError:
logger.error(
"Failed to read endpoint configuration "
"from {}. No such file.".format(os.path.abspath(filename))
structlogger.error(
"endpoint.read.failed_no_such_file",
filename=os.path.abspath(filename),
event_info=(
"Failed to read endpoint configuration file - "
"the file was not found."
),
)
return None

Expand All @@ -56,9 +59,13 @@ def concat_url(base: Text, subpath: Optional[Text]) -> Text:
"""
if not subpath:
if base.endswith("/"):
logger.debug(
f"The URL '{base}' has a trailing slash. Please make sure the "
f"target server supports trailing slashes for this endpoint."
structlogger.debug(
"endpoint.concat_url.trailing_slash",
url=base,
event_info=(
"The URL has a trailing slash. Please make sure the "
"target server supports trailing slashes for this endpoint."
),
)
return base

Expand Down Expand Up @@ -95,7 +102,6 @@ def __init__(
self.cafile = cafile
self.kwargs = kwargs

@cached_property
def session(self) -> aiohttp.ClientSession:
"""Creates and returns a configured aiohttp client session."""
# create authentication parameters
Expand Down Expand Up @@ -164,23 +170,26 @@ async def request(
f"'{os.path.abspath(self.cafile)}' does not exist."
) from e

async with self.session.request(
method,
url,
headers=headers,
params=self.combine_parameters(kwargs),
compress=compress,
ssl=sslcontext,
**kwargs,
) as response:
if response.status >= 400:
raise ClientResponseError(
response.status, response.reason, await response.content.read()
)
try:
return await response.json()
except ContentTypeError:
return None
async with self.session() as session:
async with session.request(
method,
url,
headers=headers,
params=self.combine_parameters(kwargs),
compress=compress,
ssl=sslcontext,
**kwargs,
) as response:
if response.status >= 400:
raise ClientResponseError(
response.status,
response.reason,
await response.content.read(),
)
try:
return await response.json()
except ContentTypeError:
return None

@classmethod
def from_dict(cls, data: Dict[Text, Any]) -> "EndpointConfig":
Expand Down Expand Up @@ -263,7 +272,7 @@ def float_arg(
try:
return float(str(arg))
except (ValueError, TypeError):
logger.warning(f"Failed to convert '{arg}' to float.")
structlogger.warning("endpoint.float_arg.convert_failed", arg=arg, key=key)
return default


Expand Down Expand Up @@ -291,5 +300,6 @@ def int_arg(
try:
return int(str(arg))
except (ValueError, TypeError):
logger.warning(f"Failed to convert '{arg}' to int.")

structlogger.warning("endpoint.int_arg.convert_failed", arg=arg, key=key)
return default
3 changes: 0 additions & 3 deletions tests/core/test_run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import warnings
from unittest.mock import Mock

import aiohttp
import pytest
from typing import Text

Expand Down Expand Up @@ -84,8 +83,6 @@ async def test_close_resources(loop: AbstractEventLoop):
broker = SQLEventBroker()
app = Mock()
app.ctx.agent.tracker_store.event_broker = broker
app.ctx.agent.action_endpoint.session = aiohttp.ClientSession()
app.ctx.agent.model_server.session = aiohttp.ClientSession()

with warnings.catch_warnings() as record:
await run.close_resources(app, loop)
Expand Down
Loading

0 comments on commit 199f6d0

Please sign in to comment.