Skip to content

Commit

Permalink
Merge pull request #469 from canton7/feature/custom-serial
Browse files Browse the repository at this point in the history
Work around bug in pyserial's PosixPollSerial
  • Loading branch information
canton7 authored Nov 29, 2023
2 parents f652c2a + 8a83906 commit 171b359
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 124 deletions.
2 changes: 1 addition & 1 deletion custom_components/foxess_modbus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from homeassistant.helpers.typing import UNDEFINED
from slugify import slugify

from .client.modbus_client import ModbusClient
from .const import ADAPTER_ID
from .const import ADAPTER_WAS_MIGRATED
from .const import CONFIG_SAVE_TIME
Expand All @@ -40,7 +41,6 @@
from .const import UNIQUE_ID_PREFIX
from .inverter_adapters import ADAPTERS
from .inverter_profiles import inverter_connection_type_profile_from_config
from .modbus_client import ModbusClient
from .modbus_controller import ModbusController
from .services import update_charge_period_service
from .services import websocket_api
Expand Down
113 changes: 113 additions & 0 deletions custom_components/foxess_modbus/client/custom_modbus_tcp_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import logging
import select
import socket
import time
from typing import Any
from typing import cast

from pymodbus.client import ModbusTcpClient
from pymodbus.exceptions import ConnectionException

_LOGGER = logging.getLogger(__name__)


class CustomModbusTcpClient(ModbusTcpClient):
"""Custom ModbusTcpClient subclass with some hacks"""

