From 80e9548e51762e47ed01849a16cbc02f17ee95d7 Mon Sep 17 00:00:00 2001 From: Maksym Sobolyev Date: Sat, 27 Jul 2024 17:06:05 -0700 Subject: [PATCH] Hook up payload filtering into remote SDP update callback, so it stays effective for re-INVITEs. --- sippy/b2bua_radius.py | 52 ++++++++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/sippy/b2bua_radius.py b/sippy/b2bua_radius.py index dff8487..a8a14e7 100755 --- a/sippy/b2bua_radius.py +++ b/sippy/b2bua_radius.py @@ -148,6 +148,37 @@ def __init__(self, remote_ip, source, req_source, req_target, global_config, pas self.pass_headers = pass_headers self.req_source = req_source self.req_target = req_target + if '_allowed_pts' in self.global_config: + self.uaA.on_remote_sdp_change = self.filter_SDP + + def filter_SDP(self, body, done_cb): + try: + body.parse() + except SdpParseError as ex: + done_cb(None, ex=ex) + return + except Exception as ex: + exx = SdpParseError(f'{ex}') + exx.msg = 'Malformed SDP body' + exx.code = 400 + exx.__cause__ = ex + done_cb(None, ex=exx) + return + allowed_pts = self.global_config['_allowed_pts'] + for sect in body.content.sections: + mbody = sect.m_header + if mbody.transport.lower() not in self.rtpps_cls.AV_TRTYPES: + continue + old_len = len(mbody.formats) + _allowed_pts = [x if isinstance(x, int) else sect.getPTbyName(x) for x in allowed_pts] + mbody.formats = [x for x in mbody.formats if x in _allowed_pts] + if len(mbody.formats) == 0: + ex = SdpParseError() + done_cb(None, ex=ex) + return + if old_len > len(mbody.formats): + sect.optimize_a() + done_cb(body) def recvEvent(self, event, ua): if ua == self.uaA: @@ -161,27 +192,6 @@ def recvEvent(self, event, ua): self.uaA.recvEvent(CCEventFail((500, 'Internal Server Error (1)'), rtime = event.rtime)) self.state = CCStateDead return - if body != None and '_allowed_pts' in self.global_config: - try: - body.parse() - except: - self.uaA.recvEvent(CCEventFail((400, 'Malformed SDP Body'), rtime = event.rtime)) - self.state = CCStateDead - return - allowed_pts = self.global_config['_allowed_pts'] - for sect in body.content.sections: - mbody = sect.m_header - if mbody.transport.lower() not in self.rtpps_cls.AV_TRTYPES: - continue - old_len = len(mbody.formats) - _allowed_pts = [x if isinstance(x, int) else sect.getPTbyName(x) for x in allowed_pts] - mbody.formats = [x for x in mbody.formats if x in _allowed_pts] - if len(mbody.formats) == 0: - self.uaA.recvEvent(CCEventFail((488, 'Not Acceptable Here'))) - self.state = CCStateDead - return - if old_len > len(mbody.formats): - sect.optimize_a() if self.cld.startswith('nat-'): self.cld = self.cld[4:] if body != None: