Skip to content

Commit

Permalink
improve injection code, make backwards compat explicit, make ssl-api …
Browse files Browse the repository at this point in the history
…explicit (#268)

* refactor: make injection code more readable and make backwards-compat explicit
* refactor: move ssl socket-wrapping code to ssl/socket.py
* refactor: convert MocketSSLContext.wrap_socket and wrap_bio to instance-methods
* refactor: MocketSSLSocket use proper ssl-context instead of urllib3
  • Loading branch information
betaboon authored Nov 25, 2024
1 parent 0da2722 commit a5b5e34
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 146 deletions.
149 changes: 60 additions & 89 deletions mocket/inject.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
from __future__ import annotations

import contextlib
import os
import socket
import ssl
from types import ModuleType
from typing import Any

import urllib3

try: # pragma: no cover
from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3
_patches_restore: dict[tuple[ModuleType, str], Any] = {}

pyopenssl_override = True
except ImportError:
pyopenssl_override = False

def _patch(module: ModuleType, name: str, patched_value: Any) -> None:
with contextlib.suppress(KeyError):
original_value, module.__dict__[name] = module.__dict__[name], patched_value
_patches_restore[(module, name)] = original_value


def _restore(module: ModuleType, name: str) -> None:
if original_value := _patches_restore.pop((module, name)):
module.__dict__[name] = original_value


def enable(
namespace: str | None = None,
truesocket_recording_dir: str | None = None,
) -> None:
from mocket.mocket import Mocket
from mocket.socket import (
MocketSocket,
mock_create_connection,
Expand All @@ -27,99 +35,62 @@ def enable(
mock_gethostname,
mock_inet_pton,
mock_socketpair,
mock_urllib3_match_hostname,
)
from mocket.ssl.context import MocketSSLContext
from mocket.ssl.context import MocketSSLContext, mock_wrap_socket
from mocket.urllib3 import (
mock_match_hostname as mock_urllib3_match_hostname,
)
from mocket.urllib3 import (
mock_ssl_wrap_socket as mock_urllib3_ssl_wrap_socket,
)

patches = {
# stdlib: socket
(socket, "socket"): MocketSocket,
(socket, "create_connection"): mock_create_connection,
(socket, "getaddrinfo"): mock_getaddrinfo,
(socket, "gethostbyname"): mock_gethostbyname,
(socket, "gethostname"): mock_gethostname,
(socket, "inet_pton"): mock_inet_pton,
(socket, "SocketType"): MocketSocket,
(socket, "socketpair"): mock_socketpair,
# stdlib: ssl
(ssl, "SSLContext"): MocketSSLContext,
(ssl, "wrap_socket"): mock_wrap_socket, # python < 3.12.0
# urllib3
(urllib3.connection, "match_hostname"): mock_urllib3_match_hostname,
(urllib3.connection, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket,
(urllib3.util, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket,
(urllib3.util.ssl_, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket,
(urllib3.util.ssl_, "wrap_socket"): mock_urllib3_ssl_wrap_socket, # urllib3 < 2
}

for (module, name), new_value in patches.items():
_patch(module, name, new_value)

with contextlib.suppress(ImportError):
from urllib3.contrib.pyopenssl import extract_from_urllib3

extract_from_urllib3()

from mocket.mocket import Mocket

Mocket._namespace = namespace
Mocket._truesocket_recording_dir = truesocket_recording_dir

if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
# JSON dumps will be saved here
raise AssertionError

socket.socket = socket.__dict__["socket"] = MocketSocket
socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket
socket.SocketType = socket.__dict__["SocketType"] = MocketSocket
socket.create_connection = socket.__dict__["create_connection"] = (
mock_create_connection
)
socket.gethostname = socket.__dict__["gethostname"] = mock_gethostname
socket.gethostbyname = socket.__dict__["gethostbyname"] = mock_gethostbyname
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = mock_getaddrinfo
socket.socketpair = socket.__dict__["socketpair"] = mock_socketpair
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = MocketSSLContext.wrap_socket
ssl.SSLContext = ssl.__dict__["SSLContext"] = MocketSSLContext
socket.inet_pton = socket.__dict__["inet_pton"] = mock_inet_pton
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
MocketSSLContext.wrap_socket
)
urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
"ssl_wrap_socket"
] = MocketSSLContext.wrap_socket
urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
MocketSSLContext.wrap_socket
)
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
"ssl_wrap_socket"
] = MocketSSLContext.wrap_socket
urllib3.connection.match_hostname = urllib3.connection.__dict__[
"match_hostname"
] = mock_urllib3_match_hostname
if pyopenssl_override: # pragma: no cover
# Take out the pyopenssl version - use the default implementation
extract_from_urllib3()


def disable() -> None:
for module, name in list(_patches_restore.keys()):
_restore(module, name)

with contextlib.suppress(ImportError):
from urllib3.contrib.pyopenssl import inject_into_urllib3

inject_into_urllib3()

from mocket.mocket import Mocket
from mocket.socket import (
true_create_connection,
true_getaddrinfo,
true_gethostbyname,
true_gethostname,
true_inet_pton,
true_socket,
true_socketpair,
true_urllib3_match_hostname,
)
from mocket.ssl.context import (
true_ssl_context,
true_ssl_wrap_socket,
true_urllib3_ssl_wrap_socket,
true_urllib3_wrap_socket,
)

socket.socket = socket.__dict__["socket"] = true_socket
socket._socketobject = socket.__dict__["_socketobject"] = true_socket
socket.SocketType = socket.__dict__["SocketType"] = true_socket
socket.create_connection = socket.__dict__["create_connection"] = (
true_create_connection
)
socket.gethostname = socket.__dict__["gethostname"] = true_gethostname
socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo
socket.socketpair = socket.__dict__["socketpair"] = true_socketpair
if true_ssl_wrap_socket:
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket
ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context
socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
true_urllib3_wrap_socket
)
urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
"ssl_wrap_socket"
] = true_urllib3_ssl_wrap_socket
urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
true_urllib3_ssl_wrap_socket
)
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
"ssl_wrap_socket"
] = true_urllib3_ssl_wrap_socket
urllib3.connection.match_hostname = urllib3.connection.__dict__[
"match_hostname"
] = true_urllib3_match_hostname
Mocket.reset()
if pyopenssl_override: # pragma: no cover
# Put the pyopenssl version back in place
inject_into_urllib3()
11 changes: 0 additions & 11 deletions mocket/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from types import TracebackType
from typing import Any, Type

