Skip to content

Commit

Permalink
PYTHON-4731 - Explicitly close all MongoClients opened during tests (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
NoahStapp authored Sep 17, 2024
1 parent fb51c11 commit 7395102
Show file tree
Hide file tree
Showing 73 changed files with 1,608 additions and 1,520 deletions.
1 change: 0 additions & 1 deletion pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,6 @@ def __del__(self) -> None:
),
ResourceWarning,
stacklevel=2,
source=self,
)
except AttributeError:
pass
Expand Down
1 change: 0 additions & 1 deletion pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,6 @@ def __del__(self) -> None:
),
ResourceWarning,
stacklevel=2,
source=self,
)
except AttributeError:
pass
Expand Down
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,6 @@ filterwarnings = [
"module:please use dns.resolver.Resolver.resolve:DeprecationWarning",
# https://github.com/dateutil/dateutil/issues/1314
"module:datetime.datetime.utc:DeprecationWarning:dateutil",
# TODO: Remove both of these in https://jira.mongodb.org/browse/PYTHON-4731
"ignore:Unclosed AsyncMongoClient*",
"ignore:Unclosed MongoClient*",
]
markers = [
"auth_aws: tests that rely on pymongo-auth-aws",
Expand Down
189 changes: 180 additions & 9 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from __future__ import annotations

import asyncio
import base64
import contextlib
import gc
import multiprocessing
import os
Expand All @@ -27,7 +25,6 @@
import sys
import threading
import time
import traceback
import unittest
import warnings
from asyncio import iscoroutinefunction
Expand All @@ -54,6 +51,8 @@
sanitize_reply,
)

from pymongo.uri_parser import parse_uri

try:
import ipaddress

Expand All @@ -80,6 +79,12 @@
_IS_SYNC = True


def _connection_string(h):
if h.startswith(("mongodb://", "mongodb+srv://")):
return h
return f"mongodb://{h!s}"


class ClientContext:
client: MongoClient

Expand Down Expand Up @@ -230,6 +235,9 @@ def _init_client(self):
if not self._check_user_provided():
_create_user(self.client.admin, db_user, db_pwd)

if self.client:
self.client.close()

