diff --git a/.gitignore b/.gitignore index e5d3544..b80bb0d 100644 --- a/.gitignore +++ b/.gitignore @@ -180,3 +180,6 @@ home-assistant.log.fault home-assistant_v2.db home-assistant_v2.db-shm home-assistant_v2.db-wal + +# Ignore Misc +*.old \ No newline at end of file diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index aa924b6..8acba7b 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -6,7 +6,7 @@ stages: variables: PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip" PRE_COMMIT_HOME: "/.cache/pre-commit" - PYTHON_IMAGE: python:3.12 + PYTHON_IMAGE: python:3.13 include: - local: "/template/.python-container.yml" @@ -24,7 +24,15 @@ check_formatting: paths: - $PRE_COMMIT_HOME script: - - pre-commit run --all-files + - SKIP=hassfest pre-commit run --all-files + +hassfest: + stage: lint + image: + name: ghcr.io/home-assistant/hassfest + entrypoint: [""] + script: + - /usr/src/homeassistant/script/hassfest/docker/entrypoint.sh pytest: stage: test diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5a6abfa..b5cad67 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,14 +4,6 @@ repos: hooks: - id: pyupgrade args: [--py37-plus] - - repo: https://github.com/psf/black - rev: "24.3.0" - hooks: - - id: black - args: - - --safe - - --quiet - files: ^((homeassistant|script|tests)/.+)?[^/]+\.py$ - repo: https://github.com/codespell-project/codespell rev: v2.2.6 hooks: @@ -28,7 +20,7 @@ repos: additional_dependencies: - flake8-docstrings==1.5.0 - pydocstyle==5.0.2 - files: ^(homeassistant|script|tests)/.+\.py$ + files: ^(custom_components|script|tests)/.+\.py$ - repo: https://github.com/PyCQA/bandit rev: 1.7.8 hooks: @@ -37,22 +29,36 @@ repos: - --quiet - --format=custom - --configfile=tests/bandit.yaml - files: ^(homeassistant|script|tests)/.+\.py$ + files: ^(custom_components|script|tests)/.+\.py$ - repo: https://github.com/pre-commit/mirrors-isort rev: v5.10.1 hooks: - id: isort - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v5.0.0 hooks: - id: check-executables-have-shebangs stages: [manual] - id: check-json + - id: check-yaml + exclude: ^example/frontend.yaml + - id: pretty-format-json + args: + - --autofix + - --top-keys=domain,title,name - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.3.3 + rev: v0.8.4 hooks: # Run the linter. - id: ruff # Run the formatter. - id: ruff-format + + - repo: local + hooks: + - id: hassfest + name: hassfest + language: script + entry: hooks/hassfest.sh + files: ^(custom_components/.+/(icons|manifest|strings)\.json|custom_components/.+/translations/.+\.json|custom_components/.+/(quality_scale)\.yaml|custom_components/brands/.*\.json|custom_components/.+/services\.yaml|script/hassfest/(?!metadata|mypy_config).+\.py|requirements.+\.txt)$ \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json index 1e88bf8..3169ad6 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -1,41 +1,54 @@ { - "version": "0.2.0", - "configurations": [ - { - "name": "Home Assistant", - "type": "debugpy", - "request": "launch", - "module": "homeassistant", - "justMyCode": false, - "args": [ - "--debug", - "-c", - "." - ] - }, - { - "name": "Home Assistant (skip pip)", - "type": "debugpy", - "request": "launch", - "module": "homeassistant", - "justMyCode": false, - "args": [ - "--debug", - "-c", - ".", - "--skip-pip" - ] - }, - { - "name": "Home Assistant: Changed tests", - "type": "debugpy", - "request": "launch", - "module": "pytest", - "justMyCode": false, - "args": [ - "--timeout=10", - "--picked" - ] - } - ] -} \ No newline at end of file + "configurations": [ + { + "name": "Home Assistant", + "args": [ + "--debug", + "-c", + "." + ], + "env": { + "GTFS_REALTIME_SHOW_MEMORY_USE": "on" + }, + "justMyCode": false, + "module": "homeassistant", + "request": "launch", + "type": "debugpy" + }, + { + "name": "Home Assistant (skip pip)", + "args": [ + "--debug", + "-c", + ".", + "--skip-pip" + ], + "justMyCode": false, + "module": "homeassistant", + "request": "launch", + "type": "debugpy" + }, + { + "name": "Home Assistant: Changed tests", + "args": [ + "--timeout=10", + "--picked" + ], + "justMyCode": false, + "module": "pytest", + "request": "launch", + "type": "debugpy" + }, + { + "name": "Python: Debug Tests", + "console": "integratedTerminal", + "justMyCode": false, + "purpose": [ + "debug-test" + ], + "request": "launch", + "type": "debugpy" + } + ], + "version": "0.2.0" +} diff --git a/custom_components/__init__.py b/custom_components/__init__.py index e69de29..c8e6d5d 100644 --- a/custom_components/__init__.py +++ b/custom_components/__init__.py @@ -0,0 +1 @@ +"""GTFS Realtime HomeAssistant Custom Component.""" diff --git a/custom_components/gtfs_realtime/__init__.py b/custom_components/gtfs_realtime/__init__.py index 510a054..9b874ff 100644 --- a/custom_components/gtfs_realtime/__init__.py +++ b/custom_components/gtfs_realtime/__init__.py @@ -1,28 +1,30 @@ """The GTFS Realtime integration.""" + # GTFS Station Stop Feed Subject serves as the data hub for the integration from datetime import timedelta +import logging from typing import Any from gtfs_station_stop.feed_subject import FeedSubject from homeassistant.config_entries import ConfigEntry from homeassistant.const import Platform from homeassistant.core import HomeAssistant +from homeassistant.helpers.selector import TextSelector, TextSelectorConfig import voluptuous as vol from custom_components.gtfs_realtime.config_flow import DOMAIN_SCHEMA from .const import ( - CAL_DB, + CLEAR_STATIC_FEEDS, CONF_API_KEY, CONF_GTFS_STATIC_DATA, CONF_ROUTE_ICONS, + CONF_STATIC_SOURCES_UPDATE_FREQUENCY, + CONF_STATIC_SOURCES_UPDATE_FREQUENCY_DEFAULT, CONF_URL_ENDPOINTS, - COORDINATOR_REALTIME, DOMAIN, - RTI_DB, - SSI_DB, - TI_DB, + REFRESH_STATIC_FEEDS, ) from .coordinator import GtfsRealtimeCoordinator @@ -33,34 +35,106 @@ extra=vol.ALLOW_EXTRA, ) +type GtfsRealtimeConfigEntry = ConfigEntry[GtfsRealtimeCoordinator] + +_LOGGER = logging.getLogger(__name__) + -async def _async_create_gtfs_update_hub(hass: HomeAssistant, config: dict[str, Any]): +def create_gtfs_update_hub( + hass: HomeAssistant, config: dict[str, Any] +) -> GtfsRealtimeCoordinator: + """Create the Update Coordinator.""" hub = FeedSubject( config[CONF_URL_ENDPOINTS], headers={"api_key": config[CONF_API_KEY]} ) - route_icons = config.get(CONF_ROUTE_ICONS) # optional - # Attempt to perform an update to verify configuration - await hub.async_update() - coordinator_realtime = GtfsRealtimeCoordinator( - hass, hub, config[CONF_GTFS_STATIC_DATA], static_timedelta=timedelta(hours=24) - ) - # Update the static data for the coordinator before the first update - await coordinator_realtime.async_update_static_data() - hass.data[DOMAIN] = { - COORDINATOR_REALTIME: coordinator_realtime, - CAL_DB: coordinator_realtime.calendar, - SSI_DB: coordinator_realtime.station_stop_info_db, - TI_DB: coordinator_realtime.trip_info_db, - RTI_DB: coordinator_realtime.route_info_db, - CONF_ROUTE_ICONS: route_icons, + route_icons: str | None = config.get(CONF_ROUTE_ICONS) # optional + + static_timedelta = { + uri: timedelta(**timedelta_dict) + for uri, timedelta_dict in config[CONF_STATIC_SOURCES_UPDATE_FREQUENCY].items() } - if CONF_ROUTE_ICONS in config: - hass.data[DOMAIN][CONF_ROUTE_ICONS] = config[CONF_ROUTE_ICONS] - return True + # if the value is 0, it is likely user input errors due to a bug in config flow UI, so coerce it to the default + for value in static_timedelta.values(): + if value == timedelta(seconds=0): + value = timedelta(hours=CONF_STATIC_SOURCES_UPDATE_FREQUENCY_DEFAULT) + return GtfsRealtimeCoordinator( + hass, + hub, + config[CONF_GTFS_STATIC_DATA], + static_timedelta=static_timedelta, + route_icons=route_icons, + ) -async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: +async def async_setup_entry( + hass: HomeAssistant, entry: GtfsRealtimeConfigEntry +) -> bool: """Set up GTFS Realtime Feed Subject for use by all sensors.""" - await _async_create_gtfs_update_hub(hass, entry.data) + coordinator: GtfsRealtimeCoordinator = create_gtfs_update_hub(hass, entry.data) + await coordinator.async_config_entry_first_refresh() + entry.runtime_data = coordinator + + async def handle_refresh_static_feeds(call): + """Handle service action to refresh static feeds.""" + entry.runtime_data.static_update_targets = set(call.data["gtfs_static_data"]) + await entry.runtime_data.async_update_static_data() + + async def handle_clear_static_feeds(call): + """Handle service action to clear static feeds.""" + await entry.runtime_data.async_update_static_data(clear_old_data=True) + + hass.services.async_register( + DOMAIN, + REFRESH_STATIC_FEEDS, + handle_refresh_static_feeds, + vol.Schema( + { + vol.Optional( + CONF_GTFS_STATIC_DATA, + default=entry.runtime_data.gtfs_static_zip, + description=( + {"suggested_value": ["https://"]} + if len(entry.runtime_data.gtfs_static_zip) == 0 + else {} + ), + ): TextSelector(TextSelectorConfig(multiline=False, multiple=True)), + } + ), + ) + + hass.services.async_register(DOMAIN, CLEAR_STATIC_FEEDS, handle_clear_static_feeds) + await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) return True + + +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Unload GTFS config entry.""" + return await hass.config_entries.async_unload_platforms(entry, PLATFORMS) + + +async def async_migrate_entry( + hass: HomeAssistant, entry: GtfsRealtimeConfigEntry +) -> bool: + """Migrate old entry.""" + _LOGGER.debug( + "Migrating configuration from version %s.%s", + entry.version, + entry.minor_version, + ) + if entry.version > 1: + return False + if entry.version == 1: + new_data = {**entry.data} + new_data[CONF_STATIC_SOURCES_UPDATE_FREQUENCY] = {} + for uri in new_data[CONF_GTFS_STATIC_DATA]: + _LOGGER.debug( + f"Static data source {uri} set to update on interval of {timedelta(seconds=CONF_STATIC_SOURCES_UPDATE_FREQUENCY_DEFAULT)}" + ) + new_data[CONF_STATIC_SOURCES_UPDATE_FREQUENCY][uri] = { + "hours": CONF_STATIC_SOURCES_UPDATE_FREQUENCY_DEFAULT + } + hass.config_entries.async_update_entry( + entry, data=new_data, version=2, minor_version=0 + ) + return True diff --git a/custom_components/gtfs_realtime/binary_sensor.py b/custom_components/gtfs_realtime/binary_sensor.py index 70ea625..9039c6c 100644 --- a/custom_components/gtfs_realtime/binary_sensor.py +++ b/custom_components/gtfs_realtime/binary_sensor.py @@ -13,19 +13,12 @@ from homeassistant.core import HomeAssistant, callback import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_platform import AddEntitiesCallback -from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.update_coordinator import CoordinatorEntity import voluptuous as vol -from .const import ( - CONF_ROUTE_IDS, - COORDINATOR_REALTIME, - DESCRIPTION_PRETTY, - DOMAIN, - HEADER_PRETTY, - ROUTE_ID, - STOP_ID, -) +from custom_components.gtfs_realtime import GtfsRealtimeConfigEntry + +from .const import CONF_ROUTE_IDS, ROUTE_ID, STOP_ID from .coordinator import GtfsRealtimeCoordinator PLATFORM_SCHEMA = BINARY_SENSOR_PLATFORM_SCHEMA.extend( @@ -40,32 +33,31 @@ async def async_setup_entry( hass: HomeAssistant, - config: ConfigType, + entry: GtfsRealtimeConfigEntry, add_entities: AddEntitiesCallback, - discovery_info: DiscoveryInfoType | None = None, ) -> None: """Set up the sensor platform.""" - coordinator: GtfsRealtimeCoordinator = hass.data[DOMAIN][COORDINATOR_REALTIME] - if discovery_info is None: - if CONF_ROUTE_IDS in config.data: - add_entities( - [ - AlertSensor( - coordinator, - RouteStatus(route_id, coordinator.hub), - hass.config.language, - None, - ) - for route_id in config.data[CONF_ROUTE_IDS] - ] - ) + coordinator: GtfsRealtimeCoordinator = entry.runtime_data + if CONF_ROUTE_IDS in entry.data: + add_entities( + [ + AlertSensor( + coordinator, + RouteStatus(route_id, coordinator.hub), + hass.config.language, + None, + ) + for route_id in entry.data[CONF_ROUTE_IDS] + ] + ) class AlertSensor(BinarySensorEntity, CoordinatorEntity): """Representation of a Station GTFS Realtime Alert Sensor.""" - CLEAN_ALERT_DATA = {HEADER_PRETTY: "", DESCRIPTION_PRETTY: ""} + CLEAN_ALERT_DATA = {"header_0": "", "description_0": ""} + _attr_translation_key = "alert_descriptions" _attr_device_class = BinarySensorDeviceClass.PROBLEM def __init__( @@ -95,6 +87,7 @@ def extra_state_attributes(self) -> dict[str, str]: return self._alert_detail def update(self) -> None: + """Update state from coordinator data.""" alerts = self.informed_entity.alerts self._alert_detail = {} if len(alerts) == 0: @@ -102,14 +95,15 @@ def update(self) -> None: elif len(alerts) > 0: self._attr_is_on = True for i, alert in enumerate(alerts): - self._alert_detail[f"{HEADER_PRETTY}{f" {i + 1}" if i > 0 else ""}"] = ( - alert.header_text.get(self.language, "") + self._alert_detail[f"header_{i + 1}"] = alert.header_text.get( + self.language, "" + ) + self._alert_detail[f"description_{i + 1}"] = alert.description_text.get( + self.language, "" ) - self._alert_detail[ - f"{DESCRIPTION_PRETTY}{f" {i + 1}" if i > 0 else ""}" - ] = alert.description_text.get(self.language, "") self.async_write_ha_state() @callback def _handle_coordinator_update(self) -> None: + """Handle coordinator update callback.""" self.update() diff --git a/custom_components/gtfs_realtime/config_flow.py b/custom_components/gtfs_realtime/config_flow.py index 1468305..5173b58 100644 --- a/custom_components/gtfs_realtime/config_flow.py +++ b/custom_components/gtfs_realtime/config_flow.py @@ -1,3 +1,6 @@ +"""Config Flow for GTFS Realtime.""" + +import asyncio import json import logging from typing import Any @@ -6,9 +9,15 @@ from gtfs_station_stop.route_info import RouteInfoDatabase from gtfs_station_stop.static_database import async_factory from gtfs_station_stop.station_stop_info import LocationType, StationStopInfoDatabase -from homeassistant import config_entries +from homeassistant.config_entries import ConfigFlow +from homeassistant.data_entry_flow import SectionConfig, section import homeassistant.helpers.config_validation as cv from homeassistant.helpers.selector import ( + DurationSelector, + DurationSelectorConfig, + NumberSelector, + NumberSelectorConfig, + NumberSelectorMode, SelectOptionDict, SelectSelector, SelectSelectorConfig, @@ -23,12 +32,19 @@ CONF_API_KEY, CONF_ARRIVAL_LIMIT, CONF_GTFS_PROVIDER, + CONF_GTFS_PROVIDER_ID, CONF_GTFS_STATIC_DATA, + CONF_MINOR_VERSION, CONF_ROUTE_ICONS, CONF_ROUTE_IDS, + CONF_SELECT_AT_LEAST_ONE_STOP_OR_ROUTE, + CONF_STATIC_SOURCES_UPDATE_FREQUENCY, + CONF_STATIC_SOURCES_UPDATE_FREQUENCY_DEFAULT, CONF_STOP_IDS, CONF_URL_ENDPOINTS, + CONF_VERSION, DOMAIN, + FEEDS_URL, ) DOMAIN_SCHEMA = vol.Schema( @@ -40,52 +56,54 @@ } ) - _LOGGER = logging.getLogger(__name__) -class GtfsRealtimeConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): - """Config flow for GTFS Realtime""" +class GtfsRealtimeConfigFlow(ConfigFlow, domain=DOMAIN): + """Config flow for GTFS Realtime.""" - VERSION = 1 - FEEDS_URL = "https://gist.githubusercontent.com/bcpearce/cc60c18f4022c4a11c460c5ccd2ec158/raw/feeds.json" - feeds = {} + VERSION = CONF_VERSION + MINOR_VERSION = CONF_MINOR_VERSION + feeds: dict[str, str] = {} def __init__(self) -> None: """Initialize config flow.""" - self.device_config: dict[str, Any] = {} + self.hub_config: dict[str, Any] = {} - async def get_feeds(): + async def _get_feeds(): async with aiohttp.ClientSession() as session: - async with session.get(GtfsRealtimeConfigFlow.FEEDS_URL) as response: + async with session.get(FEEDS_URL) as response: if response.status >= 200 and response.status < 400: GtfsRealtimeConfigFlow.feeds = json.loads(await response.text()) async def async_step_user(self, user_input=None): + """User initiated Config Flow.""" errors = {} if user_input is not None: - if CONF_GTFS_PROVIDER in user_input: + if CONF_GTFS_PROVIDER_ID in user_input: return await self.async_step_choose_static_and_realtime_feeds( user_input ) + else: + errors["base"] = "unexpected_user_input" # Update the feeds, typically this will be from an externally hosted # file so it may be kept up to date without requiring updates to this repository. # It can also be monkey patched to support testing. try: - await GtfsRealtimeConfigFlow.get_feeds() + await GtfsRealtimeConfigFlow._get_feeds() except Exception as e: # do not allow errors to propagate, this is for convenience - _LOGGER.error("Failed to get preconfigured feeds") - errors["base"] = f"Failed to get preconfigured feeds: {e}" + _LOGGER.error("failed_preconfigured_feeds") + errors["base"] = f"failed_preconfigured_feeds: {e}" - options = {"_": "Other - Enter Manually"} + options = {"_": "..."} for k, v in GtfsRealtimeConfigFlow.feeds.items(): options[k] = v["name"] data_schema = vol.Schema( { - vol.Required(CONF_GTFS_PROVIDER): SelectSelector( + vol.Required(CONF_GTFS_PROVIDER_ID): SelectSelector( SelectSelectorConfig( mode=SelectSelectorMode.DROPDOWN, options=[ @@ -101,26 +119,28 @@ async def async_step_user(self, user_input=None): ) async def async_step_choose_static_and_realtime_feeds( - self, user_input: dict[str, str], errors: dict[str, str] = {} + self, user_input: dict[str, str] = {}, errors: dict[str, str] = {} ): + """Select Static and Realtime Feed URIs.""" if ( CONF_GTFS_STATIC_DATA in user_input and CONF_URL_ENDPOINTS in user_input and not errors ): - self.device_config = self.device_config | user_input - return await self.async_step_choose_informed_entities(user_input) - gtfs_provider = user_input.get(CONF_GTFS_PROVIDER) - realtime_feeds = [""] - static_feeds = [""] - route_icons = "" - self.device_config[CONF_GTFS_PROVIDER] = "Manual" - feed_data = GtfsRealtimeConfigFlow.feeds.get(gtfs_provider) - if feed_data is not None: - realtime_feeds = list(feed_data["realtime_feeds"].values()) - static_feeds = list(feed_data["static_feeds"].values()) - route_icons = feed_data.get("route_icons", route_icons) - self.device_config[CONF_GTFS_PROVIDER] = feed_data["name"] + self.hub_config = self.hub_config | user_input + return await self.async_step_choose_informed_entities() + gtfs_provider_id = user_input.get(CONF_GTFS_PROVIDER_ID) + self.hub_config[CONF_GTFS_PROVIDER] = "Manual" + feed_data = GtfsRealtimeConfigFlow.feeds.get(gtfs_provider_id, {}) + realtime_feeds: list[str] = list( + feed_data.get("realtime_feeds", {"_": [""]}).values() + ) + static_feeds: list[str] = list( + feed_data.get("static_feeds", {"blank_user_entry": [""]}).values() + ) + route_icons: str = feed_data.get("route_icons", "") + self.hub_config[CONF_GTFS_PROVIDER] = feed_data.get("name", "") + self.hub_config[CONF_GTFS_PROVIDER_ID] = gtfs_provider_id data_schema = vol.Schema( { @@ -128,9 +148,11 @@ async def async_step_choose_static_and_realtime_feeds( vol.Optional( CONF_URL_ENDPOINTS, default=realtime_feeds, - description={"suggested_value": ["https://"]} - if realtime_feeds == [""] - else {}, + description=( + {"suggested_value": ["https://"]} + if realtime_feeds == [""] + else {} + ), ): TextSelector( TextSelectorConfig( multiline=False, @@ -141,9 +163,11 @@ async def async_step_choose_static_and_realtime_feeds( vol.Optional( CONF_GTFS_STATIC_DATA, default=static_feeds, - description={"suggested_value": ["https://"]} - if static_feeds == [""] - else {}, + description=( + {"suggested_value": ["https://"]} + if static_feeds == [""] + else {} + ), ): TextSelector(TextSelectorConfig(multiline=False, multiple=True)), vol.Optional( CONF_ROUTE_ICONS, @@ -157,32 +181,25 @@ async def async_step_choose_static_and_realtime_feeds( errors=errors, ) - async def async_step_choose_informed_entities(self, user_input): - if CONF_ROUTE_IDS in user_input or CONF_STOP_IDS in user_input: - self.device_config = self.device_config | user_input - return self.async_create_entry( - title=user_input[CONF_GTFS_PROVIDER], - data=self.device_config, - ) - errors = {} - headers = {} - if user_input.get(CONF_API_KEY): - headers["api_key"] = user_input[CONF_API_KEY] - try: - ssi_db = await async_factory( - StationStopInfoDatabase, - *user_input[CONF_GTFS_STATIC_DATA], - headers=headers, - ) - rt_db = await async_factory( - RouteInfoDatabase, *user_input[CONF_GTFS_STATIC_DATA], headers=headers - ) - except Exception as e: - errors["base"] = str(e) - return await self.async_step_choose_static_and_realtime_feeds( - {CONF_GTFS_PROVIDER: self.device_config.get(CONF_GTFS_PROVIDER)}, errors + async def _get_route_options(self, headers={}) -> list[SelectOptionDict]: + route_db = await async_factory( + RouteInfoDatabase, *self.hub_config[CONF_GTFS_STATIC_DATA], headers=headers + ) + return [ + SelectOptionDict( + value=k, + label=f"{k}: {route_db.route_infos[k].long_name or route_db.route_infos[k].short_name}", ) - stops = [ + for k in route_db.route_infos.keys() + ] + + async def _get_stop_options(self, headers={}) -> list[SelectOptionDict]: + ssi_db = await async_factory( + StationStopInfoDatabase, + *self.hub_config[CONF_GTFS_STATIC_DATA], + headers=headers, + ) + return [ SelectOptionDict( value=k, label=f"{v.name} {f' - {v.desc}' if v.desc is not None else ''} ({v.id})", @@ -190,23 +207,26 @@ async def async_step_choose_informed_entities(self, user_input): for k, v in ssi_db.station_stop_infos.items() if v.location_type == LocationType.STOP ] + + def _create_config_schema( + self, + stops: list[SelectOptionDict], + routes: list[SelectOptionDict], + selected_stops=None, + selected_routes=None, + ) -> vol.Schema: + """Populate the config schema with stops and routes to choose.""" data_schema = vol.Schema( { vol.Required( CONF_GTFS_PROVIDER, - default=self.device_config.get( + default=self.hub_config.get( CONF_GTFS_PROVIDER, "Generic GTFS Provider" ), ): cv.string, vol.Optional(CONF_ROUTE_IDS): SelectSelector( SelectSelectorConfig( - options=[ - SelectOptionDict( - value=k, - label=f"{k}: {rt_db.route_infos[k].long_name or rt_db.route_infos[k].short_name}", - ) - for k in rt_db.route_infos.keys() - ], + options=routes, mode=SelectSelectorMode.DROPDOWN, multiple=True, ) @@ -218,10 +238,79 @@ async def async_step_choose_informed_entities(self, user_input): multiple=True, ) ), - vol.Required(CONF_ARRIVAL_LIMIT, default=4): int, + vol.Required(CONF_ARRIVAL_LIMIT, default=4): NumberSelector( + NumberSelectorConfig(min=1, step=1, mode=NumberSelectorMode.BOX) + ), + CONF_STATIC_SOURCES_UPDATE_FREQUENCY: section( + vol.Schema( + { + vol.Required( + uri, + default={ + "hours": CONF_STATIC_SOURCES_UPDATE_FREQUENCY_DEFAULT + }, + ): DurationSelector( + DurationSelectorConfig( + allow_negative=False, + enable_day=True, + enable_millisecond=False, + ) + ) + for uri in self.hub_config[CONF_GTFS_STATIC_DATA] + } + ), + SectionConfig({"collapsed": True}), + ), } ) + return data_schema + + async def async_step_choose_informed_entities( + self, user_input: dict[str, str] | None = None + ): + """Select informed entities for sensor and binary_sensor platforms.""" + errors = {} + if user_input is not None: + if ( + len(user_input.get(CONF_ROUTE_IDS, [])) > 0 + or len(user_input.get(CONF_STOP_IDS, [])) > 0 + ): + self.hub_config |= user_input + # There appears to be a bug having the section for specific update intervals + # Default any missing ones here + for uri in self.hub_config[CONF_GTFS_STATIC_DATA]: + if uri not in self.hub_config.setdefault( + CONF_STATIC_SOURCES_UPDATE_FREQUENCY, {} + ): + self.hub_config[CONF_STATIC_SOURCES_UPDATE_FREQUENCY][uri] = { + "hours": CONF_STATIC_SOURCES_UPDATE_FREQUENCY_DEFAULT + } + return self.async_create_entry( + title=user_input.get(CONF_GTFS_PROVIDER, "generic_gtfs_provider"), + data=self.hub_config, + ) + else: + errors[CONF_ROUTE_IDS] = CONF_SELECT_AT_LEAST_ONE_STOP_OR_ROUTE + errors[CONF_STOP_IDS] = CONF_SELECT_AT_LEAST_ONE_STOP_OR_ROUTE + + headers = {} + if self.hub_config.get(CONF_API_KEY): + headers["api_key"] = user_input[CONF_API_KEY] + try: + stops, routes = await asyncio.gather( + self._get_stop_options(headers), self._get_route_options(headers) + ) + data_schema = self._create_config_schema(stops=stops, routes=routes) + except Exception as e: + errors["base"] = str(e) + return await self.async_step_choose_static_and_realtime_feeds( + {CONF_GTFS_PROVIDER_ID: self.hub_config.get(CONF_GTFS_PROVIDER_ID)}, + errors, + ) return self.async_show_form( - step_id="choose_informed_entities", data_schema=data_schema, errors=errors + step_id="choose_informed_entities", + data_schema=data_schema, + errors=errors, + last_step=True, ) diff --git a/custom_components/gtfs_realtime/const.py b/custom_components/gtfs_realtime/const.py index 0382836..2bab0de 100644 --- a/custom_components/gtfs_realtime/const.py +++ b/custom_components/gtfs_realtime/const.py @@ -3,15 +3,35 @@ DOMAIN = "gtfs_realtime" CONF_GTFS_PROVIDER = "gtfs_provider" +CONF_GTFS_PROVIDER_ID = "gtfs_provider_id" CONF_API_KEY = "api_key" CONF_GTFS_STATIC_DATA = "gtfs_static_data" +CONF_STATIC_SOURCES_UPDATE_FREQUENCY = "static_sources_update_frequency" +CONF_STATIC_SOURCES_UPDATE_FREQUENCY_DEFAULT = 2 # hours CONF_URL_ENDPOINTS = "url_endpoints" CONF_ROUTE_ICONS = "route_icons" CONF_ROUTE_IDS = "route_ids" CONF_STOP_IDS = "stop_ids" CONF_ARRIVAL_LIMIT = "arrival_limit" +CONF_VERSION = 2 +CONF_MINOR_VERSION = 0 + +# SERVICES +REFRESH_STATIC_FEEDS = "refresh_static_feeds" +CLEAR_STATIC_FEEDS = "clear_static_feeds" + +# ERRORS +CONF_SELECT_AT_LEAST_ONE_STOP_OR_ROUTE = "select_at_least_one_stop_or_route" + +FEEDS_URL = "https://gist.githubusercontent.com/bcpearce/cc60c18f4022c4a11c460c5ccd2ec158/raw/feeds.json" + STOP_ID = "stop_id" ROUTE_ID = "route_id" +TRIP_ID = "trip_id" +ROUTE_COLOR = "route_color" +ROUTE_TEXT_COLOR = "route_text_color" +HEADSIGN = "headsign" +ROUTE_TYPE = "route_type" SSI_DB = "station_stop_info_db" TI_DB = "trip_info_db" @@ -19,17 +39,5 @@ RTI_DB = "route_info_db" COORDINATOR_REALTIME = "coordinator_realtime" -COORDINATOR_STATIC = "coordinator_static" PLATFORM = "platform" - -"""These Constants are user facing.""" - -TRIP_ID_PRETTY = "Trip ID" -HEADSIGN_PRETTY = "Headsign" -ROUTE_COLOR_PRETTY = "Route Color" -ROUTE_TEXT_COLOR_PRETTY = "Route Text Color" -ROUTE_TYPE_PRETTY = "Route Type" - -HEADER_PRETTY = "Header" -DESCRIPTION_PRETTY = "Description" diff --git a/custom_components/gtfs_realtime/coordinator.py b/custom_components/gtfs_realtime/coordinator.py index 6aafca4..dd3071b 100644 --- a/custom_components/gtfs_realtime/coordinator.py +++ b/custom_components/gtfs_realtime/coordinator.py @@ -9,12 +9,16 @@ from gtfs_station_stop.calendar import Calendar from gtfs_station_stop.feed_subject import FeedSubject from gtfs_station_stop.route_info import RouteInfoDatabase +from gtfs_station_stop.route_status import RouteStatus from gtfs_station_stop.static_database import async_factory +from gtfs_station_stop.station_stop import StationStop from gtfs_station_stop.station_stop_info import StationStopInfoDatabase from gtfs_station_stop.trip_info import TripInfoDatabase from homeassistant.core import HomeAssistant from homeassistant.helpers.update_coordinator import DataUpdateCoordinator +from .const import CONF_ROUTE_ICONS, CONF_STATIC_SOURCES_UPDATE_FREQUENCY_DEFAULT + _LOGGER = logging.getLogger(__name__) @@ -25,13 +29,10 @@ def __init__( self, hass: HomeAssistant, feed_subject: FeedSubject, - gtfs_static_zip: Iterable[os.PathLike] | os.PathLike | None = None, + gtfs_static_zip: Iterable[os.PathLike] | os.PathLike = list[os.PathLike], **kwargs, ) -> None: """Initialize the GTFS Update Coordinator to notify all entities upon poll.""" - self.static_timedelta: timedelta = kwargs.get( - "static_timedelta", timedelta(hours=24) - ) self.realtime_timedelta: timedelta = kwargs.get( "realtime_timedelta", timedelta(seconds=60) ) @@ -41,138 +42,135 @@ def __init__( name="GTFS Realtime", update_interval=self.realtime_timedelta, ) - self.hub = feed_subject - self.gtfs_static_zip = gtfs_static_zip - self.calendar: Calendar | None = None - self.station_stop_info_db: StationStopInfoDatabase | None = None - self.trip_info_db: TripInfoDatabase | None = None - self.route_info_db: RouteInfoDatabase | None = None + self.static_timedelta: dict[os.PathLike, timedelta] = kwargs.get( + "static_timedelta", {} + ) self.kwargs = kwargs + self.hub: FeedSubject = feed_subject + self.station_stops: dict[str, StationStop] = {} + self.routes_statuses: dict[str, RouteStatus] = {} + self.gtfs_static_zip: Iterable[os.PathLike] | os.PathLike = gtfs_static_zip + self.calendar: Calendar = Calendar() + self.station_stop_info_db: StationStopInfoDatabase = StationStopInfoDatabase() + self.trip_info_db: TripInfoDatabase = TripInfoDatabase() + self.route_info_db: RouteInfoDatabase = RouteInfoDatabase() self.services_len: int = 0 self.stop_infos_len: int = 0 self.trip_infos_len: int = 0 self.route_infos_len: int = 0 - self.last_static_update: datetime | None = None - _LOGGER.info("Setup GTFS Realtime Update Coordinator") - _LOGGER.info(f"Realtime GTFS update interval {self.realtime_timedelta}") - _LOGGER.info(f"Static GTFS update interval {self.static_timedelta}") + self.route_icons: str | None = kwargs.get(CONF_ROUTE_ICONS) + self.static_update_targets: set[os.PathLike] = set(gtfs_static_zip) + self.last_static_update: dict[os.PathLike, datetime] = {} + _LOGGER.debug("Setup GTFS Realtime Update Coordinator") + _LOGGER.debug(f"Realtime GTFS update interval {self.realtime_timedelta}") + for uri, delta in self.static_timedelta.items(): + _LOGGER.info(f"Static GTFS update interval for {uri} is {delta}") async def _async_update_data(self): """Fetch data from API endpoint.""" + self.static_update_targets |= { + uri + for uri, last_update in self.last_static_update.items() + if datetime.now() - last_update + > self.static_timedelta.get( + uri, CONF_STATIC_SOURCES_UPDATE_FREQUENCY_DEFAULT + ) + } + await self.async_update_static_data() await self.hub.async_update() - # Update the static resource if it is past that timedelta - if ( - self.last_static_update is None - or datetime.now() - self.last_static_update > self.static_timedelta - ): - await self.async_update_static_data() - - async def async_update_static_data(self): + + async def _async_update_static_data(self): async with asyncio.TaskGroup() as tg: cal_db_task = tg.create_task( - async_factory(Calendar, *self.gtfs_static_zip, **self.kwargs) + async_factory(Calendar, *self.static_update_targets, **self.kwargs) ) ssi_db_task = tg.create_task( async_factory( - StationStopInfoDatabase, *self.gtfs_static_zip, **self.kwargs + StationStopInfoDatabase, *self.static_update_targets, **self.kwargs ) ) ti_db_task = tg.create_task( - async_factory(TripInfoDatabase, *self.gtfs_static_zip, **self.kwargs) + async_factory( + TripInfoDatabase, *self.static_update_targets, **self.kwargs + ) ) rti_db_task = tg.create_task( - async_factory(RouteInfoDatabase, *self.gtfs_static_zip, **self.kwargs) - ) - if any( - db is None - for db in ( - self.calendar, - self.station_stop_info_db, - self.trip_info_db, - self.route_info_db, - ) - ): - ( - self.calendar, - self.station_stop_info_db, - self.trip_info_db, - self.route_info_db, - ) = ( - cal_db_task.result(), - ssi_db_task.result(), - ti_db_task.result(), - rti_db_task.result(), - ) - self.services_len = len(self.calendar.services) - self.stop_infos_len = len(self.station_stop_info_db.station_stop_infos) - self.trip_infos_len = len(self.trip_info_db.trip_infos) - self.route_infos_len = len(self.route_info_db.route_infos) - _LOGGER.info( - f"GTFS Static Coordinator Initial Update Complete {self.gtfs_static_zip}" - ) - _LOGGER.info(f"GTFS Static Coordinator Services: {self.services_len}") - _LOGGER.info( - f"GTFS Static Coordinator Station Stop Infos: {self.stop_infos_len}" - ) - _LOGGER.info(f"GTFS Static Coordinator Trip Infos: {self.trip_infos_len}") - _LOGGER.info(f"GTFS Static Coordinator Route Infos: {self.route_infos_len}") - else: - old_services_len = self.services_len - self.calendar.services = ( - self.calendar.services | cal_db_task.result().services - ) - self.services_len = len(self.calendar.services) - - old_stop_infos_len = self.stop_infos_len - self.station_stop_info_db.station_stop_infos = ( - self.station_stop_info_db.station_stop_infos - | (ssi_db_task.result().station_stop_infos) - ) - self.stop_infos_len = len(self.station_stop_info_db.station_stop_infos) - - old_trip_infos_len = self.trip_infos_len - self.trip_info_db.trip_infos = ( - self.trip_info_db.trip_infos | ti_db_task.result().trip_infos - ) - self.trip_infos_len = len(self.trip_info_db.trip_infos) - - old_route_infos_len = self.route_infos_len - self.route_info_db.route_infos = ( - self.route_info_db.route_infos | rti_db_task.result().route_infos - ) - self.route_infos_len = len(self.route_info_db.route_infos) - - _LOGGER.info( - f"GTFS Static Coordinator Merge New Data Update Complete {self.gtfs_static_zip}" - ) - _LOGGER.info( - f"GTFS Static Coordinator Services: {old_services_len} -> {self.services_len}" - ) - _LOGGER.info( - f"GTFS Static Coordinator Station Stop Infos: {old_stop_infos_len} -> {self.stop_infos_len}" - ) - _LOGGER.info( - f"GTFS Static Coordinator Trip Infos: {old_trip_infos_len} -> {self.trip_infos_len}" - ) - _LOGGER.info( - f"GTFS Static Coordinator Route Infos: {old_route_infos_len} -> {self.route_infos_len}" + async_factory( + RouteInfoDatabase, *self.static_update_targets, **self.kwargs + ) ) - self.last_static_update = datetime.now() + return ( + cal_db_task.result(), + ssi_db_task.result(), + ti_db_task.result(), + rti_db_task.result(), + ) + async def async_update_static_data(self, clear_old_data=False): + """Update or clear static feeds and merge with existing databases.""" + # Check for clear old data to reset the databases + if clear_old_data: + self.calendar = Calendar() + self.station_stop_info_db = StationStopInfoDatabase() + self.trip_info_db = TripInfoDatabase() + self.route_info_db = RouteInfoDatabase() + _LOGGER.debug("GTFS Static data cleared") + elif not self.static_update_targets: + return + + cal_db, ssi_db, ti_db, rti_db = await self._async_update_static_data() + + old_services_len = self.services_len + self.calendar.services |= cal_db.services + self.services_len = len(self.calendar.services) + + old_stop_infos_len = self.stop_infos_len + self.station_stop_info_db.station_stop_infos |= ssi_db.station_stop_infos + self.stop_infos_len = len(self.station_stop_info_db.station_stop_infos) + + old_trip_infos_len = self.trip_infos_len + self.trip_info_db.trip_infos |= ti_db.trip_infos + self.trip_infos_len = len(self.trip_info_db.trip_infos) + + old_route_infos_len = self.route_infos_len + self.route_info_db.route_infos |= rti_db.route_infos + self.route_infos_len = len(self.route_info_db.route_infos) + + _LOGGER.debug( + f"GTFS Coordinator Merge New Static Data Update Complete {self.gtfs_static_zip}" + ) + _LOGGER.debug( + f"GTFS Coordinator Services: {old_services_len} -> {self.services_len}" + ) + _LOGGER.debug( + f"GTFS Coordinator Station Stop Infos: {old_stop_infos_len} -> {self.stop_infos_len}" + ) + _LOGGER.debug( + f"GTFS Coordinator Trip Infos: {old_trip_infos_len} -> {self.trip_infos_len}" + ) + _LOGGER.debug( + f"GTFS Coordinator Route Infos: {old_route_infos_len} -> {self.route_infos_len}" + ) -class GtfsStaticCoordinator(DataUpdateCoordinator): - """GTFS Static Update Coordinator. Polls Static Data Endpoints for new data on a slower basis.""" + for target in self.static_update_targets: + self.last_static_update.setdefault(target, datetime.now()) + self.static_update_targets.clear() - def __init__( - self, - hass: HomeAssistant, - ) -> None: - """Initialize the GTFS Update Coordinator to notify all entities upon poll.""" - super().__init__( - hass, _LOGGER, name="GTFS Static", update_interval=timedelta(days=1) - ) - # Save the resource path to reload periodically - _LOGGER.info("Setup GTFS Static Update Coordinator") + if os.environ.get("GTFS_REALTIME_SHOW_MEMORY_USE", "off") == "on": + try: + from pympler import asizeof - async def _async_update_data(self): - """Fetch data from API endpoint.""" + _LOGGER.debug( + f"Calendar using {asizeof.asizeof(self.calendar) / 2**20:.2f} MB" + ) + _LOGGER.debug( + f"Stations using {asizeof.asizeof(self.station_stop_info_db) / 2**20:.2f} MB" + ) + _LOGGER.debug( + f"Trips using {asizeof.asizeof(self.trip_info_db) / 2**20:.2f} MB" + ) + _LOGGER.debug( + f"Routes using {asizeof.asizeof(self.route_info_db) / 2**20:.2f} MB" + ) + except ImportError: + """Failed to import pympler for memory usage stats. When using environment variable GTFS_REALTIME_SHOW_MEMORY_USE=on, pympler must be added to the python environment running homeassitant.""" diff --git a/custom_components/gtfs_realtime/icons.json b/custom_components/gtfs_realtime/icons.json new file mode 100644 index 0000000..e0e3be4 --- /dev/null +++ b/custom_components/gtfs_realtime/icons.json @@ -0,0 +1,10 @@ +{ + "services": { + "clear_static_feeds": { + "service": "mdi:delete-empty" + }, + "refresh_static_feeds": { + "service": "mdi:calendar-refresh" + } + } +} diff --git a/custom_components/gtfs_realtime/manifest.json b/custom_components/gtfs_realtime/manifest.json index 4a292d1..2f92398 100644 --- a/custom_components/gtfs_realtime/manifest.json +++ b/custom_components/gtfs_realtime/manifest.json @@ -7,10 +7,11 @@ "config_flow": true, "dependencies": [], "documentation": "https://github.com/bcpearce/homeassistant-gtfs-realtime", + "integration_type": "hub", "iot_class": "cloud_polling", "issue_tracker": "https://github.com/bcpearce/homeassistant-gtfs-realtime/issues", "requirements": [ "gtfs_station_stop==0.8.0" ], - "version": "0.1.7" -} \ No newline at end of file + "version": "0.2.0" +} diff --git a/custom_components/gtfs_realtime/sensor.py b/custom_components/gtfs_realtime/sensor.py index 9f338c5..b9ce7b0 100644 --- a/custom_components/gtfs_realtime/sensor.py +++ b/custom_components/gtfs_realtime/sensor.py @@ -2,14 +2,10 @@ from __future__ import annotations -import os - from gtfs_station_stop.arrival import Arrival -from gtfs_station_stop.calendar import Calendar -from gtfs_station_stop.route_info import RouteInfoDatabase, RouteType +from gtfs_station_stop.route_info import RouteInfo, RouteType from gtfs_station_stop.station_stop import StationStop -from gtfs_station_stop.station_stop_info import StationStopInfo, StationStopInfoDatabase -from gtfs_station_stop.trip_info import TripInfoDatabase +from gtfs_station_stop.station_stop_info import StationStopInfo from homeassistant.components.sensor import ( PLATFORM_SCHEMA as SENSOR_PLATFORM_SCHEMA, SensorDeviceClass, @@ -20,27 +16,21 @@ from homeassistant.core import HomeAssistant, callback import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_platform import AddEntitiesCallback -from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.update_coordinator import CoordinatorEntity import voluptuous as vol +from custom_components.gtfs_realtime import GtfsRealtimeConfigEntry + from .const import ( - CAL_DB, CONF_ARRIVAL_LIMIT, - CONF_ROUTE_ICONS, CONF_STOP_IDS, - COORDINATOR_REALTIME, - DOMAIN, - HEADSIGN_PRETTY, - ROUTE_COLOR_PRETTY, + HEADSIGN, + ROUTE_COLOR, ROUTE_ID, - ROUTE_TEXT_COLOR_PRETTY, - ROUTE_TYPE_PRETTY, - RTI_DB, - SSI_DB, + ROUTE_TEXT_COLOR, + ROUTE_TYPE, STOP_ID, - TI_DB, - TRIP_ID_PRETTY, + TRIP_ID, ) from .coordinator import GtfsRealtimeCoordinator @@ -51,36 +41,20 @@ async def async_setup_entry( hass: HomeAssistant, - config: ConfigType, + entry: GtfsRealtimeConfigEntry, add_entities: AddEntitiesCallback, - discovery_info: DiscoveryInfoType | None = None, ) -> None: """Set up the sensor platform.""" - coordinator: GtfsRealtimeCoordinator = hass.data[DOMAIN][COORDINATOR_REALTIME] - if discovery_info is None: - if CONF_STOP_IDS in config.data: - ssi_db: StationStopInfoDatabase = hass.data[DOMAIN][SSI_DB] - ti_db: TripInfoDatabase = hass.data[DOMAIN][TI_DB] - cal_db: Calendar = hass.data[DOMAIN][CAL_DB] - rti_db: RouteInfoDatabase = hass.data[DOMAIN][RTI_DB] - arrival_limit: int = config.data[CONF_ARRIVAL_LIMIT] - route_icons: os.PathLike = hass.data[DOMAIN].get(CONF_ROUTE_ICONS) - arrival_sensors = [] - for i in range(arrival_limit): - for stop_id in config.data[CONF_STOP_IDS]: - arrival_sensors.append( - ArrivalSensor( - coordinator, - StationStop(stop_id, coordinator.hub), - i, - ssi_db[stop_id], - ti_db, - cal_db, - rti_db, - route_icons=route_icons, - ) - ) - add_entities(arrival_sensors, update_before_add=True) + coordinator: GtfsRealtimeCoordinator = entry.runtime_data + if CONF_STOP_IDS in entry.data: + arrival_limit: int = int(round(entry.data[CONF_ARRIVAL_LIMIT])) + arrival_sensors = [] + for i in range(arrival_limit): + for stop_id in entry.data[CONF_STOP_IDS]: + arrival_sensors.append( + ArrivalSensor(coordinator=coordinator, stop_id=stop_id, idx=i) + ) + add_entities(arrival_sensors, update_before_add=True) class ArrivalSensor(SensorEntity, CoordinatorEntity): @@ -102,25 +76,17 @@ class ArrivalSensor(SensorEntity, CoordinatorEntity): def __init__( self, coordinator: GtfsRealtimeCoordinator, - station_stop: StationStop, + stop_id: str, idx: int, - station_stop_info: StationStopInfo | None = None, - trip_info_db: TripInfoDatabase | None = None, - calendar_db: Calendar | None = None, - route_info_db: RouteInfoDatabase | None = None, - route_icons: os.PathLike | None = None, ) -> None: """Initialize the sensor.""" # Required super().__init__(coordinator) - self.station_stop = station_stop + self.station_stop = coordinator.station_stops.setdefault( + stop_id, StationStop(stop_id, coordinator.hub) + ) self._idx = idx - # Allowed to be `None` - self.station_stop_info = station_stop_info - self.trip_info_db = trip_info_db - self.calendar_db = calendar_db - self.route_icons = route_icons - self.route_info_db = route_info_db + self.coordinator = coordinator self.route_type = RouteType.UNKNOWN self._name = f"{self._idx + 1}: {self._get_station_ref()}" @@ -130,11 +96,14 @@ def __init__( self._arrival_detail: dict[str, str] = {} def _get_station_ref(self): - return ( - self.station_stop_info.name - if self.station_stop_info is not None - else self.station_stop.id + station_stop_info: StationStopInfo = ( + self.coordinator.station_stop_info_db.station_stop_infos.get( + self.station_stop.id + ) ) + if station_stop_info is not None: + return station_stop_info.name + return self.station_stop.id @property def name(self) -> str: @@ -148,21 +117,25 @@ def extra_state_attributes(self) -> dict[str, str]: @property def entity_picture(self) -> str | None: + """Provide the entity picture from a URL.""" return ( - str(self.route_icons).format( + str(self.coordinator.route_icons).format( self._arrival_detail[ROUTE_ID], - self._arrival_detail.get(ROUTE_COLOR_PRETTY, "%230039A6"), - self._arrival_detail.get(ROUTE_TEXT_COLOR_PRETTY, "%23FFFFFF"), + self._arrival_detail.get(ROUTE_COLOR, "%230039A6"), + self._arrival_detail.get(ROUTE_TEXT_COLOR, "%23FFFFFF"), ) - if self.route_icons and self._arrival_detail.get(ROUTE_ID) is not None + if self.coordinator.route_icons + and self._arrival_detail.get(ROUTE_ID) is not None else None ) @property def icon(self) -> str: + """Provide the icon.""" return self.__class__.ICON_DICT.get(self.route_type, "mdi:bus-clock") def update(self) -> None: + """Update state from coordinator data.""" time_to_arrivals = sorted(self.station_stop.get_time_to_arrivals()) self._arrival_detail = {} if len(time_to_arrivals) > self._idx: @@ -171,24 +144,22 @@ def update(self) -> None: time_to_arrival.time, 0 ) # do not allow negative numbers self._arrival_detail[ROUTE_ID] = time_to_arrival.route - if self.trip_info_db is not None: - trip_info = self.trip_info_db.get_close_match( - time_to_arrival.trip, self.calendar_db + if self.coordinator.trip_info_db is not None: + trip_info = self.coordinator.trip_info_db.get_close_match( + time_to_arrival.trip, self.coordinator.calendar ) if trip_info is not None: - self._arrival_detail[HEADSIGN_PRETTY] = trip_info.trip_headsign - self._arrival_detail[TRIP_ID_PRETTY] = trip_info.trip_id - if self.route_info_db is not None: - route_info = self.route_info_db.get(time_to_arrival.route) + self._arrival_detail[HEADSIGN] = trip_info.trip_headsign + self._arrival_detail[TRIP_ID] = trip_info.trip_id + if self.coordinator.route_info_db is not None: + route_info: RouteInfo = self.coordinator.route_info_db.get( + time_to_arrival.route + ) if route_info is not None: - self._arrival_detail[ROUTE_COLOR_PRETTY] = route_info.color - self._arrival_detail[ROUTE_TEXT_COLOR_PRETTY] = ( - route_info.text_color - ) + self._arrival_detail[ROUTE_COLOR] = route_info.color + self._arrival_detail[ROUTE_TEXT_COLOR] = route_info.text_color self.route_type = route_info.type - self._arrival_detail[ROUTE_TYPE_PRETTY] = ( - route_info.type.pretty_name() - ) + self._arrival_detail[ROUTE_TYPE] = route_info.type.pretty_name() else: self._attr_native_value = None self.async_write_ha_state() diff --git a/custom_components/gtfs_realtime/services.yaml b/custom_components/gtfs_realtime/services.yaml new file mode 100644 index 0000000..cec93c8 --- /dev/null +++ b/custom_components/gtfs_realtime/services.yaml @@ -0,0 +1,8 @@ +refresh_static_feeds: + fields: + gtfs_static_data: + required: false + selector: + text: + type: url + multiple: true \ No newline at end of file diff --git a/custom_components/gtfs_realtime/translations/en.json b/custom_components/gtfs_realtime/translations/en.json index 79eebc7..3388f46 100644 --- a/custom_components/gtfs_realtime/translations/en.json +++ b/custom_components/gtfs_realtime/translations/en.json @@ -1,42 +1,109 @@ { - "title": "GTFS Realtime", - "config": { - "step": { - "user": { - "title": "Select a GTFS Provider", - "data": { - "gtfs_provider": "GTFS Provider" - }, - "description": "Select a GTFS Provider" - }, - "choose_static_and_realtime_feeds": { - "data": { - "api_key": "API Key (if required)", - "url_endpoints": "Feed URL", - "gtfs_static_data": "GTFS Static Feeds (file or URL)", - "route_icons": "Route Icons format URL" - }, - "data_description": { - "url_endpoints": "Feed URLs for realtime GTFS Data", - "gtfs_static_data": "GTFS static feed zip file. Can be local file or URL from your provider.", - "route_icons": "URL to a route-icons provider containing an svg image file for a given route. The string can contain up to 3 str.format() compatible formatters for [route_id], [route_color], and [route_text_color] respectively. If your provider gives these colors as HTML hex, you may need to add an html-escaped '#' preceeding the input." - }, - "title": "Select GTFS Realtime and Static Feed URLs.", - "description": "This form will be pre-populated if you selected a provider in the previous step. Feeds you do not require can be removed to improve sensor update performance." - }, - "choose_informed_entities": { - "data": { - "gtfs_provider": "GTFS Provider Name", - "route_ids": "Route ID", - "stop_ids": "Stop ID", - "arrival_limit": "Arrival Limit" - }, - "data_description": { - "route_ids": "Route ID for a GTFS entity to receive service alerts.", - "stop_ids": "Stop ID for a GTFS entity to receive arrival data and service alerts" - }, - "description": "Configure GTFS parameters." - } + "title": "GTFS Realtime", + "common": { + "generic_gtfs_provider": "Generic GTFS Provider", + "manual_gtfs_provider_name": "Other - Enter Manually" + }, + "config": { + "error": { + "failed_preconfigured_feeds": "Failed to get preconfigured feeds", + "select_at_least_one_stop_or_route": "Must select at least one stop or route for GTFS updates.", + "unexpected_user_input": "Unexpected user Input" + }, + "step": { + "choose_informed_entities": { + "title": "Select Route and Stop IDs to create sensor and binary sensor entities.", + "data": { + "arrival_limit": "Arrival Limit", + "gtfs_provider": "GTFS Provider Name", + "route_ids": "Route ID", + "stop_ids": "Stop ID" + }, + "data_description": { + "route_ids": "Route ID for a GTFS entity to receive service alerts.", + "stop_ids": "Stop ID for a GTFS entity to receive arrival data and service alerts" + }, + "description": "Configure GTFS parameters.", + "sections": { + "static_sources_update_frequency": { + "name": "Static Data Update Frequency", + "description": "Set the duration of time between each update for static data. Check with your GTFS provider for the expected frequency." + } } + }, + "choose_static_and_realtime_feeds": { + "title": "Select GTFS Realtime and Static Feed URLs.", + "data": { + "api_key": "API Key (if required)", + "gtfs_static_data": "GTFS Static Feeds (file or URL)", + "route_icons": "Route Icons format URL", + "url_endpoints": "Feed URL" + }, + "data_description": { + "gtfs_static_data": "GTFS static feed zip file. Can be local file or URL from your provider.", + "route_icons": "URL to a route-icons provider containing an svg image file for a given route. The string can contain up to 3 str.format() compatible formatters for [route_id], [route_color], and [route_text_color] respectively. If your provider gives these colors as HTML hex, you may need to add an html-escaped '#' preceeding the input.", + "url_endpoints": "Feed URLs for realtime GTFS Data" + }, + "description": "This form will be pre-populated if you selected a provider in the previous step. Feeds you do not require can be removed to improve sensor update performance." + }, + "reconfigure": { + "title": "Reconfigure GTFS parameters.", + "description": "This will force update to static data." + }, + "user": { + "title": "Select a GTFS Provider", + "data": { + "gtfs_provider_id": "GTFS Provider" + }, + "description": "Select GTFS Provider" + } } -} \ No newline at end of file + }, + "entity_component": { + "binary_sensor": { + "state_attributes": { + "description": { + "name": "Description {i+1}" + }, + "header": { + "name": "Header {i+1}" + } + } + }, + "sensor": { + "state_attributes": { + "headsign": { + "name": "Headsign" + }, + "route_color": { + "name": "Route Color" + }, + "route_id": { + "name": "Route ID" + }, + "route_text_color": { + "name": "Route Text Color" + }, + "trip_id": { + "name": "Trip ID" + } + } + } + }, + "services": { + "clear_static_feeds": { + "name": "Clear Static Feeds", + "description": "Clear static GTFS schedule data. Data will be pulled again on the configured intervals for each source." + }, + "refresh_static_feeds": { + "name": "Refresh Static Feeds", + "description": "Pulls static GTFS schedule data from the data provider URIs.", + "fields": { + "gtfs_static_data": { + "name": "URL or File Path", + "description": "Static data sources to update. If blank, all Static Resources will be updated." + } + } + } + } +} diff --git a/hooks/hassfest.sh b/hooks/hassfest.sh new file mode 100755 index 0000000..05e854a --- /dev/null +++ b/hooks/hassfest.sh @@ -0,0 +1,2 @@ +#!/usr/bin/sh +docker run --rm -v ./custom_components://github/workspace ghcr.io/home-assistant/hassfest \ No newline at end of file diff --git a/requirements.test.txt b/requirements.test.txt index 8ab2037..782175c 100644 --- a/requirements.test.txt +++ b/requirements.test.txt @@ -1,7 +1,9 @@ codecov==2.1.13 -freezegun==1.4.0 +freezegun==1.5.1 gtfs-station-stop>=0.8.0 pre-commit==3.6.2 -pytest==8.0.2 -pytest-cov==4.1.0 -pytest-homeassistant-custom-component>=0.13.109 +pytest==8.3.3 +pytest-cov==6.0.0 +pytest-homeassistant-custom-component>=0.13.193 +Pympler==1.1 +ruff==0.8.4 \ No newline at end of file diff --git a/resources/feeds.json b/resources/feeds.json new file mode 100644 index 0000000..3b93b4d --- /dev/null +++ b/resources/feeds.json @@ -0,0 +1,108 @@ +{ + "boston_mbta": { + "name": "Boston MBTA", + "realtime_feeds": { + "alerts": "https://cdn.mbta.com/realtime/Alerts.pb", + "trip_update": "https://cdn.mbta.com/realtime/TripUpdates.pb", + "vehicle_positions": "https://cdn.mbta.com/realtime/VehiclePositions.pb" + }, + "static_feeds": { + "regular": "https://cdn.mbta.com/MBTA_GTFS.zip" + } + }, + "nyc_bus": { + "name": "NYC Busses", + "realtime_feeds": { + "alerts": "https://gtfsrt.prod.obanyc.com/alerts?key=", + "trip_update": "https://gtfsrt.prod.obanyc.com/tripUpdates?key=", + "vehicle_positions": "https://gtfsrt.prod.obanyc.com/vehiclePositions?key=" + }, + "static_feeds": { + "bronx": "http://web.mta.info/developers/data/nyct/bus/google_transit_bronx.zip", + "brooklyn": "http://web.mta.info/developers/data/nyct/bus/google_transit_brooklyn.zip", + "manhattan": "http://web.mta.info/developers/data/nyct/bus/google_transit_manhattan.zip", + "mta_bus": "http://web.mta.info/developers/data/busco/google_transit.zip", + "queens": "http://web.mta.info/developers/data/nyct/bus/google_transit_queens.zip", + "staten_island": "http://web.mta.info/developers/data/nyct/bus/google_transit_staten_island.zip" + } + }, + "nyc_ferry": { + "name": "NYC Ferry", + "realtime_feeds": { + "alerts": "http://nycferry.connexionz.net/rtt/public/utility/gtfsrealtime.aspx/alert", + "trip_update": "http://nycferry.connexionz.net/rtt/public/utility/gtfsrealtime.aspx/tripupdate" + }, + "static_feeds": { + "regular": "http://nycferry.connexionz.net/rtt/public/utility/gtfs.aspx" + } + }, + "nyc_long_island_railroad": { + "name": "Long Island Railroad", + "realtime_feeds": { + "all": "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/lirr%2Fgtfs-lirr" + }, + "static_feeds": { + "regular": "https://rrgtfsfeeds.s3.amazonaws.com/google_transit.zip" + } + }, + "nyc_metro_north_railroad": { + "name": "Metro North Railroad", + "realtime_feeds": { + "all": "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/mnr%2Fgtfs-mnr" + }, + "static_feeds": { + "regular": "http://web.mta.info/developers/data/mnr/google_transit.zip" + } + }, + "nyc_subway": { + "name": "NYC Subway", + "realtime_feeds": { + "1234567S": "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs", + "ACE": "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-ace", + "BDFM": "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-bdfm", + "G": "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-g", + "JZ": "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-jz", + "L": "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-l", + "NQRW": "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-nqrw", + "SIR": "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-si", + "alerts": "https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/camsys%2Fsubway-alerts" + }, + "route_icons": "https://raw.githubusercontent.com/bcpearce/homeassistant-gtfs-realtime/main/resources/NYCT_Bullets/{}.svg", + "static_feeds": { + "regular": "http://web.mta.info/developers/data/nyct/subway/google_transit.zip", + "supplemented": "http://web.mta.info/developers/files/google_transit_supplemented.zip" + } + }, + "sf_bart": { + "name": "Bay Area Rapid Transit", + "realtime_feeds": { + "alerts": "http://api.bart.gov/gtfsrt/alerts.aspx", + "trip_update": "http://api.bart.gov/gtfsrt/tripupdate.aspx" + }, + "static_feeds": { + "regular": "https://www.bart.gov/dev/schedules/google_transit.zip" + } + }, + "washington_metro_area_bus": { + "name": "Washington Metropolitan Area Transit Authroity Bus", + "realtime_feeds": { + "alerts": "https://api.wmata.com/gtfs/bus-gtfsrt-alerts.pb", + "trip_update": "https://api.wmata.com/gtfs/bus-gtfsrt-tripupdates.pb", + "vehicle_positions": "https://api.wmata.com/gtfs/bus-gtfsrt-vehiclepositions.pb" + }, + "static_feeds": { + "regular": "https://api.wmata.com/gtfs/bus-gtfs-static.zip" + } + }, + "washington_metro_area_rail": { + "name": "Washington Metropolitan Area Transit Authority Rail", + "realtime_feeds": { + "alerts": "https://api.wmata.com/gtfs/rail-gtfsrt-alerts.pb", + "trip_update": "https://api.wmata.com/gtfs/rail-gtfsrt-tripupdates.pb", + "vehicle_positions": "https://api.wmata.com/gtfs/rail-gtfsrt-vehiclepositions.pb" + }, + "static_feeds": { + "regular": "https://api.wmata.com/gtfs/rail-gtfs-static.zip" + } + } +} diff --git a/setup.cfg b/setup.cfg index fa702aa..bd0bc6d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,7 @@ omit = [pytest] addopts = --allow-hosts=127.0.0.1,localhost +timeout = 60 [tool:pytest] testpaths = tests diff --git a/tests/conftest.py b/tests/conftest.py index e2ccc69..c71dcbb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,50 @@ """Fixtures for testing.""" +import json + import pytest +from pytest_homeassistant_custom_component.common import MockConfigEntry + +from custom_components.gtfs_realtime.config_flow import DOMAIN @pytest.fixture(autouse=True) def auto_enable_custom_integrations(enable_custom_integrations): """Enable custom integration that will be tested.""" yield + + +@pytest.fixture(name="entry_v1") +def create_config_entry_v1(): + """Fixture for entry version 1.""" + yield MockConfigEntry( + entry_id="mock_config_v1", + domain=DOMAIN, + version=1, + minor_version=0, + data={"url_endpoints": ["https://gtfs.example.com/feed"]}, + ) + + +@pytest.fixture(name="entry_v1_full") +def create_config_entry_v1_full(): + """Fixture with full mock data for entry version 1.""" + with open("tests/fixtures/config_entry_v1_full.json") as f: + conf = json.load(f) + yield MockConfigEntry(**conf) + + +@pytest.fixture(name="entry_v2_full") +def create_config_entry_v2_full(): + """Fixture with full mock data for entry version 2.""" + with open("tests/fixtures/config_entry_v2_full.json") as f: + conf = json.load(f) + yield MockConfigEntry(**conf) + + +@pytest.fixture(name="entry_v2_nodialout") +def create_config_entry_v2_nodialout(): + """Fixture with mock data for entry version 2 with limited URLs to access.""" + with open("tests/fixtures/config_entry_v2_nodialout.json") as f: + conf = json.load(f) + yield MockConfigEntry(**conf) diff --git a/tests/fixtures/config_entry_v1_full.json b/tests/fixtures/config_entry_v1_full.json new file mode 100644 index 0000000..4e5ca48 --- /dev/null +++ b/tests/fixtures/config_entry_v1_full.json @@ -0,0 +1,40 @@ +{ + "domain": "gtfs_realtime", + "title": "Entry V1 Full", + "data": { + "api_key": "", + "arrival_limit": 4, + "gtfs_provider": "Entry V1 Mock", + "gtfs_static_data": [ + "https://example.com/gtfs1.zip", + "https://example.com/gtfs2.zip" + ], + "route_icons": "https://icons.example.com/{}.svg", + "route_ids": [ + "1", + "2", + "3" + ], + "stop_ids": [ + "101N", + "102S" + ], + "url_endpoints": [ + "https://api-endpoint.example.com/rt1", + "https://api-endpoint.example.com/rt2", + "https://api-endpoint.example.com/rt3", + "https://api-endpoint.example.com/rt4", + "https://api-endpoint.example.com/rt5" + ] + }, + "disabled_by": null, + "discovery_keys": {}, + "entry_id": "ENTRYV1FULL", + "minor_version": 1, + "options": {}, + "pref_disable_new_entities": false, + "pref_disable_polling": false, + "source": "user", + "unique_id": null, + "version": 1 +} diff --git a/tests/fixtures/config_entry_v2_full.json b/tests/fixtures/config_entry_v2_full.json new file mode 100644 index 0000000..5bd1483 --- /dev/null +++ b/tests/fixtures/config_entry_v2_full.json @@ -0,0 +1,48 @@ +{ + "domain": "gtfs_realtime", + "title": "Entry V2 Full", + "data": { + "api_key": "", + "arrival_limit": 4, + "gtfs_provider": "Entry V2 Mock", + "gtfs_static_data": [ + "https://example.com/gtfs1.zip", + "https://example.com/gtfs2.zip" + ], + "route_icons": "https://icons.example.com/{}.svg", + "route_ids": [ + "1", + "2", + "3" + ], + "static_sources_update_frequency": { + "https://example.com/gtfs1.zip": { + "hours": 2 + }, + "https://example.com/gtfs2.zip": { + "days": 10 + } + }, + "stop_ids": [ + "101N", + "102S" + ], + "url_endpoints": [ + "https://api-endpoint.example.com/rt1", + "https://api-endpoint.example.com/rt2", + "https://api-endpoint.example.com/rt3", + "https://api-endpoint.example.com/rt4", + "https://api-endpoint.example.com/rt5" + ] + }, + "disabled_by": null, + "discovery_keys": {}, + "entry_id": "ENTRYV2FULL", + "minor_version": 0, + "options": {}, + "pref_disable_new_entities": false, + "pref_disable_polling": false, + "source": "user", + "unique_id": null, + "version": 2 +} diff --git a/tests/fixtures/config_entry_v2_nodialout.json b/tests/fixtures/config_entry_v2_nodialout.json new file mode 100644 index 0000000..f8f70b3 --- /dev/null +++ b/tests/fixtures/config_entry_v2_nodialout.json @@ -0,0 +1,37 @@ +{ + "domain": "gtfs_realtime", + "title": "Entry V2 No Dialout", + "data": { + "api_key": "", + "arrival_limit": 4, + "gtfs_provider": "Entry V2 Mock", + "gtfs_static_data": [], + "route_ids": [ + "1", + "2", + "3" + ], + "static_sources_update_frequency": {}, + "stop_ids": [ + "101N", + "102S" + ], + "url_endpoints": [ + "https://api-endpoint.example.com/rt1", + "https://api-endpoint.example.com/rt2", + "https://api-endpoint.example.com/rt3", + "https://api-endpoint.example.com/rt4", + "https://api-endpoint.example.com/rt5" + ] + }, + "disabled_by": null, + "discovery_keys": {}, + "entry_id": "ENTRYV2NODIALOUT", + "minor_version": 0, + "options": {}, + "pref_disable_new_entities": false, + "pref_disable_polling": false, + "source": "user", + "unique_id": null, + "version": 2 +} diff --git a/tests/test_binary_sensor.py b/tests/test_binary_sensor.py index 9e23459..33348eb 100644 --- a/tests/test_binary_sensor.py +++ b/tests/test_binary_sensor.py @@ -1,59 +1,38 @@ """Test sensor.""" -import datetime +from unittest.mock import AsyncMock, patch -from gtfs_station_stop.alert import Alert -from gtfs_station_stop.feed_subject import FeedSubject -from gtfs_station_stop.route_status import RouteStatus +from gtfs_station_stop.calendar import Calendar +from gtfs_station_stop.route_info import RouteInfoDatabase +from gtfs_station_stop.station_stop_info import StationStopInfoDatabase +from gtfs_station_stop.trip_info import TripInfoDatabase +from homeassistant.const import STATE_OFF from homeassistant.core import HomeAssistant -import pytest +from pytest_homeassistant_custom_component.common import MockConfigEntry -from custom_components.gtfs_realtime.binary_sensor import AlertSensor -from custom_components.gtfs_realtime.coordinator import GtfsRealtimeCoordinator - -@pytest.fixture -def alert_sensor(hass: HomeAssistant) -> AlertSensor: - """Fixture for a basic alert sensor.""" - feed_subject = FeedSubject([]) - route_status = RouteStatus("1", feed_subject) - route_status.alerts = [ - Alert(datetime.datetime.max, {"en": "Alert"}, {"en": "This is an Alert"}), - Alert( - datetime.datetime.max, - {"en": "Another Alert"}, - {"en": "This is another Alert"}, +async def test_setup_binary_sensors( + hass: HomeAssistant, entry_v2_nodialout: MockConfigEntry +): + """Test setting up binary sensors for integration.""" + with ( + patch( + "custom_components.gtfs_realtime.coordinator.GtfsRealtimeCoordinator._async_update_data", + new_callable=AsyncMock, ), - ] - - async def noop(): - pass - - alert_sensor = AlertSensor( - GtfsRealtimeCoordinator(hass, feed_subject), route_status, "en" - ) - alert_sensor.async_write_ha_state = noop - return alert_sensor - - -def test_create_entity(alert_sensor): - """Tests entity construction.""" - # Created by the fixture - assert alert_sensor.state == "off" - assert "1" in alert_sensor.name - - -def test_update(alert_sensor): - """ - Tests calling the update method on the sensor. - - This will latch the data in station_stop into the hass platform. - """ - alert_sensor.update() - assert alert_sensor.state == "on" - assert alert_sensor.extra_state_attributes["Header"] == "Alert" - assert alert_sensor.extra_state_attributes["Header 2"] == "Another Alert" - assert alert_sensor.extra_state_attributes["Description"] == "This is an Alert" - assert ( - alert_sensor.extra_state_attributes["Description 2"] == "This is another Alert" - ) + patch( + "custom_components.gtfs_realtime.coordinator.GtfsRealtimeCoordinator._async_update_static_data", + new_callable=AsyncMock, + return_value=( + Calendar(), + StationStopInfoDatabase(), + TripInfoDatabase(), + RouteInfoDatabase(), + ), + ), + ): + entry_v2_nodialout.add_to_hass(hass) + assert await hass.config_entries.async_setup(entry_v2_nodialout.entry_id) + await hass.async_block_till_done() + assert hass.states.get("binary_sensor.1_service_alerts").state == STATE_OFF + assert hass.states.get("binary_sensor.2_service_alerts").state == STATE_OFF diff --git a/tests/test_config_flow.py b/tests/test_config_flow.py index 07603b9..032a576 100644 --- a/tests/test_config_flow.py +++ b/tests/test_config_flow.py @@ -1,59 +1,231 @@ """Test Config Flow.""" +from unittest.mock import patch + +from aiohttp.web import HTTPNotFound from homeassistant import config_entries +from homeassistant.config_entries import ConfigFlowResult from homeassistant.core import HomeAssistant -from homeassistant.data_entry_flow import RESULT_TYPE_FORM +from homeassistant.data_entry_flow import FlowResultType +from homeassistant.helpers.selector import SelectOptionDict import pytest +from pytest_homeassistant_custom_component.common import MockConfigEntry from custom_components.gtfs_realtime.config_flow import GtfsRealtimeConfigFlow -from custom_components.gtfs_realtime.const import CONF_GTFS_STATIC_DATA, DOMAIN +from custom_components.gtfs_realtime.const import ( + CONF_GTFS_PROVIDER, + CONF_GTFS_PROVIDER_ID, + CONF_GTFS_STATIC_DATA, + CONF_ROUTE_IDS, + CONF_STOP_IDS, + CONF_URL_ENDPOINTS, + DOMAIN, +) @pytest.fixture def flow(): """Fixture that constructs the flow.""" + with patch.object(GtfsRealtimeConfigFlow, "_get_feeds", return_value={}): + flow = GtfsRealtimeConfigFlow() + yield flow + + +@pytest.fixture +def example_gtfs_feed_data(): + """Fixture for example provider data.""" + yield { + CONF_GTFS_PROVIDER: "Example GTFS Provider", + CONF_GTFS_PROVIDER_ID: "example_gtfs_provider", + CONF_URL_ENDPOINTS: [ + "https://gtfs.example.com/rt1", + "https://gtfs.example.com/rt2", + ], + CONF_GTFS_STATIC_DATA: ["https://gtfs.example.com/static1.zip"], + } + + +@pytest.fixture +def good_routes_response_patch(): + """Fixture for good feed response for pre-populating routes.""" + yield patch.object( + GtfsRealtimeConfigFlow, + "_get_route_options", + return_value=[SelectOptionDict(value="X", label="Route X")], + ) + - # feeds call is now no-op - def no_feeds(): - return {} +@pytest.fixture +def good_stops_response_patch(): + """Fixture for good feed response for pre-populating stops.""" + yield patch.object( + GtfsRealtimeConfigFlow, + "_get_stop_options", + return_value=[SelectOptionDict(value="A", label="Route A")], + ) - GtfsRealtimeConfigFlow.get_feeds = no_feeds - return GtfsRealtimeConfigFlow() + +@pytest.fixture +def bad_routes_response_patch(): + """Fixture for bad feed response for pre-populating routes.""" + yield patch.object( + GtfsRealtimeConfigFlow, + "_get_route_options", + side_effect=HTTPNotFound(), + ) -@pytest.mark.skip("Fixture might be failing this.") -async def test_form(hass: HomeAssistant, flow): +@pytest.fixture +def bad_stops_response_patch(): + """Fixture for bad feed response for pre-populating stops.""" + yield patch.object( + GtfsRealtimeConfigFlow, + "_get_stop_options", + side_effect=HTTPNotFound(), + ) + + +@pytest.fixture +def example_gtfs_informed_entities_data(): + """Fixture for informed entities data.""" + yield {CONF_ROUTE_IDS: ["X", "Y", "Z"], CONF_STOP_IDS: ["A", "B", "C"]} + + +async def test_form(hass: HomeAssistant, flow) -> None: """Test we get the form.""" - result = await hass.config_entries.flow.async_init( + result: ConfigFlowResult = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) - assert result["type"] == RESULT_TYPE_FORM - assert result["errors"] is None + assert result["type"] == FlowResultType.FORM + assert result["errors"] == {} -async def test_step_user(flow): +async def test_step_user(flow: GtfsRealtimeConfigFlow) -> None: """Test User Setup through Config Flow.""" - await flow.async_step_user(None) + result: ConfigFlowResult = await flow.async_step_user(None) + assert result["type"] == FlowResultType.FORM + assert result["errors"] == {} + # check feeds were acquired + GtfsRealtimeConfigFlow._get_feeds.assert_called() -async def test_user_step_get_feeds_fails(flow): - """Test that failures getting feeds do not break config flow.""" +async def test_step_user_input_manual_provider(flow: GtfsRealtimeConfigFlow) -> None: + """Test User input selection 'Manual' GTFS provider.""" + result: ConfigFlowResult = await flow.async_step_user( + user_input={CONF_GTFS_PROVIDER_ID: "_"} + ) + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "choose_static_and_realtime_feeds" - def fail(): - raise RuntimeError("This must fail.") - GtfsRealtimeConfigFlow.get_feeds = fail - flow_result = await flow.async_step_user(None) - assert "base" in flow_result["errors"] +async def test_user_step_get_feeds_fails() -> None: + """Test that failures getting feeds do not break config flow.""" + with patch.object(GtfsRealtimeConfigFlow, "_get_feeds", side_effect=HTTPNotFound()): + flow = GtfsRealtimeConfigFlow() + result: ConfigFlowResult = await flow.async_step_user(None) + assert "base" in result["errors"] + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "user" -async def test_step_choose_static_and_realtime_feeds(flow): +async def test_step_choose_static_and_realtime_feeds_no_prefill( + flow: GtfsRealtimeConfigFlow, +) -> None: """Test config flow for choosing static and realtime feeds.""" - await flow.async_step_choose_static_and_realtime_feeds({}) + result: ConfigFlowResult = await flow.async_step_choose_static_and_realtime_feeds() + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "choose_static_and_realtime_feeds" + + +async def test_step_choose_static_and_realtime_feeds_prefilled( + flow: GtfsRealtimeConfigFlow, + example_gtfs_feed_data, + good_stops_response_patch, + good_routes_response_patch, +): + """Test that choosing feeds pre-fills from the previous step.""" + with good_stops_response_patch, good_routes_response_patch: + await flow.async_step_choose_static_and_realtime_feeds(example_gtfs_feed_data) + # hub will be configured with the provider name and ID + assert flow.hub_config[CONF_GTFS_PROVIDER_ID] == "example_gtfs_provider" + + +async def test_step_choose_informed_entities( + flow: GtfsRealtimeConfigFlow, + example_gtfs_feed_data, + good_stops_response_patch, + good_routes_response_patch, +) -> None: + """Test config flow for choosing informed entities.""" + flow.hub_config |= example_gtfs_feed_data + + with good_stops_response_patch, good_routes_response_patch: + result: ConfigFlowResult = await flow.async_step_choose_informed_entities() + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "choose_informed_entities" -async def test_step_choose_informed_entities(flow): +async def test_step_choose_informed_entities_shows_feed_selector_if_data_pull_fails( + flow: GtfsRealtimeConfigFlow, + example_gtfs_feed_data, + bad_stops_response_patch, + bad_routes_response_patch, +) -> None: """Test config flow for choosing informed entities.""" - await flow.async_step_choose_informed_entities( - user_input={CONF_GTFS_STATIC_DATA: []} + flow.hub_config |= example_gtfs_feed_data + + with bad_stops_response_patch, bad_routes_response_patch: + result: ConfigFlowResult = await flow.async_step_choose_informed_entities() + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "choose_static_and_realtime_feeds" + assert result["errors"] + + +async def test_step_choose_informed_entities_good_data( + flow: GtfsRealtimeConfigFlow, + example_gtfs_feed_data, + example_gtfs_informed_entities_data, +) -> None: + """Test config flow finish after choosing informed entities.""" + flow.hub_config |= example_gtfs_feed_data + result: ConfigFlowResult = await flow.async_step_choose_informed_entities( + user_input=example_gtfs_informed_entities_data + ) + # creates an entry + assert result["type"] == FlowResultType.CREATE_ENTRY + + +async def test_step_choose_informed_entities_no_entities( + flow: GtfsRealtimeConfigFlow, + example_gtfs_feed_data, + good_stops_response_patch, + good_routes_response_patch, +) -> None: + """Test config flow finish after choosing informed entities.""" + flow.hub_config |= example_gtfs_feed_data + with good_stops_response_patch, good_routes_response_patch: + result: ConfigFlowResult = await flow.async_step_choose_informed_entities( + user_input={CONF_ROUTE_IDS: [], CONF_STOP_IDS: []} + ) + # creates an entry + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "choose_informed_entities" + + +@pytest.mark.skip(reason="Need to Reimplement Reconfiguration Step") +async def test_step_reconfigure(hass: HomeAssistant, entry_v1: MockConfigEntry) -> None: + """Test Reconfigure.""" + entry_v1.add_to_hass(hass) + old_entry_data = entry_v1.data.copy() + result = await entry_v1.start_reconfigure_flow(hass) + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "reconfigure" + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], {"gtfs_static_data_update_frequency_hours": 15} ) + assert result["type"] == FlowResultType.ABORT + assert result["reason"] == "reconfigure_successful" + entry = hass.config_entries.async_get_entry(entry_v1.entry_id) + assert entry.data == {**old_entry_data, **entry.data} + assert entry.data["gtfs_static_data_update_frequency_hours"] == 15 diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py new file mode 100644 index 0000000..f6c71e7 --- /dev/null +++ b/tests/test_coordinator.py @@ -0,0 +1,135 @@ +"""Test Coordinator.""" + +from datetime import date, timedelta +from unittest.mock import AsyncMock, patch + +from freezegun.api import FrozenDateTimeFactory +from gtfs_station_stop.calendar import Calendar, Service, ServiceDays +from gtfs_station_stop.feed_subject import FeedSubject +from gtfs_station_stop.route_info import RouteInfo, RouteInfoDatabase +from gtfs_station_stop.station_stop_info import StationStopInfo, StationStopInfoDatabase +from gtfs_station_stop.trip_info import TripInfo, TripInfoDatabase +from homeassistant.const import STATE_UNKNOWN +from homeassistant.core import HomeAssistant +import pytest +from pytest_homeassistant_custom_component.common import ( + MockConfigEntry, + async_fire_time_changed, +) + +from custom_components.gtfs_realtime.const import ( + CLEAR_STATIC_FEEDS, + DOMAIN, + REFRESH_STATIC_FEEDS, +) +from custom_components.gtfs_realtime.coordinator import GtfsRealtimeCoordinator + + +@pytest.fixture +async def entry_v2_and_coordinator_patch( + hass: HomeAssistant, freezer: FrozenDateTimeFactory, entry_v2_full: MockConfigEntry +): + """Fixture for testing using config version 2.""" + with ( + patch( + "gtfs_station_stop.feed_subject.FeedSubject.async_update", + new_callable=AsyncMock, + return_value=None, + ) as hub_update_patch, + patch( + "custom_components.gtfs_realtime.coordinator.GtfsRealtimeCoordinator._async_update_static_data", + new_callable=AsyncMock, + return_value=( + Calendar(), + StationStopInfoDatabase(), + TripInfoDatabase(), + RouteInfoDatabase(), + ), + ), + ): + assert hub_update_patch.call_count == 0 + entry_v2_full.add_to_hass(hass) + assert await hass.config_entries.async_setup(entry_v2_full.entry_id) + await hass.async_block_till_done() + assert hub_update_patch.call_count >= 1 + assert entry_v2_full.runtime_data.last_static_update + yield entry_v2_full, hub_update_patch + + +def test_coordinator_construction(hass: HomeAssistant): + """Smoke test for creating a coordinator.""" + GtfsRealtimeCoordinator(hass, feed_subject=FeedSubject([])) + + +async def test_update_static_data( + hass: HomeAssistant, freezer: FrozenDateTimeFactory, entry_v2_and_coordinator_patch +): + """Test updates through the coordinator.""" + entry_v2, coordinator_patch = entry_v2_and_coordinator_patch + freezer.tick(timedelta(seconds=60)) + async_fire_time_changed(hass) + await hass.async_block_till_done() + + # Check second call to update on gtfs1 with 2 hour refresh + start_call_count: int = coordinator_patch.call_count + freezer.tick(timedelta(hours=2.1)) + async_fire_time_changed(hass) + await hass.async_block_till_done() + + assert coordinator_patch.call_count == start_call_count + 1 + + # Check second call to update on gtfs2 with 10 day refresh + # Find a way to track each source independently + start_call_count: int = coordinator_patch.call_count + freezer.tick(timedelta(days=10)) + async_fire_time_changed(hass) + await hass.async_block_till_done() + assert coordinator_patch.call_count == start_call_count + 1 + + sensor_state = hass.states.get("sensor.1_101n") + assert sensor_state.state == STATE_UNKNOWN + + +async def test_clear_static_data_service( + hass: HomeAssistant, freezer: FrozenDateTimeFactory, entry_v2_and_coordinator_patch +): + """Test clearing static data.""" + entry_v2, _ = entry_v2_and_coordinator_patch + coordinator: GtfsRealtimeCoordinator = entry_v2.runtime_data + coordinator.calendar.services["Normal"] = Service( + "X", + ServiceDays.no_service(), + start=date(year=2024, month=12, day=1), + end=date(year=2024, month=12, day=31), + ) + coordinator.station_stop_info_db.station_stop_infos["Stop"] = StationStopInfo( + {"stop_id": "Stop"} + ) + coordinator.route_info_db.route_infos["Route"] = RouteInfo( + { + "route_id": "Route", + "route_long_name": "Long Route Name", + "route_type": "1", + } + ) + coordinator.trip_info_db.trip_infos["Trip"] = TripInfo( + {"trip_id": "Trip", "route_id": "Route", "service_id": "Normal"} + ) + await hass.services.async_call(DOMAIN, CLEAR_STATIC_FEEDS, blocking=True) + await hass.async_block_till_done() + assert len(coordinator.calendar.services) == 0 + assert len(coordinator.station_stop_info_db.station_stop_infos) == 0 + assert len(coordinator.route_info_db.route_infos) == 0 + assert len(coordinator.trip_info_db.trip_infos) == 0 + + +@pytest.mark.skip("May require debouncing") +async def test_refresh_static_data_service( + hass: HomeAssistant, entry_v2_and_coordinator_patch +): + """Test refreshing static data.""" + _, coordinator_patch = entry_v2_and_coordinator_patch + before_call_count = coordinator_patch.call_count + await hass.services.async_call(DOMAIN, REFRESH_STATIC_FEEDS, blocking=True) + await hass.async_block_till_done() + assert coordinator_patch.call_count == before_call_count + 1 diff --git a/tests/test_init.py b/tests/test_init.py index 133269c..a1b5184 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -1,19 +1,65 @@ """Test component setup.""" +from datetime import timedelta +from unittest.mock import patch + from homeassistant.core import HomeAssistant -from homeassistant.setup import async_setup_component +from pytest_homeassistant_custom_component.common import MockConfigEntry from custom_components.gtfs_realtime.const import ( - CONF_API_KEY, CONF_GTFS_STATIC_DATA, - CONF_URL_ENDPOINTS, - DOMAIN, + CONF_STATIC_SOURCES_UPDATE_FREQUENCY, + CONF_STATIC_SOURCES_UPDATE_FREQUENCY_DEFAULT, ) -async def test_async_setup(hass: HomeAssistant) -> None: +async def test_lifecycle(hass: HomeAssistant, entry_v2_full) -> None: """Test the component gets setup.""" - test_config = { - DOMAIN: {CONF_API_KEY: "", CONF_URL_ENDPOINTS: [], CONF_GTFS_STATIC_DATA: []} - } - assert await async_setup_component(hass, DOMAIN, test_config) is True + entry_v2_full.add_to_hass(hass) + await hass.config_entries.async_setup(entry_v2_full.entry_id) + await hass.async_block_till_done() + assert await hass.config_entries.async_remove(entry_v2_full.entry_id) + await hass.async_block_till_done() + + +async def test_migrate_from_v1( + hass: HomeAssistant, + entry_v1_full: MockConfigEntry, +) -> None: + """Test Migration From Version 1.""" + + with ( + patch( + "gtfs_station_stop.feed_subject.FeedSubject.async_update", return_value=None + ), + patch( + "custom_components.gtfs_realtime.coordinator.GtfsRealtimeCoordinator.async_update_static_data", + return_value=None, + ), + ): + entry_v1_full.add_to_hass(hass) + + assert await hass.config_entries.async_setup(entry_v1_full.entry_id) + + await hass.async_block_till_done() + + # Adds default time for each static data url + updated_entry = hass.config_entries.async_get_entry(entry_v1_full.entry_id) + + static_data_uris = [ + "https://example.com/gtfs1.zip", + "https://example.com/gtfs2.zip", + ] + assert set(static_data_uris) == set(updated_entry.data.get(CONF_GTFS_STATIC_DATA)) + + for uri in static_data_uris: + # update everything to the default + timedelta_dict = updated_entry.data.get( + CONF_STATIC_SOURCES_UPDATE_FREQUENCY + ).get(uri) + assert ( + timedelta_dict.get("hours") == CONF_STATIC_SOURCES_UPDATE_FREQUENCY_DEFAULT + ) + assert timedelta(**timedelta_dict) == timedelta( + hours=CONF_STATIC_SOURCES_UPDATE_FREQUENCY_DEFAULT + ) diff --git a/tests/test_sensor.py b/tests/test_sensor.py index c332c20..ab4584a 100644 --- a/tests/test_sensor.py +++ b/tests/test_sensor.py @@ -1,61 +1,123 @@ """Test sensor.""" -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta +from unittest.mock import AsyncMock, patch -from freezegun import freeze_time +from freezegun.api import FrozenDateTimeFactory from gtfs_station_stop.arrival import Arrival -from gtfs_station_stop.feed_subject import FeedSubject -from gtfs_station_stop.station_stop import StationStop +from gtfs_station_stop.calendar import Calendar +from gtfs_station_stop.route_info import RouteInfoDatabase +from gtfs_station_stop.station_stop_info import StationStopInfoDatabase from gtfs_station_stop.trip_info import TripInfoDatabase +from homeassistant.const import STATE_UNKNOWN from homeassistant.core import HomeAssistant import pytest +from pytest_homeassistant_custom_component.common import ( + MockConfigEntry, + async_fire_time_changed, +) from custom_components.gtfs_realtime.coordinator import GtfsRealtimeCoordinator -from custom_components.gtfs_realtime.sensor import ArrivalSensor -NOW = datetime(2024, 3, 17, 23, 0, 0).replace(tzinfo=timezone.utc) +async def test_setup_sensors(hass: HomeAssistant, entry_v2_nodialout: MockConfigEntry): + """Test setting ups sensors in integration.""" + with ( + patch( + "custom_components.gtfs_realtime.coordinator.GtfsRealtimeCoordinator._async_update_data", + new_callable=AsyncMock, + ), + patch( + "custom_components.gtfs_realtime.coordinator.GtfsRealtimeCoordinator._async_update_static_data", + new_callable=AsyncMock, + return_value=( + Calendar(), + StationStopInfoDatabase(), + TripInfoDatabase(), + RouteInfoDatabase(), + ), + ), + ): + entry_v2_nodialout.add_to_hass(hass) + assert await hass.config_entries.async_setup(entry_v2_nodialout.entry_id) + await hass.async_block_till_done() + assert hass.states.get("sensor.4_101n").state == STATE_UNKNOWN -@pytest.fixture -def arrival_sensor(hass: HomeAssistant) -> ArrivalSensor: - """Fixture for a basic arrival sensor.""" - feed_subject = FeedSubject([]) - station_stop = StationStop("STATION", feed_subject) - station_stop.arrivals = [ - Arrival((NOW + timedelta(minutes=24)).timestamp(), "A", "A_trip"), - Arrival((NOW + timedelta(minutes=36)).timestamp(), "B", "B_trip"), - ] - async def noop(): - pass +@pytest.mark.skip("Find out how to mock coordinator updates") +async def test_sensor_update( + hass: HomeAssistant, + freezer: FrozenDateTimeFactory, + entry_v2_nodialout: MockConfigEntry, +): + """Test the sensors update every minute with realtime data.""" - arrival_sensor = ArrivalSensor( - GtfsRealtimeCoordinator(hass, feed_subject), station_stop, 0 - ) - arrival_sensor.async_write_ha_state = noop - return arrival_sensor + def coordinator_update_side_effects(): + coordinator = hass.config_entries.async_get_known_entry( + entry_v2_nodialout.entry_id + ) + for x in range(4): + arrivals = { + "101N": [ + Arrival(datetime.now() + timedelta(minutes=4 - x), "A", ""), + Arrival(datetime.now() + timedelta(minutes=6 - x), "B", ""), + Arrival(datetime.now() + timedelta(minutes=8 - x), "C", ""), + ], + "102S": [ + Arrival(datetime.now() + timedelta(minutes=9 - x), "X", ""), + Arrival(datetime.now() + timedelta(minutes=13 - x), "Y", ""), + Arrival(datetime.now() + timedelta(minutes=17 - x), "Z", ""), + ], + } + for id, stop in coordinator.station_stops.items(): + stop.arrivals = next(arrivals)[id] + yield + pytest.fail("Tests should not call the update more than 4 times") + with ( + patch( + "custom_components.gtfs_realtime.coordinator.GtfsRealtimeCoordinator._async_update_data", + new_callable=AsyncMock, + side_effect=coordinator_update_side_effects, + ), + patch( + "custom_components.gtfs_realtime.coordinator.GtfsRealtimeCoordinator._async_update_static_data", + new_callable=AsyncMock, + return_value=( + Calendar(), + StationStopInfoDatabase(), + TripInfoDatabase(), + RouteInfoDatabase(), + ), + ), + ): + entry_v2_nodialout.add_to_hass(hass) + assert await hass.config_entries.async_setup(entry_v2_nodialout.entry_id) + await hass.async_block_till_done() -def test_create_entity(arrival_sensor): - """Tests entity construction.""" - # Created by the fixture - assert arrival_sensor.state is None - assert arrival_sensor.name == "1: STATION" + coordinator: GtfsRealtimeCoordinator = entry_v2_nodialout.runtime_data + await coordinator.async_refresh() + freezer.tick(timedelta(minutes=1)) + async_fire_time_changed(hass) + await hass.async_block_till_done() + # Find out why sensor state is not updating + assert hass.states.get("sensor.1_101n").state == 4 + # assert hass.states.get("sensor.2_101n").state == 6 + # assert hass.states.get("sensor.3_101n").state == 8 + assert hass.states.get("sensor.4_101n").state == STATE_UNKNOWN + # assert hass.states.get("sensor.1_102s").state == 9 + # assert hass.states.get("sensor.2_102s").state == 13 + # assert hass.states.get("sensor.3_102s").state == 17 + assert hass.states.get("sensor.4_102s").state == STATE_UNKNOWN - -@freeze_time(NOW) -def test_update(arrival_sensor): - """ - Tests calling the update method on the sensor. - - This will latch the data in station_stop into the hass platform. - """ - arrival_sensor.update() - assert arrival_sensor.state == pytest.approx(24) - - -@freeze_time(NOW) -def test_update_trip_info_not_found(arrival_sensor): - """Tests that missing trip info still updates the state.""" - arrival_sensor.trip_info_db = TripInfoDatabase() - arrival_sensor.update() + freezer.tick(timedelta(minutes=10)) + async_fire_time_changed(hass) + await hass.async_block_till_done() + # assert hass.states.get("sensor.1_101n").state == 3 + # assert hass.states.get("sensor.2_101n").state == 5 + # assert hass.states.get("sensor.3_101n").state == 7 + assert hass.states.get("sensor.4_101n").state == STATE_UNKNOWN + # assert hass.states.get("sensor.1_102s").state == 8 + # assert hass.states.get("sensor.2_102s").state == 12 + # assert hass.states.get("sensor.3_102s").state == 16 + assert hass.states.get("sensor.4_102s").state == STATE_UNKNOWN