Skip to content

Commit

Permalink
#20 Implement rule action feature (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
tuanvumaihuynh authored Oct 22, 2024
1 parent 7945187 commit 3a2cc19
Show file tree
Hide file tree
Showing 43 changed files with 1,910 additions and 36 deletions.
4 changes: 2 additions & 2 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ services:
- VIOT_REDIS_PORT=6379
- VIOT_AUTH_JWT_SECRET=secret
- VIOT_EMQX_API_URL=http://node1.viot:18083/api/v5
- VIOT_EMQX_API_KEY=dev1
- VIOT_EMQX_SECRET_KEY=dev1
- VIOT_EMQX_API_KEY=ABC1234
- VIOT_EMQX_SECRET_KEY=ec7122ff-c6c6-4b6b-91eb-5593d7b09437
- VIOT_EMQX_MQTT_WHITELIST_FILE_PATH=./mqtt_whitelist
- VIOT_CELERY_REDIS_SERVER=redis
- VIOT_CELERY_REDIS_PORT=6379
Expand Down
2 changes: 2 additions & 0 deletions viot/app/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def setup_modules() -> None:
from app.module.device_data.module import DeviceDataModule
from app.module.email.module import EmailModule
from app.module.emqx.module import EmqxModule
from app.module.rule_action.module import RuleActionModule
from app.module.team.module import TeamModule

injector.binder.install(DatabaseModule)
Expand All @@ -68,6 +69,7 @@ def setup_modules() -> None:
injector.binder.install(DeviceModule)
injector.binder.install(DeviceDataModule)
injector.binder.install(EmqxModule)
injector.binder.install(RuleActionModule)


def register_middleware(app: FastAPI) -> None:
Expand Down
2 changes: 2 additions & 0 deletions viot/app/common/exception/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base import (
BadRequestException,
InternalServerException,
NotFoundException,
PermissionDeniedException,
UnauthorizedException,
Expand All @@ -12,4 +13,5 @@
"PermissionDeniedException",
"UnauthorizedException",
"ViotException",
"InternalServerException",
]
10 changes: 5 additions & 5 deletions viot/app/common/exception/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ class ViotException(Exception):
pass


class ViotHttpException(ViotException):
class InternalServerException(ViotException):
STATUS_CODE = 500