self.client = self._connect(
host,
port,
Expand All @@ -256,6 +264,8 @@ def _init_client(self):
if "setName" in hello:
self.replica_set_name = str(hello["setName"])
self.is_rs = True
if self.client:
self.client.close()
if self.auth_enabled:
# It doesn't matter which member we use as the seed here.
self.client = pymongo.MongoClient(
Expand Down Expand Up @@ -318,6 +328,7 @@ def _init_client(self):
hello = mongos_client.admin.command(HelloCompat.LEGACY_CMD)
if hello.get("msg") == "isdbgrid":
self.mongoses.append(next_address)
mongos_client.close()

def init(self):
with self.conn_lock:
Expand Down Expand Up @@ -537,12 +548,6 @@ def require_auth(self, func):
lambda: self.auth_enabled, "Authentication is not enabled on the server", func=func
)

def require_no_fips(self, func):
"""Run a test only if the host does not have FIPS enabled."""
return self._require(
lambda: not self.fips_enabled, "Test cannot run on a FIPS-enabled host", func=func
)

def require_no_auth(self, func):
"""Run a test only if the server is running without auth enabled."""
return self._require(
Expand Down Expand Up @@ -930,6 +935,172 @@ def _target() -> None:
self.fail(f"child timed out after {timeout}s (see traceback in logs): deadlock?")
self.assertEqual(proc.exitcode, 0)

@classmethod
def _unmanaged_async_mongo_client(
cls, host, port, authenticate=True, directConnection=None, **kwargs
):
"""Create a new client over SSL/TLS if necessary."""
host = host or client_context.host
port = port or client_context.port
client_options: dict = client_context.default_client_options.copy()
if client_context.replica_set_name and not directConnection:
client_options["replicaSet"] = client_context.replica_set_name
if directConnection is not None:
client_options["directConnection"] = directConnection
client_options.update(kwargs)

uri = _connection_string(host)
auth_mech = kwargs.get("authMechanism", "")
if client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
# Only add the default username or password if one is not provided.
res = parse_uri(uri)
if (
not res["username"]
and not res["password"]
and "username" not in client_options
and "password" not in client_options
):
client_options["username"] = db_user
client_options["password"] = db_pwd
client = MongoClient(uri, port, **client_options)
if client._options.connect:
client._connect()
return client

def _async_mongo_client(self, host, port, authenticate=True, directConnection=None, **kwargs):
"""Create a new client over SSL/TLS if necessary."""
host = host or client_context.host
port = port or client_context.port
client_options: dict = client_context.default_client_options.copy()
if client_context.replica_set_name and not directConnection:
client_options["replicaSet"] = client_context.replica_set_name
if directConnection is not None:
client_options["directConnection"] = directConnection
client_options.update(kwargs)

uri = _connection_string(host)
auth_mech = kwargs.get("authMechanism", "")
if client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
# Only add the default username or password if one is not provided.
res = parse_uri(uri)
if (
not res["username"]
and not res["password"]
and "username" not in client_options
and "password" not in client_options
):
client_options["username"] = db_user
client_options["password"] = db_pwd
client = MongoClient(uri, port, **client_options)
if client._options.connect:
client._connect()
self.addCleanup(client.close)
return client

@classmethod
def unmanaged_single_client_noauth(
cls, h: Any = None, p: Any = None, **kwargs: Any
) -> MongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return cls._unmanaged_async_mongo_client(
h, p, authenticate=False, directConnection=True, **kwargs
)

@classmethod
def unmanaged_single_client(
cls, h: Any = None, p: Any = None, **kwargs: Any
) -> MongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return cls._unmanaged_async_mongo_client(h, p, directConnection=True, **kwargs)

@classmethod
def unmanaged_rs_client(cls, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
"""Connect to the replica set and authenticate if necessary."""
return cls._unmanaged_async_mongo_client(h, p, **kwargs)

@classmethod
def unmanaged_rs_client_noauth(
cls, h: Any = None, p: Any = None, **kwargs: Any
) -> MongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs)

@classmethod
def unmanaged_rs_or_single_client_noauth(
cls, h: Any = None, p: Any = None, **kwargs: Any
) -> MongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs)

@classmethod
def unmanaged_rs_or_single_client(
cls, h: Any = None, p: Any = None, **kwargs: Any
) -> MongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return cls._unmanaged_async_mongo_client(h, p, **kwargs)

def single_client_noauth(
self, h: Any = None, p: Any = None, **kwargs: Any
) -> MongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return self._async_mongo_client(h, p, authenticate=False, directConnection=True, **kwargs)

def single_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
"""Make a direct connection, and authenticate if necessary."""
return self._async_mongo_client(h, p, directConnection=True, **kwargs)

def rs_client_noauth(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
"""Connect to the replica set. Don't authenticate."""
return self._async_mongo_client(h, p, authenticate=False, **kwargs)

def rs_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
"""Connect to the replica set and authenticate if necessary."""
return self._async_mongo_client(h, p, **kwargs)

def rs_or_single_client_noauth(
self, h: Any = None, p: Any = None, **kwargs: Any
) -> MongoClient[dict]:
"""Connect to the replica set if there is one, otherwise the standalone.
Like rs_or_single_client, but does not authenticate.
"""
return self._async_mongo_client(h, p, authenticate=False, **kwargs)

def rs_or_single_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[Any]:
"""Connect to the replica set if there is one, otherwise the standalone.
Authenticates if necessary.
"""
return self._async_mongo_client(h, p, **kwargs)

def simple_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient:
if not h and not p:
client = MongoClient(**kwargs)
else:
client = MongoClient(h, p, **kwargs)
self.addCleanup(client.close)
return client

@classmethod
def unmanaged_simple_client(cls, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient:
if not h and not p:
client = MongoClient(**kwargs)
else:
client = MongoClient(h, p, **kwargs)
return client

def disable_replication(self, client):
"""Disable replication on all secondaries."""
for h, p in client.secondaries:
secondary = self.single_client(h, p)
secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn")

def enable_replication(self, client):
"""Enable replication on all secondaries."""
for h, p in client.secondaries:
secondary = self.single_client(h, p)
secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off")


class UnitTest(PyMongoTestCase):
"""Async base class for TestCases that don't require a connection to MongoDB."""
Expand Down
Loading

0 comments on commit 7395102

Please sign in to comment.