Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SSLContext factory #1815

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/deployment.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ Options:
--ssl-ca-certs TEXT CA certificates file
--ssl-ciphers TEXT Ciphers to use (see stdlib ssl module's)
[default: TLSv1]
--ssl-context CALLABLE Custom ssl_context that returns
ssl.SSLContext to set on config
--header TEXT Specify custom default HTTP response headers
as a Name:Value pair
--version Display the uvicorn version and exit.
Expand Down
2 changes: 2 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ Options:
--ssl-ca-certs TEXT CA certificates file
--ssl-ciphers TEXT Ciphers to use (see stdlib ssl module's)
[default: TLSv1]
--ssl-context CALLABLE Custom ssl_context that returns
ssl.SSLContext to set on config
--header TEXT Specify custom default HTTP response headers
as a Name:Value pair
--version Display the uvicorn version and exit.
Expand Down
25 changes: 24 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import socket
import ssl
import sys
import typing
from pathlib import Path
Expand Down Expand Up @@ -305,10 +306,32 @@ def test_ssl_config(
ssl_keyfile=tls_ca_certificate_private_key_path,
)
config.load()

assert config.is_ssl is True


# ignore
def ssl_context():
context = ssl.SSLContext(int(ssl.PROTOCOL_TLS_SERVER)) # type: ignore
allowed_ciphers = (
"DEFAULT:!aNULL:!eNULL:!MD5:!3DES:!DES:!RC4:!IDEA:!SEED:!aDSS:!SRP:!PSK"
)
context.set_ciphers(allowed_ciphers)
list_options = [ssl.OP_NO_RENEGOTIATION]
for each_option in list_options:
context.options |= each_option
return context


def test_ssl_context() -> None:
config = Config(app=asgi_app, ssl_context=ssl_context)
config.load()
if config.ssl_context is not None:
assert ssl.PROTOCOL_TLS_SERVER is config.ssl_version
assert "TLSv1" in config.ssl_ciphers
if config.ssl is not None:
assert ssl.OP_NO_RENEGOTIATION in config.ssl.options


def test_ssl_config_combined(tls_certificate_key_and_chain_path: str) -> None:
config = Config(
app=asgi_app,
Expand Down
39 changes: 38 additions & 1 deletion uvicorn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,27 @@
logger = logging.getLogger("uvicorn.error")


def update_ssl_context(
ctx: ssl.SSLContext,
certfile: Optional[Union[str, os.PathLike]],
keyfile: Optional[Union[str, os.PathLike]],
password: Optional[str],
cert_reqs: int,
ca_certs: Optional[Union[str, os.PathLike]],
ciphers: Optional[str],
) -> ssl.SSLContext:
get_password = (lambda: password) if password else None
if certfile and keyfile:
ctx.load_cert_chain(certfile, keyfile, get_password)
if cert_reqs:
ctx.verify_mode = ssl.VerifyMode(cert_reqs)
if ca_certs:
ctx.load_verify_locations(ca_certs)
if ciphers:
ctx.set_ciphers(ciphers)
return ctx


def create_ssl_context(
certfile: str | os.PathLike[str],
keyfile: str | os.PathLike[str] | None,
Expand Down Expand Up @@ -223,7 +244,9 @@ def __init__(
ssl_cert_reqs: int = ssl.CERT_NONE,
ssl_ca_certs: str | None = None,
ssl_ciphers: str = "TLSv1",
ssl_context: Optional[Callable] = None,
headers: list[tuple[str, str]] | None = None,

factory: bool = False,
h11_max_incomplete_event_size: int | None = None,
):
Expand Down Expand Up @@ -267,6 +290,7 @@ def __init__(
self.ssl_cert_reqs = ssl_cert_reqs
self.ssl_ca_certs = ssl_ca_certs
self.ssl_ciphers = ssl_ciphers
self.ssl_context = ssl_context
self.headers: list[tuple[str, str]] = headers or []
self.encoded_headers: list[tuple[bytes, bytes]] = []
self.factory = factory
Expand Down Expand Up @@ -416,8 +440,20 @@ def configure_logging(self) -> None:
def load(self) -> None:
assert not self.loaded

if self.is_ssl:
if self.ssl_context:
self.ssl: Optional[ssl.SSLContext] = update_ssl_context(
self.ssl_context(),
keyfile=self.ssl_keyfile,
certfile=self.ssl_certfile,
password=self.ssl_keyfile_password,
cert_reqs=self.ssl_cert_reqs,
ca_certs=self.ssl_ca_certs,
ciphers=self.ssl_ciphers,
)

elif self.is_ssl and not self.ssl_context:
assert self.ssl_certfile

self.ssl: ssl.SSLContext | None = create_ssl_context(
keyfile=self.ssl_keyfile,
certfile=self.ssl_certfile,
Expand All @@ -427,6 +463,7 @@ def load(self) -> None:
ca_certs=self.ssl_ca_certs,
ciphers=self.ssl_ciphers,
)

else:
self.ssl = None

Expand Down
12 changes: 12 additions & 0 deletions uvicorn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,13 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
help="Ciphers to use (see stdlib ssl module's)",
show_default=True,
)
@click.option(
"--ssl-context",
type=typing.Callable,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can't be a callable. What I said was to load the function object from the path... Which I'm not sure if we really should make this available via CLI.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are messages I've written here that you ignored. I'm not reviewing this again without them addressed.

The code changes here don't fully take in consideration my previous comments.

default=None,
help="Custom ssl_context that returns ssl.SSLContext to set on config",
show_default=True,
)
@click.option(
"--header",
"headers",
Expand Down Expand Up @@ -409,6 +416,7 @@ def main(
ssl_cert_reqs: int,
ssl_ca_certs: str,
ssl_ciphers: str,
ssl_context: typing.Callable,
headers: list[str],
use_colors: bool,
app_dir: str,
Expand Down Expand Up @@ -458,6 +466,7 @@ def main(
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=ssl_ca_certs,
ssl_ciphers=ssl_ciphers,
ssl_context=ssl_context,
headers=[header.split(":", 1) for header in headers], # type: ignore[misc]
use_colors=use_colors,
factory=factory,
Expand Down Expand Up @@ -510,9 +519,11 @@ def run(
ssl_cert_reqs: int = ssl.CERT_NONE,
ssl_ca_certs: str | None = None,
ssl_ciphers: str = "TLSv1",
ssl_context: typing.Optional[typing.Callable] = None,
headers: list[tuple[str, str]] | None = None,
use_colors: bool | None = None,
app_dir: str | None = None,

factory: bool = False,
h11_max_incomplete_event_size: int | None = None,
) -> None:
Expand Down Expand Up @@ -562,6 +573,7 @@ def run(
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=ssl_ca_certs,
ssl_ciphers=ssl_ciphers,
ssl_context=ssl_context,
headers=headers,
use_colors=use_colors,
factory=factory,
Expand Down
Loading