Skip to content

Commit

Permalink
Streamline SSL experience (#677)
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim authored May 18, 2024
1 parent ef6bd2f commit de665e4
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 93 deletions.
67 changes: 10 additions & 57 deletions lib/postgrex.ex
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ defmodule Postgrex do
| {:connect_timeout, timeout}
| {:handshake_timeout, timeout}
| {:ping_timeout, timeout}
| {:ssl, boolean}
| {:ssl_opts, [:ssl.tls_client_option()]}
| {:ssl, boolean | [:ssl.tls_client_option()]}
| {:socket_options, [:gen_tcp.connect_option()]}
| {:prepare, :named | :unnamed}
| {:transactions, :strict | :naive}
Expand Down Expand Up @@ -118,11 +117,10 @@ defmodule Postgrex do
* `:idle_interval` - Ping connections after a period of inactivity in milliseconds.
Defaults to 1000ms;
* `:ssl` - Set to `true` if ssl should be used (default: `false`);
* `:ssl_opts` - A list of ssl options, see the
[`tls_client_option`](http://erlang.org/doc/man/ssl.html#type-tls_client_option)
from the ssl docs;
* `:ssl` - Enables SSL. Setting it to `true` enables SSL without server certificate verification,
which emits a warning. Instead, prefer to set it to a keyword list, with either
`:cacerts` or `:cacertfile` set to a CA trust store, to enable server certificate
verification. Defaults to `false`;
* `:socket_options` - Options to be given to the underlying socket
(applies to both TCP and UNIX sockets);
Expand Down Expand Up @@ -193,37 +191,13 @@ defmodule Postgrex do
## SSL client authentication
When connecting to Postgres or CockroachDB instances over SSL it is idiomatic to use
certificate authentication. Config files do not allowing passing functions,
so use the `init` callback of the Ecto supervisor.
In your Repository configuration:
config :app, App.Repo,
ssl: String.to_existing_atom(System.get_env("DB_SSL_ENABLED", "true")),
verify_ssl: true
And in App.Repo, set your `:ssl_opts`:
certificate authentication:
def init(_type, config) do
config =
if config[:verify_ssl] do
Keyword.put(config, :ssl_opts, my_ssl_opts(config[:hostname]))
else
config
end
[ssl: [cacertfile: System.get_env("DB_CA_CERT_FILE")]]
{:ok, config}
end
def my_ssl_opts(server) do
[
verify: :verify_peer,
cacertfile: System.get_env("DB_CA_CERT_FILE"),
server_name_indication: String.to_charlist(server),
customize_hostname_check: [match_fun: :public_key.pkix_verify_hostname_match_fun(:https)],
depth: 3
]
end
The server name indication (SNI) will be automatically set based on the `:hostname`
configuration, if one was provided. Other options, such as `depth: 3`, may be necessary
depending on the server.
## PgBouncer
Expand Down Expand Up @@ -271,27 +245,6 @@ defmodule Postgrex do
{"test-instance-N.xyz.eu-west-1.rds.amazonaws.com", 5432}
]
### Failover with SSL support
As specified in Erlang [:ssl.connect](https://erlang.org/doc/man/ssl.html#connect-3),
host verification using `:public_key.pkix_verify_hostname_match_fun(:https)`
requires that the ssl_opt `server_name_indication` is set, and for this reason,
the aforementioned `endpoints` list can become a three element tuple as:
endpoints: [
{
"test-instance-1.xyz.eu-west-1.rds.amazonaws.com",
5432,
[ssl_opts: [server_name_indication: String.to_charlist("test-instance-1.xyz.eu-west-1.rds.amazonaws.com")]]
},
(...),
{
"test-instance-2.xyz.eu-west-1.rds.amazonaws.com",
5432,
[ssl_opts: [server_name_indication: String.to_charlist("test-instance-2.xyz.eu-west-1.rds.amazonaws.com")]]
}
]
"""
@spec start_link([start_option]) :: {:ok, pid} | {:error, Postgrex.Error.t() | term}
def start_link(opts) do
Expand Down
83 changes: 62 additions & 21 deletions lib/postgrex/protocol.ex
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,29 @@ defmodule Postgrex.Protocol do
timeout = opts[:timeout] || @timeout
ping_timeout = Keyword.get(opts, :ping_timeout, timeout)
sock_opts = [send_timeout: timeout] ++ (opts[:socket_options] || [])
ssl? = opts[:ssl] || false
types_mod = Keyword.fetch!(opts, :types)
disconnect_on_error_codes = opts[:disconnect_on_error_codes] || []
target_server_type = opts[:target_server_type] || :any
disable_composite_types = opts[:disable_composite_types] || false

{ssl_opts, opts} =
case Keyword.pop(opts, :ssl, false) do
{false, opts} ->
{nil, opts}

{true, opts} ->
Logger.warning(
"setting ssl: true on your database connection offers only limited protection, " <>
"as the server's certificate is not verified. Set \"ssl: [cacertfile: path/to/file]\" instead"
)

# Read ssl_opts for backwards compatibility
Keyword.pop(opts, :ssl_opts, [])

{ssl_opts, opts} when is_list(ssl_opts) ->
{Keyword.merge(default_ssl_opts(), ssl_opts), opts}
end

transactions =
case opts[:transactions] || :naive do
:naive -> :naive
Expand Down Expand Up @@ -117,32 +134,49 @@ defmodule Postgrex.Protocol do
types_lock: nil,
prepare: prepare,
messages: [],
ssl: ssl?,
ssl: ssl_opts,
target_server_type: target_server_type,
search_path: opts[:search_path]
}

connect_endpoints(endpoints, sock_opts ++ @sock_opts, connect_timeout, s, status, [])
end

defp default_ssl_opts do
[
verify: :verify_peer,
customize_hostname_check: [
match_fun: :public_key.pkix_verify_hostname_match_fun(:https)
]
]
end

defp endpoints(opts) do
port = opts[:port] || 5432

case Keyword.fetch(opts, :socket) do
{:ok, file} ->
[{{:local, file}, 0, []}]
[{{:local, file}, 0}]

:error ->
case Keyword.fetch(opts, :socket_dir) do
{:ok, dir} ->
[{{:local, "#{dir}/.s.PGSQL.#{port}"}, 0, []}]
[{{:local, "#{dir}/.s.PGSQL.#{port}"}, 0}]

:error ->
case Keyword.fetch(opts, :endpoints) do
{:ok, endpoints} when is_list(endpoints) ->
Enum.map(endpoints, fn
{hostname, port} -> {to_charlist(hostname), port, []}
{hostname, port, extra_opts} -> {to_charlist(hostname), port, extra_opts}
{hostname, port} ->
{to_charlist(hostname), port}

{hostname, port, _extra_opts} ->
Logger.warning(
"Returning a triplet from :endpoints is deprecated, " <>
"the server name indicator is automatically set based on the hostname if SSL is enabled"
)

{to_charlist(hostname), port}
end)

{:ok, _} ->
Expand All @@ -151,7 +185,7 @@ defmodule Postgrex.Protocol do
:error ->
case Keyword.fetch(opts, :hostname) do
{:ok, hostname} ->
[{to_charlist(hostname), port, []}]
[{to_charlist(hostname), port}]

:error ->
raise ArgumentError,
Expand All @@ -163,16 +197,15 @@ defmodule Postgrex.Protocol do
end

defp connect_endpoints(
[{host, port, extra_opts} | remaining_endpoints],
[{host, port} | remaining_endpoints],
sock_opts,
timeout,
s,
%{opts: opts, types_mod: types_mod} = status,
previous_errors
) do
with {:ok, database} <- fetch_database(opts),
opts = Config.Reader.merge(opts, extra_opts),
status = %{status | types_key: if(types_mod, do: {host, port, database}), opts: opts},
status = %{status | types_key: if(types_mod, do: {host, port, database})},
{:ok, ret} <- connect_and_handshake(host, port, sock_opts, timeout, s, status) do
{:ok, ret}
else
Expand Down Expand Up @@ -216,7 +249,7 @@ defmodule Postgrex.Protocol do
defp connect_and_handshake(host, port, sock_opts, timeout, s, status) do
case connect(host, port, sock_opts, timeout, s) do
{:ok, s} ->
handshake(s, status)
handshake(host, s, status)

{:error, _} = error ->
error
Expand Down Expand Up @@ -682,13 +715,13 @@ defmodule Postgrex.Protocol do

## handshake

defp handshake(%{sock: {:gen_tcp, sock}, timeout: timeout} = s, status) do
defp handshake(host, %{sock: {:gen_tcp, sock}, timeout: timeout} = s, status) do
{:ok, peer} = :inet.peername(sock)
%{opts: opts} = status
handshake_timeout = Keyword.get(opts, :handshake_timeout, timeout)
timer = start_handshake_timer(handshake_timeout, sock)

case do_handshake(%{s | peer: peer}, status) do
case do_handshake(host, %{s | peer: peer}, status) do
{:ok, %{parameters: parameters} = s} ->
cancel_handshake_timer(timer)
ref = Postgrex.Parameters.insert(parameters)
Expand Down Expand Up @@ -733,22 +766,30 @@ defmodule Postgrex.Protocol do
:ok
end

defp do_handshake(s, %{ssl: true} = status), do: ssl(s, status)
defp do_handshake(s, %{ssl: false} = status), do: startup(s, status)
defp do_handshake(_host, s, %{ssl: nil} = status), do: startup(s, status)

defp do_handshake(host, s, %{ssl: ssl_opts} = status) do
ssl_opts =
if is_list(host),
do: Keyword.put_new(ssl_opts, :server_name_indication, host),
else: ssl_opts

ssl(s, status, ssl_opts)
end

## ssl

defp ssl(s, status) do
defp ssl(s, status, ssl_opts) do
case msg_send(s, msg_ssl_request(), "") do
:ok -> ssl_recv(s, status)
:ok -> ssl_recv(s, status, ssl_opts)
{:disconnect, _, _} = dis -> dis
end
end

defp ssl_recv(%{sock: {:gen_tcp, sock}} = s, status) do
defp ssl_recv(%{sock: {:gen_tcp, sock}} = s, status, ssl_opts) do
case :gen_tcp.recv(sock, 1, :infinity) do
{:ok, <<?S>>} ->
ssl_connect(s, status)
ssl_connect(s, status, ssl_opts)

{:ok, <<?N>>} ->
disconnect(s, %Postgrex.Error{message: "ssl not available"}, "")
Expand All @@ -770,8 +811,8 @@ defmodule Postgrex.Protocol do
end
end

defp ssl_connect(%{sock: {:gen_tcp, sock}, timeout: timeout} = s, status) do
case :ssl.connect(sock, status.opts[:ssl_opts] || [], timeout) do
defp ssl_connect(%{sock: {:gen_tcp, sock}, timeout: timeout} = s, status, ssl_opts) do
case :ssl.connect(sock, ssl_opts, timeout) do
{:ok, ssl_sock} ->
startup(%{s | sock: {:ssl, ssl_sock}}, status)

Expand Down
15 changes: 0 additions & 15 deletions test/login_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -88,21 +88,6 @@ defmodule LoginTest do
assert {:ok, %Postgrex.Result{}} = P.query(pid, "SELECT 123", [])
end

@tag :ssl
test "ssl with extra_ssl_opts in endpoints succeeds", context do
opts = [ssl: true, endpoints: [{"localhost", 5555, [ssl: [verify_peer: :none]]}]]
assert {:ok, pid} = P.start_link(opts ++ context[:options])
assert {:ok, %Postgrex.Result{}} = P.query(pid, "SELECT 123", [])
end

@tag :ssl
test "ssl with extra_ssl_opts in endpoints fails due to bad ssl_opt", context do
assert capture_log(fn ->
opts = [ssl: true, endpoints: [{"localhost", 5555, [ssl: [verify_peer: :foobar]]}]]
assert_start_and_killed(opts ++ context[:options])
end)
end

test "env var defaults", context do
assert {:ok, pid} = P.start_link(context[:options])
assert {:ok, %Postgrex.Result{}} = P.query(pid, "SELECT 123", [])
Expand Down

0 comments on commit de665e4

Please sign in to comment.