-
Notifications
You must be signed in to change notification settings - Fork 4.6k
/
rest.py
210 lines (172 loc) · 6.85 KB
/
rest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import asyncio
import copy
import inspect
import json
import logging
import structlog
from asyncio import Queue, CancelledError
from sanic import Blueprint, response
from sanic.request import Request
from sanic.response import HTTPResponse, ResponseStream
from typing import Text, Dict, Any, Optional, Callable, Awaitable, NoReturn, Union
import rasa.utils.endpoints
from rasa.core.channels.channel import (
InputChannel,
CollectingOutputChannel,
UserMessage,
)
logger = logging.getLogger(__name__)
structlogger = structlog.get_logger()
class RestInput(InputChannel):
"""A custom http input channel.
This implementation is the basis for a custom implementation of a chat
frontend. You can customize this to send messages to Rasa and
retrieve responses from the assistant.
"""
@classmethod
def name(cls) -> Text:
return "rest"
@staticmethod
async def on_message_wrapper(
on_new_message: Callable[[UserMessage], Awaitable[Any]],
text: Text,
queue: Queue,
sender_id: Text,
input_channel: Text,
metadata: Optional[Dict[Text, Any]],
) -> None:
collector = QueueOutputChannel(queue)
message = UserMessage(
text, collector, sender_id, input_channel=input_channel, metadata=metadata
)
await on_new_message(message)
await queue.put("DONE")
async def _extract_sender(self, req: Request) -> Optional[Text]:
return req.json.get("sender", None)
# noinspection PyMethodMayBeStatic
def _extract_message(self, req: Request) -> Optional[Text]:
return req.json.get("message", None)
def _extract_input_channel(self, req: Request) -> Text:
return req.json.get("input_channel") or self.name()
def get_metadata(self, request: Request) -> Optional[Dict[Text, Any]]:
"""Extracts additional information from the incoming request.
Implementing this function is not required. However, it can be used to extract
metadata from the request. The return value is passed on to the
``UserMessage`` object and stored in the conversation tracker.
Args:
request: incoming request with the message of the user
Returns:
Metadata which was extracted from the request.
"""
return request.json.get("metadata", None)
def stream_response(
self,
on_new_message: Callable[[UserMessage], Awaitable[None]],
text: Text,
sender_id: Text,
input_channel: Text,
metadata: Optional[Dict[Text, Any]],
) -> Callable[[Any], Awaitable[None]]:
"""Streams response to the client.
If the stream option is enabled, this method will be called to
stream the response to the client
Args:
on_new_message: sanic event
text: message text
sender_id: message sender_id
input_channel: input channel name
metadata: optional metadata sent with the message
Returns:
Sanic stream
"""
async def stream(resp: Any) -> None:
q: Queue = Queue()
task = asyncio.ensure_future(
self.on_message_wrapper(
on_new_message, text, q, sender_id, input_channel, metadata
)
)
while True:
result = await q.get()
if result == "DONE":
break
else:
await resp.write(json.dumps(result) + "\n")
await task
return stream
def blueprint(
self, on_new_message: Callable[[UserMessage], Awaitable[None]]
) -> Blueprint:
"""Groups the collection of endpoints used by rest channel."""
module_type = inspect.getmodule(self)
if module_type is not None:
module_name = module_type.__name__
else:
module_name = None
custom_webhook = Blueprint(
"custom_webhook_{}".format(type(self).__name__),
module_name,
)
# noinspection PyUnusedLocal
@custom_webhook.route("/", methods=["GET"])
async def health(request: Request) -> HTTPResponse:
return response.json({"status": "ok"})
@custom_webhook.route("/webhook", methods=["POST"])
async def receive(request: Request) -> Union[ResponseStream, HTTPResponse]:
sender_id = await self._extract_sender(request)
text = self._extract_message(request)
should_use_stream = rasa.utils.endpoints.bool_arg(
request, "stream", default=False
)
input_channel = self._extract_input_channel(request)
metadata = self.get_metadata(request)
if should_use_stream:
return response.stream(
self.stream_response(
on_new_message, text, sender_id, input_channel, metadata
),
content_type="text/event-stream",
)
else:
collector = CollectingOutputChannel()
# noinspection PyBroadException
try:
await on_new_message(
UserMessage(
text,
collector,
sender_id,
input_channel=input_channel,
metadata=metadata,
headers=request.headers,
)
)
except CancelledError:
structlogger.error(
"rest.message.received.timeout", text=copy.deepcopy(text)
)
except Exception:
structlogger.exception(
"rest.message.received.failure", text=copy.deepcopy(text)
)
return response.json(collector.messages)
return custom_webhook
class QueueOutputChannel(CollectingOutputChannel):
"""Output channel that collects send messages in a list.
(doesn't send them anywhere, just collects them).
"""
# FIXME: this is breaking Liskov substitution principle
# and would require some user-facing refactoring to address
messages: Queue # type: ignore[assignment]
@classmethod
def name(cls) -> Text:
"""Name of QueueOutputChannel."""
return "queue"
# noinspection PyMissingConstructor
def __init__(self, message_queue: Optional[Queue] = None) -> None:
super().__init__()
self.messages = Queue() if not message_queue else message_queue
def latest_output(self) -> NoReturn:
raise NotImplementedError("A queue doesn't allow to peek at messages.")
async def _persist_message(self, message: Dict[Text, Any]) -> None:
await self.messages.put(message)