Skip to content

Commit

Permalink
Replace deprecated IPDB with NDB
Browse files Browse the repository at this point in the history
  • Loading branch information
gab-arrobo committed Feb 5, 2024
1 parent a3d4534 commit 3a22664
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 67 deletions.
39 changes: 19 additions & 20 deletions conf/route_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Dict, List, Optional, Tuple

from pybess.bess import *
from pyroute2 import IPDB, IPRoute
from pyroute2 import NDB, IPRoute
from scapy.all import ICMP, IP, send

LOG_FORMAT = "%(asctime)s %(levelname)s %(message)s"
Expand Down Expand Up @@ -308,7 +308,7 @@ class RouteController:
def __init__(
self,
bess_controller: BessController,
ipdb: IPDB,
ndb: NDB,
ipr: IPRoute,
interfaces: List[str],
):
Expand All @@ -319,7 +319,7 @@ def __init__(
bess_controller (BessController):
Controller for BESS (Berkeley Extensible Software Switch).
route_parser (RouteEntryParser): Parser for route entries.
ipdb (IPDB): IP database to manage IP configurations.
ndb (NDB): database to manage Network configurations.
ipr (IPRoute): IP routing control object.
Attributes:
Expand All @@ -336,7 +336,7 @@ def __init__(

self._lock = Lock()

self._ipdb = ipdb
self._ndb = ndb
self._ipr = ipr
self._bess_controller = bess_controller
self._ping_missing_thread = Thread(
Expand All @@ -348,7 +348,7 @@ def __init__(
def register_callbacks(self) -> None:
"""Register callback function."""
logger.info("Registering netlink event listener callback...")
self._event_callback = self._ipdb.register_callback(
self._event_callback = self._ndb.register_callback(
self._netlink_event_listener
)

Expand All @@ -373,7 +373,7 @@ def add_new_route_entry(self, route_entry: RouteEntry) -> None:
Args:
route_entry (RouteEntry): The route entry.
"""
if not (next_hop_mac := fetch_mac(self._ipdb, route_entry.next_hop_ip)):
if not (next_hop_mac := fetch_mac(self._ndb, route_entry.next_hop_ip)):
logger.info(
"mac address of the next hop %s is not stored in ARP table. Probing...",
route_entry.next_hop_ip,
Expand Down Expand Up @@ -603,12 +603,12 @@ def _get_gate_idx(self, route_entry: RouteEntry, module_name: str) -> int:
return self._module_gate_count_cache[module_name]

def _netlink_event_listener(
self, ipdb: IPDB, netlink_message: dict, action: str
self, ndb: NDB, netlink_message: dict, action: str
) -> None:
"""Listens for netlink events and handles them.
Args:
ipdb (IPDB): The IPDB object.
ndb (NDB): The NDB object.
netlink_message (dict): The netlink message.
action (str): The action.
"""
Expand All @@ -629,7 +629,7 @@ def _netlink_event_listener(
def cleanup(self, number: int) -> None:
"""Unregisters the netlink event listener callback and exits."""
logger.info("Received: %i Exiting", number)
self._ipdb.unregister_callback(self._event_callback)
self._ndb.unregister_callback(self._event_callback)
logger.info("Unregistered netlink event listener callback")
sys.exit()

Expand Down Expand Up @@ -667,7 +667,7 @@ def _parse_route_entry_msg(self, route_entry: dict) -> Optional[RouteEntry]:
if not attr_dict.get(KEY_INTERFACE):
return None
interface_index = int(attr_dict.get(KEY_INTERFACE))
interface = self._ipdb.interfaces[interface_index].ifname
interface = self._ndb.interfaces[interface_index]["ifname"]
if interface not in self._interfaces:
return None

Expand Down Expand Up @@ -736,26 +736,25 @@ def send_ping(neighbor_ip):
send(IP(dst=neighbor_ip) / ICMP())


def fetch_mac(ipdb: IPDB, target_ip: str) -> Optional[str]:
"""Fetches the MAC address of the target IP from the ARP table using IPDB.
def fetch_mac(ndb: NDB, target_ip: str) -> Optional[str]:
"""Fetches the MAC address of the target IP from the ARP table using NDB.
Args:
ipdb (IPDB): The IPDB object.
ndb (NDB): The NDB object.
target_ip (str): The target IP address.
Returns:
Optional[str]: The MAC address of the target IP.
"""
neighbors = ipdb.nl.get_neighbours(dst=target_ip)
neighbors = ndb.neighbours.dump()
for neighbor in neighbors:
attrs = dict(neighbor["attrs"])
if attrs.get(KEY_NETWORK_LAYER_DEST_ADDR, "") == target_ip:
if neighbor["dst"] == target_ip:
logger.info(
"Mac address found for %s, Mac: %s",
target_ip,
attrs.get(KEY_LINK_LAYER_ADDRESS, ""),
neighbor["lladdr"],
)
return attrs.get(KEY_LINK_LAYER_ADDRESS, "")
return neighbor["lladdr"]
logger.info("Mac address not found for %s", target_ip)
return None

Expand Down Expand Up @@ -800,11 +799,11 @@ def register_signal_handlers(route_controller: RouteController) -> None:
if __name__ == "__main__":
interface_arg, ip_arg, port_arg = parse_args()
ipr = IPRoute()
ipdb = IPDB()
ndb = NDB()
bess_controller = BessController(ip_arg, port_arg)
route_controller = RouteController(
bess_controller=bess_controller,
ipdb=ipdb,
ndb=ndb,
ipr=ipr,
interfaces=interface_arg,
)
Expand Down
114 changes: 78 additions & 36 deletions conf/test_route_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import unittest
from unittest.mock import MagicMock, Mock, patch

from pyroute2 import IPDB # type: ignore[import]
from pyroute2 import NDB

sys.modules["pybess.bess"] = MagicMock()

Expand Down Expand Up @@ -50,7 +50,7 @@ def test_given_valid_ip_when_validate_ipv4_then_returns_true(self):
def test_given_invalid_ip_when_validate_ipv4_then_returns_false(self):
self.assertFalse(validate_ipv4("192.168.300.1"))

def test_given_invalid_ip_when_validate_ipv4_then_returns_false(self):
def test_given_invalid_ip_when_validate_ipv6_then_returns_false(self):
self.assertFalse(validate_ipv4("::1"))
self.assertFalse(validate_ipv4(""))

Expand All @@ -67,27 +67,39 @@ def test_given_valid_mac_when_mac_to_hex_then_return_hex_string_representation(
self.assertEqual(mac_to_hex("00:1a:2b:3c:4d:5e"), "001A2B3C4D5E")

def test_given_known_destination_when_fetch_mac_then_returns_mac(self):
ipdb = IPDB()
ipdb.nl.get_neighbours = lambda dst, **kwargs: [
{"attrs": [("NDA_DST", dst), ("NDA_LLADDR", "00:1a:2b:3c:4d:5e")]}
]
self.assertEqual(fetch_mac(ipdb, "192.168.1.1"), "00:1a:2b:3c:4d:5e")
ndb = Mock()
kwargs = {
"ifindex": 1,
"dst": "192.168.1.1",
"lladdr": "00:1a:2b:3c:4d:5e"
}
neighbour = ndb.neighbours.create(**kwargs)
neighbour.commit()
ndb.neighbours.dump.return_value = [kwargs]
self.assertEqual(fetch_mac(ndb, "192.168.1.1"), "00:1a:2b:3c:4d:5e")

def test_given_unkonw_destination_when_fetch_mac_then_returns_none(self):
ipdb = IPDB()
ipdb.nl.get_neighbours = lambda dst, **kwargs: []
self.assertIsNone(fetch_mac(ipdb, "192.168.1.1"))
ndb = Mock()
kwargs = {
"ifindex": 1,
"dst": "192.168.1.1",
"lladdr": None
}
neighbour = ndb.neighbours.create(**kwargs)
neighbour.commit()
ndb.neighbours.dump.return_value = [kwargs]
self.assertIsNone(fetch_mac(ndb, "192.168.1.1"))


class TestRouteController(unittest.TestCase):
def setUp(self):
self.mock_bess_controller = Mock(BessControllerMock)
self.ipdb = Mock()
self.ndb = Mock()
self.ipr = Mock()
interfaces = ["access", "core"]
self.route_controller = RouteController(
self.mock_bess_controller,
self.ipdb,
self.ndb,
interfaces=interfaces,
ipr=self.ipr,
)
Expand All @@ -105,9 +117,14 @@ def add_route_entry(
mock_fetch_mac,
) -> None:
"""Adds a new route entry using the route controller."""
self.ipdb.nl.get_neighbours = lambda dst, **kwargs: [
{"attrs": [("NDA_DST", dst), ("NDA_LLADDR", "00:1a:2b:3c:4d:5e")]}
]
kwargs = {
"ifindex": 1,
"dst": "192.168.1.1",
"lladdr": "00:1a:2b:3c:4d:5e"
}
neighbour = self.ndb.neighbours.create(**kwargs)
neighbour.commit()
self.ndb.neighbours.dump.return_value = [kwargs]
mock_get_update_module_name.return_value = "merge_module"
mock_get_route_module_name.return_value = "route_module"
mock_get_merge_module_name.return_value = "update_module"
Expand All @@ -116,7 +133,7 @@ def add_route_entry(
return route_entry

def test_given_valid_route_message_when_parse_message_then_parses_message(self):
self.ipdb.interfaces = {2: Mock(ifname="core")}
self.ndb.interfaces = {2: {"ifname": "core"}}
example_route_entry = {
"family": 2,
"dst_len": 24,
Expand All @@ -141,13 +158,13 @@ def test_given_valid_route_message_when_parse_message_then_parses_message(self):
self.assertIsInstance(result, RouteEntry)
self.assertEqual(result.dest_prefix, "192.168.1.0")
self.assertEqual(result.next_hop_ip, "172.31.48.1")
self.assertEqual(result.interface, self.ipdb.interfaces[2].ifname)
self.assertEqual(result.interface, self.ndb.interfaces[2]["ifname"])
self.assertEqual(result.prefix_len, 24)

def test_given_valid_route_message_and_dst_len_is_zero_when_parse_message_then_parses_message_as_default_route(
self,
):
self.ipdb.interfaces = {2: Mock(ifname="core")}
self.ndb.interfaces = {2: {"ifname": "core"}}
example_route_entry = {
"family": 2,
"dst_len": 0,
Expand All @@ -171,11 +188,11 @@ def test_given_valid_route_message_and_dst_len_is_zero_when_parse_message_then_p
self.assertIsInstance(result, RouteEntry)
self.assertEqual(result.dest_prefix, "0.0.0.0")
self.assertEqual(result.next_hop_ip, "172.31.48.1")
self.assertEqual(result.interface, self.ipdb.interfaces[2].ifname)
self.assertEqual(result.interface, self.ndb.interfaces[2]["ifname"])
self.assertEqual(result.prefix_len, 0)

def test_given_invalid_route_message_when_parse_message_then_returns_none(self):
self.ipdb.interfaces = {2: Mock(ifname="not the needed interface")}
self.ndb.interfaces = {2: {"ifname": "not the needed interface"}}
example_route_entry = {
"family": 2,
"flags": 0,
Expand All @@ -202,7 +219,14 @@ def test_given_new_route_when_add_new_route_entry_and_mac_not_known_then_destina
self,
mock_send_ping,
):
self.ipdb.nl.get_neighbours = lambda dst, **kwargs: []
kwargs = {
"ifindex": 1,
"dst": "192.168.1.1",
"lladdr": None
}
neighbour = self.ndb.neighbours.create(**kwargs)
neighbour.commit()
self.ndb.neighbours.dump.return_value = [kwargs]
route_entry = RouteEntry(
next_hop_ip="1.2.3.4",
interface="random_interface",
Expand All @@ -215,9 +239,14 @@ def test_given_new_route_when_add_new_route_entry_and_mac_not_known_then_destina
def test_given_valid_new_route_when_add_new_route_entry_and_mac_known_then_route_is_added_in_bess(
self,
):
self.ipdb.nl.get_neighbours = lambda dst, **kwargs: [
{"attrs": [("NDA_DST", dst), ("NDA_LLADDR", "00:1a:2b:3c:4d:5e")]}
]
kwargs = {
"ifindex": 1,
"dst": "192.168.1.1",
"lladdr": "00:1a:2b:3c:4d:5e"
}
neighbour = self.ndb.neighbours.create(**kwargs)
neighbour.commit()
self.ndb.neighbours.dump.return_value = [kwargs]
mock_routes = [{"event": "RTM_NEWROUTE"}, {"event": "OTHER_ACTION"}]
self.ipr.get_routes.return_value = mock_routes
route_entry = RouteEntry(
Expand All @@ -227,14 +256,20 @@ def test_given_valid_new_route_when_add_new_route_entry_and_mac_known_then_route
prefix_len=24,
)
self.route_controller.add_new_route_entry(route_entry)
self.mock_bess_controller.add_route_to_module()
self.mock_bess_controller.add_route_to_module.assert_called_once()

def test_given_valid_new_route_when_add_new_route_entry_and_mac_known_and_neighbor_not_known_then_update_module_is_created_and_modules_are_linked(
self,
):
self.ipdb.nl.get_neighbours = lambda dst, **kwargs: [
{"attrs": [("NDA_DST", dst), ("NDA_LLADDR", "00:1a:2b:3c:4d:5e")]}
]
kwargs = {
"ifindex": 1,
"dst": "192.168.1.1",
"lladdr": "00:1a:2b:3c:4d:5e"
}
neighbour = self.ndb.neighbours.create(**kwargs)
neighbour.commit()
self.ndb.neighbours.dump.return_value = [kwargs]
mock_routes = [{"event": "RTM_NEWROUTE"}, {"event": "OTHER_ACTION"}]
self.ipr.get_routes.return_value = mock_routes
route_entry = RouteEntry(
Expand All @@ -244,7 +279,9 @@ def test_given_valid_new_route_when_add_new_route_entry_and_mac_known_and_neighb
prefix_len=24,
)
self.route_controller.add_new_route_entry(route_entry)
self.mock_bess_controller.create_module()
self.mock_bess_controller.create_module.assert_called()
self.mock_bess_controller.link_modules()
self.mock_bess_controller.link_modules.assert_called()

@patch.object(RouteController, "add_new_route_entry")
Expand All @@ -265,7 +302,7 @@ def test_given_new_route_when_bootstrap_routes_then_add_new_entry_is_called(
{"event": "OTHER_ACTION"},
]
self.ipr.get_routes.return_value = mock_routes
self.ipdb.interfaces = {2: Mock(ifname="core")}
self.ndb.interfaces = {2: {"ifname": "core"}}
valid_route_entry = RouteEntry(
next_hop_ip="1.2.3.4",
interface="core",
Expand Down Expand Up @@ -322,7 +359,7 @@ def test_given_new_route_and_invalid_message_when_bootstrap_routes_then_add_new_
def test_given_netlink_message_when_rtm_newroute_event_then_add_new_route_entry_is_called(
self, mock_add_new_route_entry
):
self.ipdb.interfaces = {2: Mock(ifname="core")}
self.ndb.interfaces = {2: {"ifname": "core"}}
example_route_entry = {
"family": 2,
"dst_len": 24,
Expand All @@ -344,7 +381,7 @@ def test_given_netlink_message_when_rtm_newroute_event_then_add_new_route_entry_
"event": "RTM_NEWROUTE",
}
self.route_controller._netlink_event_listener(
self.ipdb, example_route_entry, "RTM_NEWROUTE"
self.ndb, example_route_entry, "RTM_NEWROUTE"
)
mock_add_new_route_entry.assert_called()

Expand Down Expand Up @@ -398,7 +435,7 @@ def test_given_existing_neighbor_and_route_count_is_one_when_delete_route_entry_
def test_given_netlink_message_when_rtm_delroute_event_then_delete_route_entry_is_called(
self, mock_delete_route_entry
):
self.ipdb.interfaces = {2: Mock(ifname="core")}
self.ndb.interfaces = {2: {"ifname": "core"}}
example_route_entry = {
"family": 2,
"dst_len": 24,
Expand All @@ -420,7 +457,7 @@ def test_given_netlink_message_when_rtm_delroute_event_then_delete_route_entry_i
"event": "RTM_DELROUTE",
}
self.route_controller._netlink_event_listener(
self.ipdb, example_route_entry, "RTM_DELROUTE"
self.ndb, example_route_entry, "RTM_DELROUTE"
)
mock_delete_route_entry.assert_called()

Expand All @@ -429,9 +466,14 @@ def test_given_new_neighbor_in_unresolved_when_add_unresolved_new_neighbor_then_
self,
_,
):
self.ipdb.nl.get_neighbours = lambda dst, **kwargs: [
{"attrs": [("NDA_DST", dst), ("NDA_LLADDR", "00:1a:2b:3c:4d:5e")]}
]
kwargs = {
"ifindex": 1,
"dst": "192.168.1.1",
"lladdr": "00:1a:2b:3c:4d:5e"
}
neighbour = self.ndb.neighbours.create(**kwargs)
neighbour.commit()
self.ndb.neighbours.dump.return_value = [kwargs]
mock_netlink_msg = {
"attrs": {
"NDA_DST": "1.2.3.4",
Expand All @@ -455,6 +497,6 @@ def test_given_netlink_message_when_rtm_newneigh_event_then_add_unresolved_new_n
self, mock_add_unresolved_new_neighbor
):
self.route_controller._netlink_event_listener(
self.ipdb, "new neighbour message", "RTM_NEWNEIGH"
self.ndb, "new neighbour message", "RTM_NEWNEIGH"
)
mock_add_unresolved_new_neighbor.assert_called()
Loading

0 comments on commit 3a22664

Please sign in to comment.