def __init__(
Expand All @@ -20,7 +20,7 @@ def __init__(
super().__init__(message)


class BadRequestException(ViotHttpException):
class BadRequestException(InternalServerException):
STATUS_CODE = 400

def __init__(
Expand All @@ -29,7 +29,7 @@ def __init__(
super().__init__(code=code, message=message)


class UnauthorizedException(ViotHttpException):
class UnauthorizedException(InternalServerException):
STATUS_CODE = 401

def __init__(
Expand All @@ -42,7 +42,7 @@ def __init__(
)


class PermissionDeniedException(ViotHttpException):
class PermissionDeniedException(InternalServerException):
STATUS_CODE = 403

def __init__(
Expand All @@ -51,7 +51,7 @@ def __init__(
super().__init__(code=code, message=message)


class NotFoundException(ViotHttpException):
class NotFoundException(InternalServerException):
STATUS_CODE = 404

def __init__(
Expand Down
10 changes: 5 additions & 5 deletions viot/app/common/exception/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi.exceptions import RequestValidationError

from app.common.dto import ErrorDto
from app.common.exception.base import ViotHttpException
from app.common.exception.base import InternalServerException
from app.common.exception.constant import MessageError
from app.common.fastapi.serializer import JSONResponse
from app.config import app_settings
Expand All @@ -18,7 +18,7 @@ def register_exception_handlers(app: FastAPI) -> None:
async def handle_request_validation_error( # type: ignore
request: Request, exc: RequestValidationError
) -> JSONResponse[ErrorDto]:
logger.debug(f"Validation error: {exc.errors()}")
logger.info(f"Validation error: {exc.errors()}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=ErrorDto(
Expand All @@ -28,11 +28,11 @@ async def handle_request_validation_error( # type: ignore
),
)

@app.exception_handler(ViotHttpException)
@app.exception_handler(InternalServerException)
async def handle_viot_http_exception( # type: ignore
request: Request, exc: ViotHttpException
request: Request, exc: InternalServerException
) -> JSONResponse[ErrorDto]:
logger.debug(f"Viot HTTP exception: {exc.code} - {exc.message}")
logger.info(f"HTTP exception: {exc.code} - {exc.message}")
return JSONResponse(
status_code=exc.STATUS_CODE,
content=ErrorDto(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""rule action
Revision ID: d39146d49608
Revises: e4598ec5a65e
Create Date: 2024-10-17 16:13:33.806797
"""

from collections.abc import Sequence

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision: str = "d39146d49608"
down_revision: str | None = "e4598ec5a65e"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"rules",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("name", sa.TEXT(), nullable=False),
sa.Column("description", sa.TEXT(), nullable=True),
sa.Column("enable", sa.Boolean(), nullable=False),
sa.Column("event_type", sa.SMALLINT(), nullable=False),
sa.Column("sql", sa.TEXT(), nullable=False),
sa.Column("topic", sa.TEXT(), nullable=False),
sa.Column(
"condition", postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), nullable=False
),
sa.Column("device_id", sa.UUID(), nullable=False),
sa.Column("team_id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["device_id"], ["devices.id"], onupdate="CASCADE", ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["team_id"], ["teams.id"], onupdate="CASCADE", ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"actions",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("name", sa.TEXT(), nullable=False),
sa.Column("description", sa.TEXT(), nullable=True),
sa.Column("action_type", sa.SMALLINT(), nullable=False),
sa.Column(
"config", postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), nullable=False
),
sa.Column("team_id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(["team_id"], ["teams.id"], onupdate="CASCADE", ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"rule_actions",
sa.Column("rule_id", sa.UUID(), nullable=False),
sa.Column("action_id", sa.UUID(), nullable=False),
sa.ForeignKeyConstraint(
["action_id"], ["actions.id"], onupdate="CASCADE", ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["rule_id"], ["rules.id"], onupdate="CASCADE", ondelete="CASCADE"),
sa.PrimaryKeyConstraint("rule_id", "action_id"),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("rule_actions")
op.drop_table("actions")
op.drop_table("rules")
# ### end Alembic commands ###
10 changes: 8 additions & 2 deletions viot/app/database/repository/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ class Page(msgspec.Struct, Generic[TModel]):


class PageableRepository(CrudRepository[TBaseModel, TPrimaryKey]):
async def find_all_with_paging(self, pageable: Pageable) -> Page[TBaseModel]:
async def find_all_with_paging(
self, pageable: Pageable, use_unique: bool = False
) -> Page[TBaseModel]:
# Base query for selecting items
base_query = select(self._model)

Expand All @@ -138,7 +140,11 @@ async def find_all_with_paging(self, pageable: Pageable) -> Page[TBaseModel]:
)

# Execute queries
items = (await self.session.execute(query)).scalars().all()
result = await self.session.execute(query)
if use_unique:
items = result.unique().scalars().all()
else:
items = result.scalars().all()
total_items = (await self.session.execute(count_query)).scalar_one()

return Page(
Expand Down
6 changes: 6 additions & 0 deletions viot/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from app.module.device_data.model.device_attribute import DeviceAttribute
from app.module.device_data.model.device_data import DeviceData
from app.module.device_data.model.device_data_latest import DeviceDataLatest
from app.module.rule_action.model.action import Action
from app.module.rule_action.model.rule import Rule
from app.module.rule_action.model.rule_action import RuleAction
from app.module.team.model.team import Team
from app.module.team.model.team_invitation import TeamInvitation

Expand All @@ -30,4 +33,7 @@
"DeviceAttribute",
"DeviceData",
"DeviceDataLatest",
"Rule",
"Action",
"RuleAction",
]
7 changes: 7 additions & 0 deletions viot/app/module/emqx/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from functools import lru_cache

from httpx import BasicAuth
from pydantic import computed_field
from pydantic_settings import BaseSettings, SettingsConfigDict


Expand All @@ -14,6 +16,11 @@ class EmqxSettings(BaseSettings):

MQTT_WHITELIST_FILE_PATH: str = ""

@computed_field # type: ignore
@property
def BASIC_AUTH(self) -> BasicAuth:
return BasicAuth(username=self.API_KEY, password=self.SECRET_KEY)


@lru_cache
def get_emqx_settings() -> EmqxSettings:
Expand Down
29 changes: 29 additions & 0 deletions viot/app/module/emqx/dto/emqx_rule_dto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Literal

from pydantic import BaseModel


class RepublishArgsDto(BaseModel):
topic: str
payload: str
qos: int
retain: bool
direct_dispatch: bool


class EmqxActionDto(BaseModel):
function: Literal["republish"]
args: RepublishArgsDto


class EmqxCreateRuleDto(BaseModel):
id: str
sql: str
actions: list[EmqxActionDto]
enable: bool


class EmqxUpdateRuleDto(BaseModel):
sql: str
actions: list[EmqxActionDto]
enable: bool
2 changes: 2 additions & 0 deletions viot/app/module/emqx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .controller.emqx_device_controller import EmqxDeviceController
from .service.emqx_auth_service import EmqxDeviceAuthService
from .service.emqx_event_service import EmqxEventService
from .service.emqx_rule_service import EmqxRuleService
from .service.mqtt_whitelist_service import MqttWhitelistService


Expand All @@ -11,5 +12,6 @@ def configure(self, binder: Binder) -> None:
binder.bind(EmqxDeviceAuthService, to=EmqxDeviceAuthService, scope=SingletonScope)
binder.bind(EmqxEventService, to=EmqxEventService, scope=SingletonScope)
binder.bind(MqttWhitelistService, to=MqttWhitelistService, scope=SingletonScope)
binder.bind(EmqxRuleService, to=EmqxRuleService, scope=SingletonScope)

binder.bind(EmqxDeviceController, to=EmqxDeviceController, scope=SingletonScope)
38 changes: 32 additions & 6 deletions viot/app/module/emqx/service/emqx_auth_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@

from injector import inject

from app.module.device.constants import DeviceStatus
from app.module.device.constants import DeviceStatus, DeviceType
from app.module.device.exception.device_exception import DeviceNotFoundException
from app.module.device.model.device import Device
from app.module.device.repository.device_repository import DeviceRepository
from app.module.device_data.constants import ConnectStatus
from app.module.device_data.model.connect_log import ConnectLog
from app.module.device_data.repository.connect_log_repository import ConnectLogRepository
from app.module.rule_action.constants import (
MQTT_DEVICE_ATTRIBUTES_TOPIC,
MQTT_DEVICE_DATA_TOPIC,
MQTT_SUB_DEVICE_ATTRIBUTES_TOPIC,
MQTT_SUB_DEVICE_DATA_TOPIC,
)

from ..dto.emqx_auth_dto import EmqxAuthenRequestDto, EmqxAuthenResponseDto
from ..exception.emqx_auth_exception import DeviceCredentialException, DeviceDisabledException
Expand Down Expand Up @@ -64,12 +70,9 @@ async def authenticate(self, *, request_dto: EmqxAuthenRequestDto) -> EmqxAuthen
device.last_connection = last_connection
device.status = DeviceStatus.ONLINE
connect_status = ConnectStatus.CONNECTED # Successful connection
device_acl = self._get_device_acl(device.device_type)

return EmqxAuthenResponseDto(
result="allow",
is_superuser=False,
acl=[],
)
return EmqxAuthenResponseDto(result="allow", is_superuser=False, acl=device_acl)

except (DeviceNotFoundException, DeviceCredentialException, DeviceDisabledException) as e:
# Re-raise exception
Expand All @@ -88,3 +91,26 @@ async def authenticate(self, *, request_dto: EmqxAuthenRequestDto) -> EmqxAuthen
ip=request_dto.ip_address,
)
)

def _get_device_acl(self, device_type: DeviceType) -> list[dict[str, str]]:
acls: list[dict[str, str]] = [
{"permission": "allow", "action": "publish", "topic": MQTT_DEVICE_DATA_TOPIC},
{"permission": "allow", "action": "publish", "topic": MQTT_DEVICE_ATTRIBUTES_TOPIC},
]
if device_type == DeviceType.GATEWAY:
acls.extend(
[
{
"permission": "allow",
"action": "publish",
"topic": MQTT_SUB_DEVICE_DATA_TOPIC,
},
{
"permission": "allow",
"action": "publish",
"topic": MQTT_SUB_DEVICE_ATTRIBUTES_TOPIC,
},
]
)

return acls
12 changes: 7 additions & 5 deletions viot/app/module/emqx/service/emqx_event_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from typing import Literal, TypedDict
from uuid import UUID

from httpx import AsyncClient, BasicAuth
from httpx import AsyncClient
from injector import inject

from app.common.exception import InternalServerException
from app.module.device_data.constants import ConnectStatus
from app.module.device_data.model.connect_log import ConnectLog
from app.module.device_data.repository.connect_log_repository import ConnectLogRepository
Expand Down Expand Up @@ -47,11 +48,12 @@ async def handle_device_disconnected(self, *, event: DeviceDisconnectedEventDto)
)

async def _subscribe_device_topics(self, device_id: UUID) -> None:
async with AsyncClient(
auth=BasicAuth(emqx_settings.API_KEY, emqx_settings.SECRET_KEY)
) as client:
async with AsyncClient(auth=emqx_settings.BASIC_AUTH) as client:
url: str = f"{emqx_settings.API_URL}/clients/{device_id}/subscribe/bulk"

topics: list[Subscription] = []

await client.post(url, json=topics)
result = await client.post(url, json=topics)

if result.status_code not in (200, 201):
raise InternalServerException(message="Error while subscribing to topics")
Loading

0 comments on commit 3a2cc19

Please sign in to comment.