diff --git a/amqp/transport.py b/amqp/transport.py index ec100e5f..41306811 100644 --- a/amqp/transport.py +++ b/amqp/transport.py @@ -535,6 +535,14 @@ def _wrap_socket_sni(self, sock, keyfile=None, certfile=None, except AttributeError: pass # ask forgiveness not permission + if ca_certs is None and context.verify_mode != ssl.CERT_NONE: + purpose = ( + ssl.Purpose.CLIENT_AUTH + if server_side + else ssl.Purpose.SERVER_AUTH + ) + context.load_default_certs(purpose) + sock = context.wrap_socket(**opts) return sock diff --git a/t/integration/test_rmq.py b/t/integration/test_rmq.py index 746aa65c..d89a2338 100644 --- a/t/integration/test_rmq.py +++ b/t/integration/test_rmq.py @@ -5,6 +5,7 @@ import pytest import amqp +from amqp import transport def get_connection( @@ -73,6 +74,32 @@ def test_tls_connect_fails(): connection.connect() +@pytest.mark.env('rabbitmq') +@pytest.mark.flaky(reruns=5, reruns_delay=2) +def test_tls_default_certs(): + # testing TLS connection against badssl.com with default certs + connection = transport.Transport( + host="tls-v1-2.badssl.com:1012", + ssl=True, + ) + assert type(connection) == transport.SSLTransport + connection.connect() + + +@pytest.mark.env('rabbitmq') +@pytest.mark.flaky(reruns=5, reruns_delay=2) +def test_tls_no_default_certs_fails(): + # testing TLS connection fails against badssl.com without default certs + connection = transport.Transport( + host="tls-v1-2.badssl.com:1012", + ssl={ + "ca_certs": 't/certs/ca_certificate.pem', + }, + ) + with pytest.raises(ssl.SSLError): + connection.connect() + + @pytest.mark.env('rabbitmq') class test_rabbitmq_operations(): diff --git a/t/unit/test_transport.py b/t/unit/test_transport.py index 7f8d78c9..f217fb65 100644 --- a/t/unit/test_transport.py +++ b/t/unit/test_transport.py @@ -1,6 +1,7 @@ import errno import os import re +import ssl import socket import struct from struct import pack @@ -639,112 +640,144 @@ def test_wrap_context(self): def test_wrap_socket_sni(self): # testing default values of _wrap_socket_sni() - sock = Mock() with patch('ssl.SSLContext') as mock_ssl_context_class: - wrap_socket_method_mock = mock_ssl_context_class().wrap_socket - wrap_socket_method_mock.return_value = sentinel.WRAPPED_SOCKET + sock = Mock() + context = mock_ssl_context_class() + context.wrap_socket.return_value = sentinel.WRAPPED_SOCKET ret = self.t._wrap_socket_sni(sock) - mock_ssl_context_class.load_cert_chain.assert_not_called() - mock_ssl_context_class.load_verify_locations.assert_not_called() - mock_ssl_context_class.set_ciphers.assert_not_called() - mock_ssl_context_class.verify_mode.assert_not_called() - wrap_socket_method_mock.assert_called_with( - sock=sock, - server_side=False, - do_handshake_on_connect=False, - suppress_ragged_eofs=True, - server_hostname=None - ) - assert ret == sentinel.WRAPPED_SOCKET + context.load_cert_chain.assert_not_called() + context.load_verify_locations.assert_not_called() + context.set_ciphers.assert_not_called() + context.verify_mode.assert_not_called() + + context.load_default_certs.assert_called_with( + ssl.Purpose.SERVER_AUTH + ) + context.wrap_socket.assert_called_with( + sock=sock, + server_side=False, + do_handshake_on_connect=False, + suppress_ragged_eofs=True, + server_hostname=None + ) + assert ret == sentinel.WRAPPED_SOCKET def test_wrap_socket_sni_certfile(self): # testing _wrap_socket_sni() with parameters certfile and keyfile with patch('ssl.SSLContext') as mock_ssl_context_class: - load_cert_chain_method_mock = \ - mock_ssl_context_class().load_cert_chain + sock = Mock() + context = mock_ssl_context_class() self.t._wrap_socket_sni( - Mock(), keyfile=sentinel.KEYFILE, certfile=sentinel.CERTFILE + sock, keyfile=sentinel.KEYFILE, certfile=sentinel.CERTFILE ) - load_cert_chain_method_mock.assert_called_with( - sentinel.CERTFILE, sentinel.KEYFILE - ) + context.load_default_certs.assert_called_with( + ssl.Purpose.SERVER_AUTH + ) + context.load_cert_chain.assert_called_with( + sentinel.CERTFILE, sentinel.KEYFILE + ) def test_wrap_socket_ca_certs(self): # testing _wrap_socket_sni() with parameter ca_certs with patch('ssl.SSLContext') as mock_ssl_context_class: - load_verify_locations_method_mock = \ - mock_ssl_context_class().load_verify_locations - self.t._wrap_socket_sni(Mock(), ca_certs=sentinel.CA_CERTS) + sock = Mock() + context = mock_ssl_context_class() + self.t._wrap_socket_sni(sock, ca_certs=sentinel.CA_CERTS) - load_verify_locations_method_mock.assert_called_with(sentinel.CA_CERTS) + context.load_default_certs.assert_not_called() + context.load_verify_locations.assert_called_with(sentinel.CA_CERTS) def test_wrap_socket_ciphers(self): # testing _wrap_socket_sni() with parameter ciphers with patch('ssl.SSLContext') as mock_ssl_context_class: - set_ciphers_method_mock = mock_ssl_context_class().set_ciphers - self.t._wrap_socket_sni(Mock(), ciphers=sentinel.CIPHERS) + sock = Mock() + context = mock_ssl_context_class() + set_ciphers_method_mock = context.set_ciphers + self.t._wrap_socket_sni(sock, ciphers=sentinel.CIPHERS) - set_ciphers_method_mock.assert_called_with(sentinel.CIPHERS) + set_ciphers_method_mock.assert_called_with(sentinel.CIPHERS) def test_wrap_socket_sni_cert_reqs(self): - # testing _wrap_socket_sni() with parameter cert_reqs + # testing _wrap_socket_sni() with parameter cert_reqs == ssl.CERT_NONE + with patch('ssl.SSLContext') as mock_ssl_context_class: + sock = Mock() + context = mock_ssl_context_class() + self.t._wrap_socket_sni(sock, cert_reqs=ssl.CERT_NONE) + + context.load_default_certs.assert_not_called() + assert context.verify_mode == ssl.CERT_NONE + + # testing _wrap_socket_sni() with parameter cert_reqs != ssl.CERT_NONE with patch('ssl.SSLContext') as mock_ssl_context_class: - self.t._wrap_socket_sni(Mock(), cert_reqs=sentinel.CERT_REQS) + sock = Mock() + context = mock_ssl_context_class() + self.t._wrap_socket_sni(sock, cert_reqs=sentinel.CERT_REQS) - assert mock_ssl_context_class().verify_mode == sentinel.CERT_REQS + context.load_default_certs.assert_called_with( + ssl.Purpose.SERVER_AUTH + ) + assert context.verify_mode == sentinel.CERT_REQS def test_wrap_socket_sni_setting_sni_header(self): # testing _wrap_socket_sni() without parameter server_hostname + # SSL module supports SNI with patch('ssl.SSLContext') as mock_ssl_context_class, \ patch('ssl.HAS_SNI', new=True): - self.t._wrap_socket_sni(Mock()) + sock = Mock() + context = mock_ssl_context_class() + self.t._wrap_socket_sni(sock) - assert mock_ssl_context_class().check_hostname is False + assert context.check_hostname is False # SSL module does not support SNI with patch('ssl.SSLContext') as mock_ssl_context_class, \ patch('ssl.HAS_SNI', new=False): - self.t._wrap_socket_sni(Mock()) + sock = Mock() + context = mock_ssl_context_class() + self.t._wrap_socket_sni(sock) - assert mock_ssl_context_class().check_hostname is False + assert context.check_hostname is False # testing _wrap_socket_sni() with parameter server_hostname - sock = Mock() + + # SSL module supports SNI with patch('ssl.SSLContext') as mock_ssl_context_class, \ patch('ssl.HAS_SNI', new=True): - # SSL module supports SNI - wrap_socket_method_mock = mock_ssl_context_class().wrap_socket + sock = Mock() + context = mock_ssl_context_class() self.t._wrap_socket_sni( sock, server_hostname=sentinel.SERVER_HOSTNAME ) - wrap_socket_method_mock.assert_called_with( - sock=sock, - server_side=False, - do_handshake_on_connect=False, - suppress_ragged_eofs=True, - server_hostname=sentinel.SERVER_HOSTNAME - ) - assert mock_ssl_context_class().check_hostname is True + context.wrap_socket.assert_called_with( + sock=sock, + server_side=False, + do_handshake_on_connect=False, + suppress_ragged_eofs=True, + server_hostname=sentinel.SERVER_HOSTNAME + ) + assert context.check_hostname is True + # SSL module does not support SNI with patch('ssl.SSLContext') as mock_ssl_context_class, \ patch('ssl.HAS_SNI', new=False): - # SSL module does not support SNI - wrap_socket_method_mock = mock_ssl_context_class().wrap_socket + sock = Mock() + context = mock_ssl_context_class() self.t._wrap_socket_sni( sock, server_hostname=sentinel.SERVER_HOSTNAME ) - wrap_socket_method_mock.assert_called_with( - sock=sock, - server_side=False, - do_handshake_on_connect=False, - suppress_ragged_eofs=True, - server_hostname=sentinel.SERVER_HOSTNAME - ) - assert mock_ssl_context_class().check_hostname is False + + context.wrap_socket.assert_called_with( + sock=sock, + server_side=False, + do_handshake_on_connect=False, + suppress_ragged_eofs=True, + server_hostname=sentinel.SERVER_HOSTNAME + ) + assert context.check_hostname is False def test_shutdown_transport(self): self.t.sock = None