diff --git a/changelog.d/12656.misc b/changelog.d/12656.misc new file mode 100644 index 000000000000..8a8743e614f3 --- /dev/null +++ b/changelog.d/12656.misc @@ -0,0 +1 @@ +Prevent memory leak from reoccurring when presence is disabled. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index d078162c2938..268481ec1963 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -659,27 +659,28 @@ def __init__(self, hs: "HomeServer"): ) now = self.clock.time_msec() - for state in self.user_to_current_state.values(): - self.wheel_timer.insert( - now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER - ) - self.wheel_timer.insert( - now=now, - obj=state.user_id, - then=state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT, - ) - if self.is_mine_id(state.user_id): + if self._presence_enabled: + for state in self.user_to_current_state.values(): self.wheel_timer.insert( - now=now, - obj=state.user_id, - then=state.last_federation_update_ts + FEDERATION_PING_INTERVAL, + now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER ) - else: self.wheel_timer.insert( now=now, obj=state.user_id, - then=state.last_federation_update_ts + FEDERATION_TIMEOUT, + then=state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT, ) + if self.is_mine_id(state.user_id): + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_federation_update_ts + FEDERATION_PING_INTERVAL, + ) + else: + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_federation_update_ts + FEDERATION_TIMEOUT, + ) # Set of users who have presence in the `user_to_current_state` that # have not yet been persisted @@ -804,6 +805,13 @@ async def _update_states( This is currently used to bump the max presence stream ID without changing any user's presence (see PresenceHandler.add_users_to_send_full_presence_to). """ + if not self._presence_enabled: + # We shouldn't get here if presence is disabled, but we check anyway + # to ensure that we don't a) send out presence federation and b) + # don't add things to the wheel timer that will never be handled. + logger.warning("Tried to update presence states when presence is disabled") + return + now = self.clock.time_msec() with Measure(self.clock, "presence_update_states"): @@ -1229,6 +1237,10 @@ async def set_state( ): raise SynapseError(400, "Invalid presence state") + # If presence is disabled, no-op + if not self.hs.config.server.use_presence: + return + user_id = target_user.to_string() prev_state = await self.current_state_for_user(user_id) diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index e108adc4604f..177e198e7e75 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -11,17 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Generic, List, TypeVar +import logging +from typing import Generic, Hashable, List, Set, TypeVar -T = TypeVar("T") +import attr +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=Hashable) -class _Entry(Generic[T]): - __slots__ = ["end_key", "queue"] - def __init__(self, end_key: int) -> None: - self.end_key: int = end_key - self.queue: List[T] = [] +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _Entry(Generic[T]): + end_key: int + elements: Set[T] = attr.Factory(set) class WheelTimer(Generic[T]): @@ -48,17 +51,27 @@ def insert(self, now: int, obj: T, then: int) -> None: then: When to return the object strictly after. """ then_key = int(then / self.bucket_size) + 1 + now_key = int(now / self.bucket_size) if self.entries: min_key = self.entries[0].end_key max_key = self.entries[-1].end_key + if min_key < now_key - 10: + # If we have ten buckets that are due and still nothing has + # called `fetch()` then we likely have a bug that is causing a + # memory leak. + logger.warning( + "Inserting into a wheel timer that hasn't been read from recently. Item: %s", + obj, + ) + if then_key <= max_key: # The max here is to protect against inserts for times in the past - self.entries[max(min_key, then_key) - min_key].queue.append(obj) + self.entries[max(min_key, then_key) - min_key].elements.add(obj) return - next_key = int(now / self.bucket_size) + 1 + next_key = now_key + 1 if self.entries: last_key = self.entries[-1].end_key else: @@ -71,7 +84,7 @@ def insert(self, now: int, obj: T, then: int) -> None: # to insert. This ensures there are no gaps. self.entries.extend(_Entry(key) for key in range(last_key, then_key + 1)) - self.entries[-1].queue.append(obj) + self.entries[-1].elements.add(obj) def fetch(self, now: int) -> List[T]: """Fetch any objects that have timed out @@ -84,11 +97,11 @@ def fetch(self, now: int) -> List[T]: """ now_key = int(now / self.bucket_size) - ret = [] + ret: List[T] = [] while self.entries and self.entries[0].end_key <= now_key: - ret.extend(self.entries.pop(0).queue) + ret.extend(self.entries.pop(0).elements) return ret def __len__(self) -> int: - return sum(len(entry.queue) for entry in self.entries) + return sum(len(entry.elements) for entry in self.entries)