Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow extensions to access type modifiers #684

Merged
merged 12 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lib/postgrex/extensions/array.ex
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ defmodule Postgrex.Extensions.Array do
data::binary>> = binary

# decode_list/2 defined by TypeModule
type =
case type do
{extension, sub_oids, sub_types} -> {extension, sub_oids, sub_types, var!(mod)}
extension -> {extension, var!(mod)}
end

flat = decode_list(data, type)

unquote(__MODULE__).decode(dims, flat)
Expand Down
7 changes: 7 additions & 0 deletions lib/postgrex/extensions/multirange.ex
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@ defmodule Postgrex.Extensions.Multirange do
quote location: :keep do
<<len::int32(), data::binary-size(len)>>, [_oid], [type] ->
<<_num_ranges::int32(), ranges::binary>> = data

# decode_list/2 defined by TypeModule
type =
case type do
{extension, sub_oids, sub_types} -> {extension, sub_oids, sub_types, nil}
extension -> {extension, nil}
end

bound_decoder = &decode_list(&1, type)
unquote(__MODULE__).decode(ranges, bound_decoder, [])
end
Expand Down
7 changes: 7 additions & 0 deletions lib/postgrex/extensions/range.ex
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,14 @@ defmodule Postgrex.Extensions.Range do
quote location: :keep do
<<len::int32(), binary::binary-size(len)>>, [_oid], [type] ->
<<flags, data::binary>> = binary

# decode_list/2 defined by TypeModule
type =
case type do
{extension, sub_oids, sub_types} -> {extension, sub_oids, sub_types, nil}
extension -> {extension, nil}
end

case decode_list(data, type) do
[upper, lower] ->
unquote(__MODULE__).decode(flags, [lower, upper])
Expand Down
26 changes: 15 additions & 11 deletions lib/postgrex/extensions/timestamp.ex
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ defmodule Postgrex.Extensions.Timestamp do
@min_year -4_713
@plus_infinity 9_223_372_036_854_775_807
@minus_infinity -9_223_372_036_854_775_808
@default_precision 6

def init(opts), do: Keyword.get(opts, :allow_infinite_timestamps, false)

Expand All @@ -28,7 +29,7 @@ defmodule Postgrex.Extensions.Timestamp do
def decode(infinity?) do
quote location: :keep do
<<8::int32(), microsecs::int64()>> ->
unquote(__MODULE__).microsecond_to_elixir(microsecs, unquote(infinity?))
unquote(__MODULE__).microsecond_to_elixir(microsecs, var!(mod), unquote(infinity?))
end
end

Expand All @@ -51,32 +52,35 @@ defmodule Postgrex.Extensions.Timestamp do
<<8::int32(), secs * 1_000_000 + usec::int64()>>
end

def microsecond_to_elixir(@plus_infinity, infinity?) do
def microsecond_to_elixir(@plus_infinity, _precision, infinity?) do
if infinity?, do: :inf, else: raise_infinity("infinity")
end

def microsecond_to_elixir(@minus_infinity, infinity?) do
def microsecond_to_elixir(@minus_infinity, _precision, infinity?) do
if infinity?, do: :"-inf", else: raise_infinity("-infinity")
end

def microsecond_to_elixir(microsecs, _infinity) do
split(microsecs)
def microsecond_to_elixir(microsecs, precision, _infinity) do
split(microsecs, precision)
end

defp split(microsecs) when microsecs < 0 and rem(microsecs, 1_000_000) != 0 do
defp split(microsecs, precision) when microsecs < 0 and rem(microsecs, 1_000_000) != 0 do
secs = div(microsecs, 1_000_000) - 1
microsecs = 1_000_000 + rem(microsecs, 1_000_000)
split(secs, microsecs)
split(secs, microsecs, precision)
end

defp split(microsecs) do
defp split(microsecs, precision) do
secs = div(microsecs, 1_000_000)
microsecs = rem(microsecs, 1_000_000)
split(secs, microsecs)
split(secs, microsecs, precision)
end

defp split(secs, microsecs) do
NaiveDateTime.from_gregorian_seconds(secs + @gs_epoch, {microsecs, 6})
defp split(secs, microsecs, precision) do
# use the default precision if the precision modifier from postgres is -1 (this means no precision specified)
# or if the precision is missing because we are in a super type which does not give us the sub-type's modifier
precision = if precision in [-1, nil], do: @default_precision, else: precision
NaiveDateTime.from_gregorian_seconds(secs + @gs_epoch, {microsecs, precision})
end

defp raise_infinity(type) do
Expand Down
55 changes: 33 additions & 22 deletions lib/postgrex/protocol.ex
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ defmodule Postgrex.Protocol do
defp set_search_path_recv(s, status, buffer) do
case msg_recv(s, :infinity, buffer) do
{:ok, msg_row_desc(fields: fields), buffer} ->
{[@text_type_oid], ["search_path"]} = columns(fields)
{[@text_type_oid], ["search_path"], _} = columns(fields)
set_search_path_recv(s, status, buffer)

{:ok, msg_data_row(), buffer} ->
Expand Down Expand Up @@ -1014,7 +1014,7 @@ defmodule Postgrex.Protocol do
) do
case msg_recv(s, :infinity, buffer) do
{:ok, msg_row_desc(fields: fields), buffer} ->
{[@text_type_oid], ["transaction_read_only"]} = columns(fields)
{[@text_type_oid], ["transaction_read_only"], _} = columns(fields)
check_target_server_type_recv(s, status, buffer)