def __init__(self, delay_on_connect: int | None, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._delay_on_connect = delay_on_connect

def connect(self) -> bool:
was_connected = self.socket is not None
if not was_connected:
_LOGGER.debug("Connecting to %s", self.params)
is_connected = cast(bool, super().connect())
# pymodbus doesn't disable Nagle's algorithm. This slows down reads quite substantially as the
# TCP stack waits to see if we're going to send anything else. Disable it ourselves.
if not was_connected and is_connected:
assert self.socket is not None
self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
if self._delay_on_connect is not None:
time.sleep(self._delay_on_connect)
return is_connected

# Replacement of ModbusTcpClient to use poll rather than select, see
# https://github.com/nathanmarlor/foxess_modbus/issues/275
def recv(self, size: int) -> bytes:
"""Read data from the underlying descriptor."""
super(ModbusTcpClient, self).recv(size)
if not self.socket:
raise ConnectionException(str(self))

# socket.recv(size) waits until it gets some data from the host but
# not necessarily the entire response that can be fragmented in
# many packets.
# To avoid split responses to be recognized as invalid
# messages and to be discarded, loops socket.recv until full data
# is received or timeout is expired.
# If timeout expires returns the read data, also if its length is
# less than the expected size.
self.socket.setblocking(0)

# In the base method this is 'timeout = self.comm_params.timeout', but that changed from 'self.params.timeout'
# in 3.4.1. So we don't have a consistent way to access the timeout.
# However, this just mirrors what we set, which is the default of 3s. So use that.
# Annoyingly 3.4.1
timeout = 3

# If size isn't specified read up to 4096 bytes at a time.
if size is None:
recv_size = 4096
else:
recv_size = size

data: list[bytes] = []
data_length = 0
time_ = time.time()
end = time_ + timeout
poll = select.poll()
# We don't need to call poll.unregister, since we're deallocing the poll. register just adds the socket to a
# dict owned by the poll object (the underlying syscall has no concept of register/unregister, and just takes an
# array of fds to poll). If we hit a disconnection the socket.fileno() becomes -1 anyway, so unregistering will
# fail
poll.register(self.socket, select.POLLIN)
while recv_size > 0:
poll_res = poll.poll(end - time_)
# We expect a single-element list if this succeeds, or an empty list if it timed out
if len(poll_res) > 0:
if (recv_data := self.socket.recv(recv_size)) == b"":
return self._handle_abrupt_socket_close( # type: ignore[no-any-return]
size, data, time.time() - time_
)
data.append(recv_data)
data_length += len(recv_data)
time_ = time.time()

# If size isn't specified continue to read until timeout expires.
if size:
recv_size = size - data_length

# Timeout is reduced also if some data has been received in order
# to avoid infinite loops when there isn't an expected response
# size and the slave sends noisy data continuously.
if time_ > end:
break

return b"".join(data)

# Replacement of ModbusTcpClient to use poll rather than select, see
# https://github.com/nathanmarlor/foxess_modbus/issues/275
def _check_read_buffer(self) -> bytes | None:
"""Check read buffer."""
time_ = time.time()
end = time_ + self.params.timeout
data = None

assert self.socket is not None
poll = select.poll()
poll.register(self.socket, select.POLLIN)
poll_res = poll.poll(end - time_)
if len(poll_res) > 0:
data = self.socket.recv(1024)
return data
135 changes: 17 additions & 118 deletions ...components/foxess_modbus/modbus_client.py → ...nts/foxess_modbus/client/modbus_client.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,16 @@
import asyncio
import logging
import os
import select
import socket
import time
from typing import Any
from typing import Callable
from typing import Type
from typing import TypeVar
from typing import cast

import serial
from homeassistant.core import HomeAssistant
from pymodbus.client import ModbusSerialClient
from pymodbus.client import ModbusTcpClient
from pymodbus.client import ModbusUdpClient
from pymodbus.exceptions import ConnectionException
from pymodbus.pdu import ModbusResponse
from pymodbus.register_read_message import ReadHoldingRegistersResponse
from pymodbus.register_read_message import ReadInputRegistersResponse
Expand All @@ -24,121 +20,21 @@
from pymodbus.transaction import ModbusRtuFramer
from pymodbus.transaction import ModbusSocketFramer

from .common.register_type import RegisterType
from .const import LAN
from .const import RTU_OVER_TCP
from .const import SERIAL
from .const import TCP
from .const import UDP
from .inverter_adapters import InverterAdapter
from .. import client
from ..common.register_type import RegisterType
from ..const import LAN
from ..const import RTU_OVER_TCP
from ..const import SERIAL
from ..const import TCP
from ..const import UDP
from ..inverter_adapters import InverterAdapter
from .custom_modbus_tcp_client import CustomModbusTcpClient

_LOGGER = logging.getLogger(__name__)

T = TypeVar("T")


class CustomModbusTcpClient(ModbusTcpClient):
"""Custom ModbusTcpClient subclass with some hacks"""

def __init__(self, delay_on_connect: int | None, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._delay_on_connect = delay_on_connect

def connect(self) -> bool:
was_connected = self.socket is not None
if not was_connected:
_LOGGER.debug("Connecting to %s", self.params)
is_connected = cast(bool, super().connect())
# pymodbus doesn't disable Nagle's algorithm. This slows down reads quite substantially as the
# TCP stack waits to see if we're going to send anything else. Disable it ourselves.
if not was_connected and is_connected:
assert self.socket is not None
self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
if self._delay_on_connect is not None:
time.sleep(self._delay_on_connect)
return is_connected

# Replacement of ModbusTcpClient to use poll rather than select, see
# https://github.com/nathanmarlor/foxess_modbus/issues/275
def recv(self, size: int) -> bytes:
"""Read data from the underlying descriptor."""
super(ModbusTcpClient, self).recv(size)
if not self.socket:
raise ConnectionException(str(self))

# socket.recv(size) waits until it gets some data from the host but
# not necessarily the entire response that can be fragmented in
# many packets.
# To avoid split responses to be recognized as invalid
# messages and to be discarded, loops socket.recv until full data
# is received or timeout is expired.
# If timeout expires returns the read data, also if its length is
# less than the expected size.
self.socket.setblocking(0)

# In the base method this is 'timeout = self.comm_params.timeout', but that changed from 'self.params.timeout'
# in 3.4.1. So we don't have a consistent way to access the timeout.
# However, this just mirrors what we set, which is the default of 3s. So use that.
# Annoyingly 3.4.1
timeout = 3

# If size isn't specified read up to 4096 bytes at a time.
if size is None:
recv_size = 4096
else:
recv_size = size

data: list[bytes] = []
data_length = 0
time_ = time.time()
end = time_ + timeout
poll = select.poll()
# We don't need to call poll.unregister, since we're deallocing the poll. register just adds the socket to a
# dict owned by the poll object (the underlying syscall has no concept of register/unregister, and just takes an
# array of fds to poll). If we hit a disconnection the socket.fileno() becomes -1 anyway, so unregistering will
# fail
poll.register(self.socket, select.POLLIN)
while recv_size > 0:
poll_res = poll.poll(end - time_)
# We expect a single-element list if this succeeds, or an empty list if it timed out
if len(poll_res) > 0:
if (recv_data := self.socket.recv(recv_size)) == b"":
return self._handle_abrupt_socket_close( # type: ignore[no-any-return]
size, data, time.time() - time_
)
data.append(recv_data)
data_length += len(recv_data)
time_ = time.time()

# If size isn't specified continue to read until timeout expires.
if size:
recv_size = size - data_length

# Timeout is reduced also if some data has been received in order
# to avoid infinite loops when there isn't an expected response
# size and the slave sends noisy data continuously.
if time_ > end:
break

return b"".join(data)

# Replacement of ModbusTcpClient to use poll rather than select, see
# https://github.com/nathanmarlor/foxess_modbus/issues/275
def _check_read_buffer(self) -> bytes | None:
"""Check read buffer."""
time_ = time.time()
end = time_ + self.params.timeout
data = None

assert self.socket is not None
poll = select.poll()
poll.register(self.socket, select.POLLIN)
poll_res = poll.poll(end - time_)
if len(poll_res) > 0:
data = self.socket.recv(1024)
return data


_CLIENTS: dict[str, dict[str, Any]] = {
SERIAL: {
"client": ModbusSerialClient,
Expand All @@ -158,6 +54,8 @@ def _check_read_buffer(self) -> bytes | None:
},
}

serial.protocol_handler_packages.append(client.__name__)


class ModbusClient:
"""Modbus"""
Expand All @@ -179,12 +77,13 @@ def __init__(self, hass: HomeAssistant, protocol: str, adapter: InverterAdapter,
"delay_on_connect": 1 if adapter.connection_type == LAN else None,
}

# If PosixPollSerial is supported, use that. This uses poll rather than select, which means we don't break when
# there are more than 1024 fds. See #457.
# If our custom PosixPollSerial hack is supported, use that. This uses poll rather than select, which means we
# don't break when there are more than 1024 fds. See #457.
# Only supported on posix, see https://github.com/pyserial/pyserial/blob/7aeea35429d15f3eefed10bbb659674638903e3a/serial/__init__.py#L31
# This ties into the call to serial.protocol_handler_packages.append above, and means that pyserial will find
# our protocol_pollserial module, and the Serial class inside, when we use the prefix pollserial://
if protocol == SERIAL and os.name == "posix":
# https://pyserial.readthedocs.io/en/latest/url_handlers.html#alt
config["port"] = f"alt://{config['port']}?class=PosixPollSerial"
config["port"] = f"pollserial://{config['port']}"

# Some serial devices need a short delay after polling. Also do this for the inverter, just
# in case it helps.
Expand Down
76 changes: 76 additions & 0 deletions custom_components/foxess_modbus/client/protocol_pollserial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
Custom protocol handler for pyserial, which uses poll but doesn't have
https://github.com/pyserial/pyserial/issues/617
"""

import os
import select
from enum import Enum

import serial
from serial import serialposix
from serial.serialutil import PortNotOpenError
from serial.serialutil import SerialException
from serial.serialutil import Timeout


class _PollResult(Enum):
TIMEOUT = 0
ABORT = 1
DATA = 2


class Serial(serialposix.Serial):
"""
From https://github.com/pyserial/pyserial/blob/7aeea35429d15f3eefed10bbb659674638903e3a/serial/serialposix.py,
but with https://github.com/pyserial/pyserial/pull/618 applied
"""

@serial.Serial.port.setter # type: ignore
def port(self, value: str) -> None:
if value is not None:
serial.Serial.port.__set__(self, value.removeprefix("pollserial://"))

def read(self, size: int = 1) -> bytes:
"""\
Read size bytes from the serial port. If a timeout is set it may
return less characters as requested. With no timeout it will block
until the requested number of bytes is read.
"""
if not self.is_open:
raise PortNotOpenError()
read = bytearray()
timeout = Timeout(self._timeout)
poll = select.poll()
poll.register(self.fd, select.POLLIN | select.POLLERR | select.POLLHUP | select.POLLNVAL)
poll.register(self.pipe_abort_read_r, select.POLLIN | select.POLLERR | select.POLLHUP | select.POLLNVAL)
if size > 0:
while len(read) < size:
# wait until device becomes ready to read (or something fails)
result = _PollResult.TIMEOUT # In case poll returns an empty list
for fd, event in poll.poll(None if timeout.is_infinite else (timeout.time_left() * 1000)):
if fd == self.pipe_abort_read_r:
os.read(self.pipe_abort_read_r, 1000)
result = _PollResult.ABORT
break
if event & (select.POLLERR | select.POLLHUP | select.POLLNVAL):
raise SerialException("device reports error (poll)")
result = _PollResult.DATA

if result == _PollResult.DATA:
buf = os.read(self.fd, size - len(read))
read.extend(buf)
if (
result == _PollResult.TIMEOUT
or result == _PollResult.ABORT
or timeout.expired()
or (self._inter_byte_timeout is not None and self._inter_byte_timeout > 0)
and not buf
):
break # early abort on timeout
return bytes(read)


# This needs to have a very particular name, as it's registered by string in modbus_client
assert Serial.__module__ == "custom_components.foxess_modbus.client.protocol_pollserial"
assert Serial.__name__ == "Serial"
Loading

0 comments on commit 171b359

Please sign in to comment.