diff --git a/lib/postgrex/extensions/interval.ex b/lib/postgrex/extensions/interval.ex index 7860eb70..5fa91b2e 100644 --- a/lib/postgrex/extensions/interval.ex +++ b/lib/postgrex/extensions/interval.ex @@ -3,7 +3,16 @@ defmodule Postgrex.Extensions.Interval do import Postgrex.BinaryUtils, warn: false use Postgrex.BinaryExtension, send: "interval_send" - def init(opts), do: Keyword.get(opts, :interval_decode_type, Postgrex.Interval) + def init(opts) do + case Keyword.get(opts, :interval_decode_type, Postgrex.Interval) do + type when type in [Postgrex.Interval, Duration] -> + type + + other -> + raise ArgumentError, + "#{inspect(other)} is not valid for `:interval_decode_type`. Please use either `Postgrex.Interval` or `Duration`" + end + end if Code.ensure_loaded?(Duration) do import Bitwise, warn: false @@ -43,48 +52,60 @@ defmodule Postgrex.Extensions.Interval do def decode(type) do quote location: :keep do <<16::int32(), microseconds::int64(), days::int32(), months::int32()>> -> - seconds = div(microseconds, 1_000_000) - microseconds = rem(microseconds, 1_000_000) - - case unquote(type) do - Postgrex.Interval -> - %Postgrex.Interval{ - months: months, - days: days, - secs: seconds, - microsecs: microseconds - } - - Duration -> - years = div(months, 12) - months = rem(months, 12) - weeks = div(days, 7) - days = rem(days, 7) - minutes = div(seconds, 60) - seconds = rem(seconds, 60) - hours = div(minutes, 60) - minutes = rem(minutes, 60) - type_mod = var!(mod) - precision = if type_mod, do: type_mod &&& unquote(@precision_mask) - - precision = - if precision in unquote(@unspecified_precision), - do: unquote(@default_precision), - else: precision - - Duration.new!( - year: years, - month: months, - week: weeks, - day: days, - hour: hours, - minute: minutes, - second: seconds, - microsecond: {microseconds, precision} - ) - end + precision = if var!(mod), do: var!(mod) &&& unquote(@precision_mask) + + unquote(__MODULE__).decode_interval( + microseconds, + days, + months, + precision, + unquote(type) + ) end end + + ## Helpers + + def decode_interval(microseconds, days, months, _precision, Postgrex.Interval) do + seconds = div(microseconds, 1_000_000) + microseconds = rem(microseconds, 1_000_000) + + %Postgrex.Interval{ + months: months, + days: days, + secs: seconds, + microsecs: microseconds + } + end + + def decode_interval(microseconds, days, months, precision, Duration) do + years = div(months, 12) + months = rem(months, 12) + weeks = div(days, 7) + days = rem(days, 7) + seconds = div(microseconds, 1_000_000) + microseconds = rem(microseconds, 1_000_000) + minutes = div(seconds, 60) + seconds = rem(seconds, 60) + hours = div(minutes, 60) + minutes = rem(minutes, 60) + + precision = + if precision in unquote(@unspecified_precision), + do: unquote(@default_precision), + else: precision + + Duration.new!( + year: years, + month: months, + week: weeks, + day: days, + hour: hours, + minute: minutes, + second: seconds, + microsecond: {microseconds, precision} + ) + end else def encode(_) do quote location: :keep do @@ -100,10 +121,22 @@ defmodule Postgrex.Extensions.Interval do def decode(_) do quote location: :keep do <<16::int32(), microseconds::int64(), days::int32(), months::int32()>> -> - seconds = div(microseconds, 1_000_000) - microseconds = rem(microseconds, 1_000_000) - %Postgrex.Interval{months: months, days: days, secs: seconds, microsecs: microseconds} + unquote(__MODULE__).decode_interval(microseconds, days, months, nil, Postgrex.Interval) end end + + ## Helpers + + def decode_interval(microseconds, days, months, _precision, Postgrex.Interval) do + seconds = div(microseconds, 1_000_000) + microseconds = rem(microseconds, 1_000_000) + + %Postgrex.Interval{ + months: months, + days: days, + secs: seconds, + microsecs: microseconds + } + end end end