{:ok, msg_data_row(values: values), buffer} ->
Expand Down Expand Up @@ -1611,8 +1611,9 @@ defmodule Postgrex.Protocol do
)
when ref == nil or protocol_types != query_types do
with {:ok, s, buffer} <- recv_parse(s, status, buffer),
{:ok, param_oids, result_oids, columns, s, buffer} <- recv_describe(s, status, buffer) do
describe(s, query, param_oids, result_oids, columns, buffer)
{:ok, param_oids, result_oids, result_mods, columns, s, buffer} <-
recv_describe(s, status, buffer) do
describe(s, query, param_oids, result_oids, result_mods, columns, buffer)
else
{:error, %Postgrex.Error{} = error, s, buffer} ->
{:error, %{error | query: query.statement}, s, buffer}
Expand All @@ -1626,13 +1627,13 @@ defmodule Postgrex.Protocol do
%Query{param_oids: param_oids, result_oids: result_oids, columns: columns} = query

with {:ok, s, buffer} <- recv_parse(s, status, buffer),
{:ok, ^param_oids, ^result_oids, ^columns, s, buffer} <-
{:ok, ^param_oids, ^result_oids, _, ^columns, s, buffer} <-
recv_describe(s, status, param_oids, buffer) do
query_put(s, query)
{:ok, query, s, buffer}
else
{:ok, ^param_oids, new_result_oids, new_columns, s, buffer} ->
redescribe(s, query, new_result_oids, new_columns, buffer)
{:ok, ^param_oids, new_result_oids, new_result_mods, new_columns, s, buffer} ->
redescribe(s, query, new_result_oids, new_result_mods, new_columns, buffer)

{:error, %Postgrex.Error{}, _, _} = error ->
error
Expand Down Expand Up @@ -1662,14 +1663,14 @@ defmodule Postgrex.Protocol do
defp recv_describe(s, status, param_oids \\ [], buffer) do
case msg_recv(s, :infinity, buffer) do
{:ok, msg_no_data(), buffer} ->
{:ok, param_oids, nil, nil, s, buffer}
{:ok, param_oids, nil, nil, nil, s, buffer}

{:ok, msg_parameter_desc(type_oids: param_oids), buffer} ->
recv_describe(s, status, param_oids, buffer)

{:ok, msg_row_desc(fields: fields), buffer} ->
{result_oids, columns} = columns(fields)
{:ok, param_oids, result_oids, columns, s, buffer}
{result_oids, columns, result_mods} = columns(fields)
{:ok, param_oids, result_oids, result_mods, columns, s, buffer}

{:ok, msg_too_many_parameters(len: len, max_len: max), buffer} ->
msg = "postgresql protocol can not handle #{len} parameters, the maximum is #{max}"
Expand All @@ -1688,10 +1689,10 @@ defmodule Postgrex.Protocol do
end
end

defp describe(s, query, param_oids, result_oids, columns, buffer) do
defp describe(s, query, param_oids, result_oids, result_mods, columns, buffer) do
case describe_params(s, query, param_oids) do
{:ok, query} ->
redescribe(s, query, result_oids, columns, buffer)
redescribe(s, query, result_oids, result_mods, columns, buffer)

{:reload, oids} ->
reload_describe_result(s, oids, result_oids, buffer)
Expand All @@ -1701,8 +1702,8 @@ defmodule Postgrex.Protocol do
end
end

defp redescribe(s, query, result_oids, columns, buffer) do
with {:ok, query} <- describe_result(s, query, result_oids, columns) do
defp redescribe(s, query, result_oids, result_mods, columns, buffer) do
with {:ok, query} <- describe_result(s, query, result_oids, result_mods, columns) do
query_put(s, query)
{:ok, query, s, buffer}
else
Expand All @@ -1715,8 +1716,9 @@ defmodule Postgrex.Protocol do
end

defp describe_params(%{types: types}, query, param_oids) do
with {:ok, param_info} <- fetch_type_info(param_oids, types),
{param_formats, param_types} = Enum.unzip(param_info) do
with {:ok, param_info} <- fetch_type_info(param_oids, types) do
{param_formats, param_types} = Enum.unzip(param_info)

query = %Query{
query
| param_oids: param_oids,
Expand Down Expand Up @@ -1745,7 +1747,7 @@ defmodule Postgrex.Protocol do
end
end

defp describe_result(%{types: types}, query, nil, nil) do
defp describe_result(%{types: types}, query, nil, nil, nil) do
query = %Query{
query
| ref: make_ref(),
Expand All @@ -1759,9 +1761,18 @@ defmodule Postgrex.Protocol do
{:ok, query}
end

defp describe_result(%{types: types}, query, result_oids, columns) do
with {:ok, result_info} <- fetch_type_info(result_oids, types),
{result_formats, result_types} = Enum.unzip(result_info) do
defp describe_result(%{types: types}, query, result_oids, result_mods, columns) do
with {:ok, result_info} <- fetch_type_info(result_oids, types) do
{result_formats, result_types} = Enum.unzip(result_info)

result_types =
result_types
|> Enum.zip(result_mods)
|> Enum.map(fn
{{extension, sub_oids, sub_types}, mod} -> {extension, sub_oids, sub_types, mod}
{extension, mod} -> {extension, mod}
end)

query = %Query{
query
| ref: make_ref(),
Expand Down Expand Up @@ -3186,8 +3197,8 @@ defmodule Postgrex.Protocol do

defp columns(fields) do
fields
|> Enum.map(fn row_field(type_oid: oid, name: name) -> {oid, name} end)
|> :lists.unzip()
|> Enum.map(fn row_field(type_oid: oid, type_mod: mod, name: name) -> {oid, name, mod} end)
|> :lists.unzip3()
end

defp column_names(fields) do
Expand Down
Loading
Loading