Skip to content

Commit

Permalink
Major refactor #2
Browse files Browse the repository at this point in the history
  • Loading branch information
raul-marquez-csa committed Feb 1, 2024
1 parent fca19cd commit 93394a8
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 94 deletions.
24 changes: 0 additions & 24 deletions src/python_testing/mdns_discovery/exceptions.py

This file was deleted.

169 changes: 99 additions & 70 deletions src/python_testing/mdns_discovery/mdns_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class MdnsServiceType(Enum):

class MdnsDiscovery:

DEFAULT_DISCOVERY_DURATION_SEC = 3
DISCOVERY_TIMEOUT_SEC = 15

def __init__(self):
"""
Expand All @@ -88,19 +88,78 @@ def __init__(self):
- get_commissionable_service_info()
- get_operational_service_info()
- get_border_router_service_info()
- discover()
- _discover()
"""
# An instance of Zeroconf to manage mDNS operations.
# This is used for handling the low-level details of mDNS.
self._zc = Zeroconf()
self._zc = Zeroconf(ip_version=IPVersion.V6Only)

# A dictionary to store discovered services.
# As services are discovered, they are added to this dictionary.
self._discovered_services = {}

self._service_types = []

self._event = asyncio.Event()

async def get_commissioner_service(self, log_output: bool = False,
discovery_timeout_sec: float = DISCOVERY_TIMEOUT_SEC
) -> MdnsServiceInfo:
return await self._get_service(MdnsServiceType.COMMISSIONER, log_output, discovery_timeout_sec)

async def get_commissionable_service(self, log_output: bool = False,
discovery_timeout_sec: float = DISCOVERY_TIMEOUT_SEC
) -> MdnsServiceInfo:
return await self._get_service(MdnsServiceType.COMMISSIONABLE, log_output, discovery_timeout_sec)

async def get_operational_service(self, service_name: str = None,
service_type: str = None,
discovery_timeout_sec: float = DISCOVERY_TIMEOUT_SEC,
log_output: bool = False
) -> Optional[MdnsServiceInfo]:
# Validation to ensure both or none of the parameters are provided
if (service_name is None) != (service_type is None):
raise ValueError("Both service_name and service_type must be provided together or not at all.")

mdns_service_info = None

if service_name is None and service_type is None:
mdns_service_info = await self._get_service(MdnsServiceType.OPERATIONAL, log_output, discovery_timeout_sec)
else:
print(f"Looking for MDNS service: Type: {service_type} - Name {service_name}")

# Get service info
service_info = AsyncServiceInfo(service_type, service_name)
is_discovered = await service_info.async_request(self._zc, 3000)

if is_discovered:
mdns_service_info = self._to_mdns_service_info_class(service_info)

self._discovered_services = {}
self._discovered_services[service_type] = [mdns_service_info]

if log_output:
self._log_output()

return mdns_service_info

async def discover(self,
discovery_duration_sec: float = DEFAULT_DISCOVERY_DURATION_SEC,
log_output: bool = True) -> dict:
async def get_border_router_service(self, log_output: bool = False,
discovery_timeout_sec: float = DISCOVERY_TIMEOUT_SEC
) -> MdnsServiceInfo:
return await self._get_service(MdnsServiceType.BORDER_ROUTER, log_output, discovery_timeout_sec)

async def get_all_services(self, log_output: bool = False,
discovery_timeout_sec: float = DISCOVERY_TIMEOUT_SEC
) -> Dict[str, List[MdnsServiceInfo]]:
return await self._discover(all_services=True,
discovery_timeout_sec=discovery_timeout_sec,
log_output=log_output)

# Private methods
async def _discover(self,
discovery_timeout_sec: float = DISCOVERY_TIMEOUT_SEC,
all_services: bool = False,
log_output: bool = False) -> None:
"""
Asynchronously discovers network services of specified types using mDNS and returns the discovered services.
Expand All @@ -109,7 +168,7 @@ async def discover(self,
processed.
Args:
discovery_duration_sec (float): The duration in seconds for which the discovery process should wait
discovery_timeout_sec (float): The duration in seconds for which the discovery process should wait
to allow for service announcements. If not provided, defaults to a
predetermined duration (e.g., 3 seconds).
log_output (bool): Logs discovered services to the console in JSON format, defaults to True.
Expand All @@ -123,24 +182,27 @@ async def discover(self,
dictionary represents a service type, and each value is a list of MdnsServiceInfo objects, each
containing details of a discovered service.
"""
service_types = list(await AsyncZeroconfServiceTypes.async_find())
self._zc = Zeroconf(ip_version=IPVersion.V6Only)
self._event.clear()