import urllib3.connection
from typing_extensions import Self

from mocket.compat import decode_from_bytes, encode_to_bytes
Expand All @@ -27,14 +26,8 @@
)
from mocket.utils import hexdump, hexload

true_create_connection = socket.create_connection
true_getaddrinfo = socket.getaddrinfo
true_gethostbyname = socket.gethostbyname
true_gethostname = socket.gethostname
true_inet_pton = socket.inet_pton
true_socket = socket.socket
true_socketpair = socket.socketpair
true_urllib3_match_hostname = urllib3.connection.match_hostname


xxh32 = None
Expand Down Expand Up @@ -84,10 +77,6 @@ def mock_socketpair(*args, **kwargs):
return _socket.socketpair(*args, **kwargs)


def mock_urllib3_match_hostname(*args: Any) -> None:
return None


def _hash_request(h, req):
return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest()

Expand Down
60 changes: 17 additions & 43 deletions mocket/ssl/context.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,10 @@
from __future__ import annotations

import contextlib
import ssl
from typing import Any

import urllib3.util.ssl_

from mocket.socket import MocketSocket
from mocket.ssl.socket import MocketSSLSocket

true_ssl_context = ssl.SSLContext

true_ssl_wrap_socket = None
true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket
true_urllib3_wrap_socket = None

with contextlib.suppress(ImportError):
# from Py3.12 it's only under SSLContext
from ssl import wrap_socket as ssl_wrap_socket

