diff --git a/aiosip/application.py b/aiosip/application.py index 24d1d32..5f6a769 100644 --- a/aiosip/application.py +++ b/aiosip/application.py @@ -13,7 +13,7 @@ from collections import MutableMapping from . import __version__ -from .dialog import Dialog +from .dialog import Dialog, DialogRequest from .dialplan import BaseDialplan from .protocol import UDP, TCP, WS from .peers import UDPConnector, TCPConnector, WSConnector @@ -95,40 +95,8 @@ async def _call_route(self, peer, route, msg): for middleware_factory in reversed(self._middleware): route = await middleware_factory(route) - app = self - call_id = msg.headers['Call-ID'] - - # TODO: refactor - class Request: - def __init__(self): - self.app = app - self.dialog = None - - def _create_dialog(self, dialog_factory=Dialog, **kwargs): - if not self.dialog: - self.dialog = peer._create_dialog( - method=msg.method, - from_details=Contact.from_header(msg.headers['To']), - to_details=Contact.from_header(msg.headers['From']), - call_id=call_id, - inbound=True, - dialog_factory=dialog_factory, - **kwargs - ) - return self.dialog - - async def prepare(self, status_code, *args, **kwargs): - dialog = self._create_dialog() - - await dialog.reply(msg, status_code, *args, **kwargs) - if status_code >= 300: - await dialog.close() - return None - - return dialog - - request = Request() - await route(request, msg) + request = DialogRequest(self, msg, peer) + await route(request) async def _dispatch(self, protocol, msg, addr): call_id = msg.headers['Call-ID'] diff --git a/aiosip/dialog.py b/aiosip/dialog.py index 11795f3..7d8af87 100644 --- a/aiosip/dialog.py +++ b/aiosip/dialog.py @@ -8,6 +8,7 @@ from . import utils from .auth import Auth +from .contact import Contact from .message import Request, Response from .transaction import UnreliableTransaction @@ -118,9 +119,9 @@ def ack(self, msg, headers=None, *args, **kwargs): ack = self._prepare_request('ACK', cseq=msg.cseq, to_details=msg.to_details, headers=headers, *args, **kwargs) self.peer.send_message(ack) - async def unauthorized(self, msg): + async def unauthorized(self, msg, **kwargs): self._nonce = utils.gen_str(10) - headers = CIMultiDict() + headers = kwargs.get('headers', CIMultiDict()) headers['WWW-Authenticate'] = str(Auth(nonce=self._nonce, algorithm='md5', realm='sip')) await self.reply(msg, status_code=401, headers=headers) @@ -459,3 +460,52 @@ async def close(self, timeout=None): self._close() self._close() + + +class DialogRequest: + + def __init__(self, app, message, peer): + self.app = app + self.peer = peer + self.message = message + self.dialog = None + + def _create_dialog(self, dialog_factory=Dialog, **kwargs): + if not self.dialog: + self.dialog = self.peer._create_dialog( + method=self.message.method, + from_details=Contact.from_header(self.message.headers['To']), + to_details=Contact.from_header(self.message.headers['From']), + call_id=self.message['Call-ID'], + inbound=True, + dialog_factory=dialog_factory, + **kwargs + ) + return self.dialog + + async def prepare(self, status_code, *args, **kwargs): + dialog = self._create_dialog() + + await dialog.reply(self.message, status_code, *args, **kwargs) + if status_code >= 300: + await dialog.close() + return None + + return dialog + + async def unauthorized(self, *args, **kwargs): + dialog = self._create_dialog() + await dialog.unauthorized(self.message, *args, **kwargs) + return dialog + + @property + def headers(self): + return self.message.headers + + @property + def payload(self): + return self.message.payload + + @property + def method(self): + return self.message.method diff --git a/tests/test_sip_scenario.py b/tests/test_sip_scenario.py index ca74656..03376fd 100644 --- a/tests/test_sip_scenario.py +++ b/tests/test_sip_scenario.py @@ -13,8 +13,8 @@ async def resolve(self, *args, **kwargs): await super().resolve(*args, **kwargs) return self.subscribe - async def subscribe(self, request, msg): - expires = int(msg.headers['Expires']) + async def subscribe(self, request): + expires = int(request.headers['Expires']) dialog = await request.prepare(status_code=200, headers={'Expires': expires}) await asyncio.sleep(0.1) @@ -67,12 +67,9 @@ async def resolve(self, *args, **kwargs): await super().resolve(*args, **kwargs) return self.subscribe - async def subscribe(self, request, message): - dialog = request._create_dialog() - - received_messages.append(message) - assert not dialog.validate_auth(message, password) - await dialog.unauthorized(message) + async def subscribe(self, request): + received_messages.append(request) + dialog = await request.unauthorized() async for message in dialog: received_messages.append(message) @@ -117,11 +114,9 @@ async def resolve(self, *args, **kwargs): await super().resolve(*args, **kwargs) return self.subscribe - async def subscribe(self, request, message): - dialog = request._create_dialog() - - received_messages.append(message) - await dialog.unauthorized(message) + async def subscribe(self, request): + received_messages.append(request) + dialog = await request.unauthorized() async for message in dialog: received_messages.append(message) @@ -169,12 +164,12 @@ async def resolve(self, *args, **kwargs): await super().resolve(*args, **kwargs) return self.invite - async def invite(self, request, message): + async def invite(self, request): dialog = await request.prepare(status_code=100) await asyncio.sleep(0.1) - await dialog.reply(message, status_code=180) + await dialog.reply(request, status_code=180) await asyncio.sleep(0.1) - await dialog.reply(message, status_code=200) + await dialog.reply(request, status_code=200) call_established.set_result(None) async for message in dialog: @@ -230,11 +225,11 @@ async def resolve(self, *args, **kwargs): elif kwargs['method'] == 'CANCEL': return self.cancel - async def subscribe(self, request, message): + async def subscribe(self, request): pending_subscription.cancel() - async def cancel(self, request, message): - cancel_future.set_result(message) + async def cancel(self, request): + cancel_future.set_result(request) app = aiosip.Application(loop=loop) server_app = aiosip.Application(loop=loop, dialplan=Dialplan()) diff --git a/tests/test_sip_server.py b/tests/test_sip_server.py index aa3c3fb..7459735 100644 --- a/tests/test_sip_server.py +++ b/tests/test_sip_server.py @@ -11,9 +11,9 @@ async def resolve(self, *args, **kwargs): return self.on_subscribe - async def on_subscribe(self, request, message): + async def on_subscribe(self, request): await request.prepare(status_code=200) - callback_complete.set_result(message) + callback_complete.set_result(request) app = aiosip.Application(loop=loop) @@ -77,7 +77,7 @@ async def resolve(self, *args, **kwargs): return self.on_subscribe - async def on_subscribe(self, request, message): + async def on_subscribe(self, request): raise RuntimeError('Test error') app = aiosip.Application(loop=loop)