if all_services:
self._service_types = list(await AsyncZeroconfServiceTypes.async_find())

print(f"Browsing for MDNS service(s) of type: {self._service_types}")

aiobrowser = AsyncServiceBrowser(zeroconf=self._zc,
type_=service_types,
type_=self._service_types,
handlers=[self._on_service_state_change]
)
# Wait for discovery
await asyncio.sleep(discovery_duration_sec)
if aiobrowser is not None:

try:
await asyncio.wait_for(self._event.wait(), timeout=discovery_timeout_sec)
except asyncio.TimeoutError:
print(f"MDNS service discovery timed out after {discovery_timeout_sec} seconds.")
finally:
await aiobrowser.async_cancel()

if log_output:
# Log discovered services in JSON format
converted_services = {key: [asdict(item) for item in value] for key, value in self._discovered_services.items()}
json_str = json.dumps(converted_services, indent=4)
print(json_str)

return self._discovered_services
self._log_output()

def _on_service_state_change(
self,
Expand All @@ -165,6 +227,7 @@ def _on_service_state_change(
None: This method does not return any value.
"""
if state_change.value == ServiceStateChange.Added.value:
self._event.set()
asyncio.ensure_future(self._query_service_info(
zeroconf,
service_type,
Expand Down Expand Up @@ -193,14 +256,14 @@ async def _query_service_info(self, zeroconf: Zeroconf, service_type: str, servi
service_info.async_clear_cache()

if is_service_discovered:
mdns_service_info = self._to_mdns_service_class(service_info)
mdns_service_info = self._to_mdns_service_info_class(service_info)

if service_type not in self._discovered_services:
self._discovered_services[service_type] = [mdns_service_info]
else:
self._discovered_services[service_type].append(mdns_service_info)

def _to_mdns_service_class(self, service_info: AsyncServiceInfo) -> MdnsServiceInfo:
def _to_mdns_service_info_class(self, service_info: AsyncServiceInfo) -> MdnsServiceInfo:
"""
Converts an AsyncServiceInfo object into a MdnsServiceInfo data class.
Expand All @@ -227,53 +290,19 @@ def _to_mdns_service_class(self, service_info: AsyncServiceInfo) -> MdnsServiceI

return mdns_service_info

def get_commissioner_service_info(self) -> List[Dict[str, any]]:
"""
Retrieves the service information for Commissioner services.
Returns:
list: A list of discovered Commissioner services.
"""
return self._get_service_info(MdnsServiceType.COMMISSIONER)

def get_commissionable_service_info(self) -> List[Dict[str, any]]:
"""
Retrieves the service information for Commissionable services.
Returns:
list: A list of discovered Commissionable services.
"""
return self._get_service_info(MdnsServiceType.COMMISSIONABLE)

async def get_operational_service_info(self, service_name: str, service_type: str) -> Optional[MdnsServiceInfo]:
"""
Asynchronously retrieves service information for a specified service name and service_type.
Args:
service_name (str): The name of the service to discover.
service_type (str): The service type of the service to discover.
Returns:
MdnsServiceInfo | None: The discovered service information if discovered, otherwise None.
Raises:
ValueError: If either 'service_name' or 'service_type' is None.
"""
# Validate arguments to ensure they are not None
if service_name is None or service_type is None:
raise ValueError("Neither 'service_name' nor 'service_type' can be None.")

# Get service info
service_info = AsyncServiceInfo(type, service_name)
is_discovered = await service_info.async_request(self._zc, 3000)

return self._to_mdns_service_class(service_info) if is_discovered else None

def get_border_router_service_info(self) -> List[Dict[str, any]]:
"""
Retrieves the service information for Border Router services.
async def _get_service(self, service_type: MdnsServiceType,
log_output: bool = False,
discovery_timeout_sec: float = DISCOVERY_TIMEOUT_SEC
) -> MdnsServiceInfo:
mdns_service_info = None
self._service_types = [service_type.value]
await self._discover(log_output=log_output, discovery_timeout_sec=discovery_timeout_sec)
if service_type.value in self._discovered_services:
mdns_service_info = self._discovered_services[service_type.value][0]

return mdns_service_info

Returns:
list: A list of discovered Border Router services.
"""
return self._get_service_info(MdnsServiceType.BORDER_ROUTER)
def _log_output(self) -> str:
converted_services = {key: [asdict(item) for item in value] for key, value in self._discovered_services.items()}
json_str = json.dumps(converted_services, indent=4)
print(json_str)

0 comments on commit 93394a8

Please sign in to comment.