true_ssl_wrap_socket = ssl_wrap_socket

with contextlib.suppress(ImportError):
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket

true_urllib3_wrap_socket = urllib3_wrap_socket


class _MocketSSLContext:
"""For Python 3.6 and newer."""
Expand Down Expand Up @@ -70,30 +49,16 @@ def dummy_method(*args: Any, **kwargs: Any) -> Any:
for m in self.DUMMY_METHODS:
setattr(self, m, dummy_method)

@staticmethod
def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSSLSocket:
ssl_socket = MocketSSLSocket()
ssl_socket._original_socket = sock

ssl_socket._true_socket = true_urllib3_ssl_wrap_socket(
sock._true_socket,
**kwargs,
)
ssl_socket._kwargs = kwargs

ssl_socket._timeout = sock._timeout

ssl_socket._host = sock._host
ssl_socket._port = sock._port
ssl_socket._address = sock._address

ssl_socket._io = sock._io
ssl_socket._entry = sock._entry

return ssl_socket
def wrap_socket(
self,
sock: MocketSocket,
*args: Any,
**kwargs: Any,
) -> MocketSSLSocket:
return MocketSSLSocket._create(sock, *args, **kwargs)

@staticmethod
def wrap_bio(
self,
incoming: Any, # _ssl.MemoryBIO
outgoing: Any, # _ssl.MemoryBIO
server_side: bool = False,
Expand All @@ -102,3 +67,12 @@ def wrap_bio(
ssl_obj = MocketSSLSocket()
ssl_obj._host = server_hostname
return ssl_obj


def mock_wrap_socket(
sock: MocketSocket,
*args: Any,
**kwargs: Any,
) -> MocketSSLSocket:
context = MocketSSLContext()
return context.wrap_socket(sock, *args, **kwargs)
32 changes: 32 additions & 0 deletions mocket/ssl/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,35 @@ def compression(self) -> str | None:

def unwrap(self) -> MocketSocket:
return self._original_socket

@classmethod
def _create(
cls,
sock: MocketSocket,
ssl_context: ssl.SSLContext | None = None,
server_hostname: str | None = None,
*args: Any,
**kwargs: Any,
) -> MocketSSLSocket:
ssl_socket = MocketSSLSocket()
ssl_socket._original_socket = sock
ssl_socket._true_socket = sock._true_socket

if ssl_context:
ssl_socket._true_socket = ssl_context.wrap_socket(
sock=ssl_socket._true_socket,
server_hostname=server_hostname,
)

ssl_socket._kwargs = kwargs

ssl_socket._timeout = sock._timeout

ssl_socket._host = sock._host
ssl_socket._port = sock._port
ssl_socket._address = sock._address

ssl_socket._io = sock._io
ssl_socket._entry = sock._entry

return ssl_socket
20 changes: 20 additions & 0 deletions mocket/urllib3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations

from typing import Any

from mocket.socket import MocketSocket
from mocket.ssl.context import MocketSSLContext
from mocket.ssl.socket import MocketSSLSocket


def mock_match_hostname(*args: Any) -> None:
return None


def mock_ssl_wrap_socket(
sock: MocketSocket,
*args: Any,
**kwargs: Any,
) -> MocketSSLSocket:
context = MocketSSLContext()
return context.wrap_socket(sock, *args, **kwargs)
4 changes: 2 additions & 2 deletions mocket/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def get_mocketize(wrapper_: Callable) -> Callable:


__all__ = (
"MocketSocketCore",
"MocketMode",
"MocketSocketCore",
"SSL_PROTOCOL",
"get_mocketize",
"hexdump",
"hexload",
"get_mocketize",
)
2 changes: 1 addition & 1 deletion tests/test_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from mocket import Mocketizer, mocketize
from mocket.exceptions import StrictMocketException
from mocket.mockhttp import Entry, Response
from mocket.utils import MocketMode
from mocket.mode import MocketMode


@mocketize(strict_mode=True)
Expand Down

0 comments on commit a5b5e34

Please sign in to comment.