diff --git a/README.md b/README.md index 2d0db3e..7e35e89 100644 --- a/README.md +++ b/README.md @@ -96,3 +96,54 @@ def deps do ] end ``` + +## DBConnection Support + +[DBConnection](https://github.com/elixir-ecto/db_connection) support is currently in experimental phase, setting it up is very similar to current implementation with the expection of configuration options and obtaining the same results will require an extra step: + +### Configuration: + +Setting a Module to hold the connection is very similar, but instead you'll use `Snowflex.DBConnection`: + +Example: + +```elixir +defmodule MyApp.SnowflakeConnection do + use Snowflex.DBConnection, + otp_app: :my_app, + timeout: :timer.minutes(5) +end +``` + +```elixir +config :my_app, MyApp.SnowflakeConnection, + pool_size: 5, # the connection pool size + worker: MyApp.CustomWorker, # defaults to Snowflex.DBConnection.Server + connection: [ + role: "PROD", + warehouse: System.get_env("SNOWFLAKE_POS_WH"), + uid: System.get_env("SNOWFLAKE_POS_UID"), + pwd: System.get_env("SNOWFLAKE_POS_PWD") + ] +``` + +### Usage: + +After setup, you can use your connection to query: + +```elixir +alias Snowflex.DBConnection.Result + +{:ok, %Result{} = result} = MyApp.SnowflakeConnection.execute("my query") +{:ok, %Result{} = result} = MyApp.SnowflakeConnection.execute("my query", ["my params"]) +``` + +As you can see we now receive an `{:ok, result}` tuple, to get results as expected with current implementation, we need to call `process_result/1`: + +```elixir +alias Snowflex.DBConnection.Result + +{:ok, %Result{} = result} = MyApp.SnowflakeConnection.execute("my query") + +[%{"col" => 1}, %{"col" => 2}] = SnowflakeDBConnection.process_result(result) +``` diff --git a/config/test.exs b/config/test.exs index db61e3a..4e9b744 100644 --- a/config/test.exs +++ b/config/test.exs @@ -6,3 +6,12 @@ config :snowflex, Snowflex.ConnectionTest.SnowflakeConnection, min: 1, max: 1 ] + +config :snowflex, Snowflex.DBConnectionTest.SnowflakeDBConnection, + worker: Snowflex.DBConnectionTest.MockWorker, + pool_size: 3, + connection: [ + server: "snowflex.us-east-8.snowflakecomputing.com", + role: "DEV", + warehouse: "CUSTOMER_DEV_WH" + ] diff --git a/lib/snowflex/db_connection.ex b/lib/snowflex/db_connection.ex new file mode 100644 index 0000000..849b34e --- /dev/null +++ b/lib/snowflex/db_connection.ex @@ -0,0 +1,65 @@ +defmodule Snowflex.DBConnection do + @moduledoc """ + Defines a Snowflake connection with DBConnection adapter. + + ## Definition + + When used, the connection expects the `:otp_app` option. You may also define a standard timeout. + This will default to 60 seconds. + + ``` + defmodule SnowflakeDBConnection do + use Snowflex.DBConnection, + otp_app: :my_app, + timeout: :timer.seconds(60) + end + ``` + """ + + alias Snowflex.DBConnection.{ + Protocol, + Query, + Result + } + + @doc false + defmacro __using__(opts) do + quote bind_quoted: [opts: opts] do + # setup compile time config + otp_app = Keyword.fetch!(opts, :otp_app) + timeout = Keyword.get(opts, :timeout, :timer.seconds(60)) + + @otp_app otp_app + @timeout timeout + @name __MODULE__ + + def child_spec(_) do + config = Application.get_env(@otp_app, __MODULE__, []) + connection = Keyword.get(config, :connection, []) + + opts = + Keyword.merge(config, + name: @name, + timeout: @timeout, + connection: connection + ) + + DBConnection.child_spec(Protocol, opts) + end + + def execute(statement, params \\ []) when is_binary(statement) and is_list(params) do + case prepare_execute("", statement, params) do + {:ok, _query, result} -> {:ok, result} + {:error, error} -> {:error, error} + end + end + + defdelegate process_result(result, opts \\ [map_nulls_to_nil?: true]), to: Result + + defp prepare_execute(name, statement, params, opts \\ []) do + query = %Query{name: name, statement: statement} + DBConnection.prepare_execute(@name, query, params, opts) + end + end + end +end diff --git a/lib/snowflex/db_connection/error.ex b/lib/snowflex/db_connection/error.ex new file mode 100644 index 0000000..bad2b17 --- /dev/null +++ b/lib/snowflex/db_connection/error.ex @@ -0,0 +1,30 @@ +defmodule Snowflex.DBConnection.Error do + @moduledoc """ + Defines an error returned from the ODBC adapter. + """ + + defexception [:message] + + @type t :: %__MODULE__{ + message: String.t() + } + + @spec exception(term()) :: t() + def exception({odbc_code, native_code, reason}) do + message = + to_string(reason) <> + " - ODBC_CODE: " <> + to_string(odbc_code) <> + " - SNOWFLAKE_CODE: " <> to_string(native_code) + + %__MODULE__{ + message: message + } + end + + def exception(message) do + %__MODULE__{ + message: to_string(message) + } + end +end diff --git a/lib/snowflex/db_connection/protocol.ex b/lib/snowflex/db_connection/protocol.ex new file mode 100644 index 0000000..ac81bdb --- /dev/null +++ b/lib/snowflex/db_connection/protocol.ex @@ -0,0 +1,149 @@ +defmodule Snowflex.DBConnection.Protocol do + use DBConnection + + require Logger + + alias Snowflex.DBConnection.{ + Query, + Result, + Server + } + + defstruct pid: nil, status: :idle, conn_opts: [], worker: Server + + @type state :: %__MODULE__{ + pid: pid(), + status: :idle, + conn_opts: Keyword.t(), + worker: Server | any() + } + + ## DBConnection Callbacks + + @impl DBConnection + def connect(opts) do + connection_args = Keyword.fetch!(opts, :connection) + worker = Keyword.get(opts, :worker, Server) + + {:ok, pid} = worker.start_link(opts) + + state = %__MODULE__{ + pid: pid, + status: :idle, + conn_opts: connection_args, + worker: worker + } + + {:ok, state} + end + + @impl DBConnection + def disconnect(_err, %{pid: pid}), do: Server.disconnect(pid) + + @impl DBConnection + def checkout(state), do: {:ok, state} + + @impl DBConnection + def ping(state) do + query = %Query{name: "ping", statement: "SELECT /* snowflex:heartbeat */ 1;"} + + case do_query(query, [], [], state) do + {:ok, _, _, new_state} -> {:ok, new_state} + {:error, reason, new_state} -> {:disconnect, reason, new_state} + end + end + + @impl DBConnection + def handle_prepare(query, _opts, state) do + {:ok, query, state} + end + + @impl DBConnection + def handle_execute(query, params, opts, state) do + do_query(query, params, opts, state) + end + + @impl DBConnection + def handle_status(_, %{status: {status, _}} = state), do: {status, state} + def handle_status(_, %{status: status} = state), do: {status, state} + + @impl DBConnection + def handle_close(_query, _opts, state) do + {:ok, %Result{}, state} + end + + ## Not implemented Callbacks + + @impl DBConnection + def handle_begin(_opts, _state) do + throw("not implemented") + end + + @impl DBConnection + def handle_commit(_opts, _state) do + throw("not implemented") + end + + @impl DBConnection + def handle_rollback(_opts, _state) do + throw("not implemented") + end + + @impl DBConnection + def handle_declare(_query, _params, _opts, _state) do + throw("not implemeted") + end + + @impl DBConnection + def handle_deallocate(_query, _cursor, _opts, _state) do + throw("not implemeted") + end + + @impl DBConnection + def handle_fetch(_query, _cursor, _opts, _state) do + throw("not implemeted") + end + + ## Helpers + + defp do_query(%Query{} = query, [], opts, %{worker: worker} = state) do + case worker.sql_query(state.pid, query.statement, opts) do + {:ok, result} -> + result = parse_result(result, query) + {:ok, query, result, state} + + {:error, reason} -> + {:error, reason, state} + end + end + + defp do_query(%Query{} = query, params, opts, %{worker: worker} = state) do + case worker.param_query(state.pid, query.statement, params, opts) do + {:ok, result} -> + result = parse_result(result, query) + {:ok, query, result, state} + + {:error, reason} -> + {:error, reason, state} + end + end + + defp parse_result({:selected, columns, rows, _}, query), + do: parse_result({:selected, columns, rows}, query) + + defp parse_result({:selected, columns, rows}, query) do + parse_result(columns, rows, query) + end + + defp parse_result(result, _query), do: result + + defp parse_result(columns, rows, query) do + %Result{ + columns: Enum.map(columns, &to_string(&1)), + rows: rows, + num_rows: Enum.count(rows), + success: true, + statement: query.statement + } + end +end diff --git a/lib/snowflex/db_connection/query.ex b/lib/snowflex/db_connection/query.ex new file mode 100644 index 0000000..86b0c54 --- /dev/null +++ b/lib/snowflex/db_connection/query.ex @@ -0,0 +1,29 @@ +defmodule Snowflex.DBConnection.Query do + defstruct [ + :ref, + :name, + :statement, + :columns, + :result_oids, + cache: :reference + ] + + defimpl DBConnection.Query do + def parse(query, _opts), do: query + def describe(query, _opts), do: query + def encode(_query, params, _opts), do: params + def decode(_query, result, _opts), do: result + end + + defimpl String.Chars do + alias Snowflex.DBConnection.Query + + def to_string(%{statement: statement}) do + case statement do + statement when is_binary(statement) -> IO.iodata_to_binary(statement) + statement when is_list(statement) -> IO.iodata_to_binary(statement) + %{statement: %Query{} = q} -> String.Chars.to_string(q) + end + end + end +end diff --git a/lib/snowflex/db_connection/result.ex b/lib/snowflex/db_connection/result.ex new file mode 100644 index 0000000..30a7248 --- /dev/null +++ b/lib/snowflex/db_connection/result.ex @@ -0,0 +1,62 @@ +defmodule Snowflex.DBConnection.Result do + defstruct columns: nil, + rows: nil, + num_rows: 0, + metadata: [], + messages: [], + statement: nil, + success: false + + @type t :: %__MODULE__{ + columns: [String.t()] | nil, + rows: [tuple()] | nil, + num_rows: integer(), + metadata: [map()], + messages: [map()], + statement: String.t() | nil, + success: boolean() + } + + def process_result(result, opts \\ []) + + def process_result(%__MODULE__{columns: columns, rows: rows}, opts) do + process_results({:selected, columns, rows}, opts) + end + + def process_result({:updated, _} = result, _opts), do: result + + ## Helpers + + defp process_results(data, opts) when is_list(data) do + Enum.map(data, &process_results(&1, opts)) + end + + defp process_results({:selected, headers, rows}, opts) do + map_nulls_to_nil? = Keyword.get(opts, :map_nulls_to_nil?, true) + + bin_headers = + headers + |> Enum.map(fn header -> header |> to_string() |> String.downcase() end) + |> Enum.with_index() + + Enum.map(rows, fn row -> + Enum.reduce(bin_headers, %{}, fn {col, index}, map -> + data = + row + |> elem(index) + |> to_string_if_charlist() + |> map_null_to_nil(map_nulls_to_nil?) + + Map.put(map, col, data) + end) + end) + end + + defp process_results({:updated, _} = results, _opts), do: results + + defp to_string_if_charlist(data) when is_list(data), do: to_string(data) + defp to_string_if_charlist(data), do: data + + defp map_null_to_nil(:null, true), do: nil + defp map_null_to_nil(data, _), do: data +end diff --git a/lib/snowflex/db_connection/server.ex b/lib/snowflex/db_connection/server.ex new file mode 100644 index 0000000..e8d2d10 --- /dev/null +++ b/lib/snowflex/db_connection/server.ex @@ -0,0 +1,193 @@ +defmodule Snowflex.DBConnection.Server do + @moduledoc """ + Adapter to Erlang's `:odbc` module. + + A GenServer that handles communication between Elixir and Erlang's `:odbc` module. + """ + + use GenServer + + require Logger + + alias Snowflex.DBConnection.Error + alias Snowflex.Params + alias Snowflex.Telemetry + + @timeout :timer.seconds(60) + + ## Public API + + @doc """ + Starts the connection process to the ODBC driver. + """ + @spec start_link(Keyword.t()) :: {:ok, pid()} + def start_link(opts) do + connection_args = Keyword.fetch!(opts, :connection) + conn_str = connection_string(connection_args) + + GenServer.start_link(__MODULE__, [{:conn_str, to_charlist(conn_str)} | opts]) + end + + @doc """ + Sends a query to the ODBC driver. + + `pid` is the `:odbc` process id + `statement` is the SQL query string + `opts` are options to be passed on to `:odbc` + """ + @spec sql_query(pid(), iodata(), Keyword.t()) :: + {:ok, {:selected, [binary()], [tuple()]}} + | {:ok, {:selected, [binary()], [tuple()], [{binary()}]}} + | {:ok, {:updated, non_neg_integer()}} + | {:error, Error.t()} + def sql_query(pid, statement, opts \\ []) do + if Process.alive?(pid) do + statement = IO.iodata_to_binary(statement) + timeout = Keyword.get(opts, :timeout, @timeout) + + GenServer.call(pid, {:sql_query, %{statement: statement}}, timeout) + else + {:error, %Error{message: :no_connection}} + end + end + + @doc """ + Sends a parametrized query to the ODBC driver. + + `pid` is the `:odbc` process id + `statement` is the SQL query string + `params` are the parameters to send with the SQL query + `opts` are options to be passed on to `:odbc` + """ + @spec param_query(pid(), iodata(), Keyword.t(), Keyword.t()) :: + {:ok, {:selected, [binary()], [tuple()]}} + | {:ok, {:selected, [binary()], [tuple()], [{binary()}]}} + | {:ok, {:updated, non_neg_integer()}} + | {:error, Error.t()} + def param_query(pid, statement, params, opts \\ []) do + if Process.alive?(pid) do + statement = IO.iodata_to_binary(statement) + timeout = Keyword.get(opts, :timeout, @timeout) + + GenServer.call(pid, {:param_query, %{statement: statement, params: params}}, timeout) + else + {:error, %Error{message: :no_connection}} + end + end + + @doc """ + Disconnects from the ODBC driver. + """ + @spec disconnect(pid()) :: :ok + def disconnect(pid) do + GenServer.stop(pid, :normal) + end + + ## GenServer callbacks + + @impl GenServer + def init(opts) do + send(self(), {:start, opts}) + + {:ok, %{backoff: :backoff.init(2, 60), state: :not_connected}} + end + + @impl GenServer + def handle_call({:sql_query, _query}, _from, %{state: :not_connected} = state) do + {:reply, {:error, :not_connected}, state} + end + + def handle_call({:sql_query, %{statement: statement}}, _from, %{pid: pid} = state) do + start_time = Telemetry.sql_start(%{query: statement}) + + result = + case :odbc.sql_query(pid, to_charlist(statement)) do + {:error, reason} -> + error = Error.exception(reason) + Logger.warn("Unable to execute query: #{error.message}") + + {:reply, {:error, error}, state} + + result -> + {:reply, {:ok, result}, state} + end + + Telemetry.sql_stop(start_time) + + result + end + + def handle_call({:param_query, _query}, _from, %{state: :not_connected} = state) do + {:reply, {:error, :not_connected}, state} + end + + def handle_call( + {:param_query, %{statement: statement, params: params}}, + _from, + %{pid: pid} = state + ) do + start_time = Telemetry.param_start(%{query: statement, params: params}) + params = Params.prepare(params) + + result = + case :odbc.param_query(pid, to_charlist(statement), params) do + {:error, reason} -> + error = Error.exception(reason) + Logger.warn("Unable to execute query: #{error.message}") + + {:reply, {:error, error}, state} + + result -> + {:reply, {:ok, result}, state} + end + + Telemetry.param_stop(start_time) + + result + end + + @impl GenServer + def handle_info({:start, opts}, %{backoff: backoff} = _state) do + connect_opts = + opts + |> Keyword.delete_first(:conn_str) + |> Keyword.put_new(:auto_commit, :on) + |> Keyword.put_new(:binary_strings, :on) + |> Keyword.put_new(:tuple_row, :on) + |> Keyword.put_new(:extended_errors, :on) + + case :odbc.connect(opts[:conn_str], connect_opts) do + {:ok, pid} -> + {:noreply, %{pid: pid, backoff: :backoff.succeed(backoff), state: :connected}} + + {:error, reason} -> + Logger.warn("Unable to connect to snowflake: #{inspect(reason)}") + + seconds = + backoff + |> :backoff.get() + |> :timer.seconds() + + Process.send_after(self(), {:start, opts}, seconds) + + {_, new_backoff} = :backoff.fail(backoff) + + {:noreply, %{backoff: new_backoff, state: :not_connected}} + end + end + + @impl GenServer + def terminate(_reason, %{state: :not_connected} = _state), do: :ok + def terminate(_reason, %{pid: pid} = _state), do: :odbc.disconnect(pid) + + ## Helpers + + defp connection_string(connection_args) do + driver = Application.get_env(:snowflex, :driver) + connection_args = [{:driver, driver} | connection_args] + + Enum.reduce(connection_args, "", fn {key, value}, acc -> + acc <> "#{key}=#{value};" + end) + end +end diff --git a/lib/snowflex/params.ex b/lib/snowflex/params.ex new file mode 100644 index 0000000..56e4282 --- /dev/null +++ b/lib/snowflex/params.ex @@ -0,0 +1,47 @@ +defmodule Snowflex.Params do + @moduledoc """ + Provides shared functions for parameter parsing + """ + + @string_types ~w( + sql_char + sql_wchar + sql_varchar + sql_wvarchar + sql_wlongvarchar + )a + + def prepare(params) do + Enum.map(params, &prepare_param/1) + end + + def prepare_param({type, values}) when not is_list(values) do + prepare_param({type, [values]}) + end + + def prepare_param({{type_atom, _size} = type, values}) when type_atom in @string_types do + {type, Enum.map(values, &null_or_charlist/1)} + end + + def prepare_param({type, values}) do + {type, Enum.map(values, &null_or_any/1)} + end + + ## Helpers + + defp null_or_charlist(nil) do + :null + end + + defp null_or_charlist(val) do + to_charlist(val) + end + + defp null_or_any(nil) do + :null + end + + defp null_or_any(any) do + any + end +end diff --git a/lib/snowflex/telemetry.ex b/lib/snowflex/telemetry.ex new file mode 100644 index 0000000..3145e61 --- /dev/null +++ b/lib/snowflex/telemetry.ex @@ -0,0 +1,50 @@ +defmodule Snowflex.Telemetry do + @moduledoc """ + Shared telemetry module for registering events + """ + + @sql_start [:snowflex, :sql_query, :start] + @sql_stop [:snowflex, :sql_query, :stop] + @param_start [:snowflex, :param_query, :start] + @param_stop [:snowflex, :param_query, :stop] + + @spec sql_start(map(), map()) :: integer() + def sql_start(metadata \\ %{}, measurements \\ %{}) do + start_time = System.monotonic_time() + + measurements = Map.merge(measurements, %{system_time: System.system_time()}) + + :telemetry.execute(@sql_start, measurements, metadata) + + start_time + end + + @spec sql_stop(integer(), map(), map()) :: :ok + def sql_stop(start_time, metadata \\ %{}, measurements \\ %{}) do + end_time = System.monotonic_time() + + measurements = Map.merge(measurements, %{duration: end_time - start_time}) + + :telemetry.execute(@sql_stop, measurements, metadata) + end + + @spec param_start(map(), map()) :: integer() + def param_start(metadata \\ %{}, measurements \\ %{}) do + start_time = System.monotonic_time() + + measurements = Map.merge(measurements, %{system_time: System.system_time()}) + + :telemetry.execute(@param_start, measurements, metadata) + + start_time + end + + @spec param_stop(integer(), map(), map()) :: :ok + def param_stop(start_time, metadata \\ %{}, measurements \\ %{}) do + end_time = System.monotonic_time() + + measurements = Map.merge(measurements, %{duration: end_time - start_time}) + + :telemetry.execute(@param_stop, measurements, metadata) + end +end diff --git a/lib/snowflex/worker.ex b/lib/snowflex/worker.ex index 810f522..46ea8c0 100644 --- a/lib/snowflex/worker.ex +++ b/lib/snowflex/worker.ex @@ -2,22 +2,14 @@ defmodule Snowflex.Worker do @moduledoc false require Logger + use GenServer + alias Snowflex.Params + alias Snowflex.Telemetry + @timeout :timer.seconds(60) @gc_delay_ms 5 - @string_types ~w( - sql_char - sql_wchar - sql_varchar - sql_wvarchar - sql_wlongvarchar - )a - - @sql_start [:snowflex, :sql_query, :start] - @sql_stop [:snowflex, :sql_query, :stop] - @param_start [:snowflex, :param_query, :start] - @param_stop [:snowflex, :param_query, :stop] def start_link(connection_args) do GenServer.start_link(__MODULE__, connection_args, []) @@ -45,35 +37,29 @@ defmodule Snowflex.Worker do @impl GenServer def handle_call({:sql_query, query}, _from, state) do - start_time = System.monotonic_time() - :telemetry.execute(@sql_start, %{system_time: System.system_time()}, %{query: query}) + start_time = Telemetry.sql_start(%{query: query}) {result, state} = state |> do_sql_query(query) |> reschedule_heartbeat() - duration = System.monotonic_time() - start_time - :telemetry.execute(@sql_stop, %{duration: duration}) + Telemetry.sql_stop(start_time) + Process.send_after(self(), :gc, @gc_delay_ms) {:reply, result, state} end def handle_call({:param_query, query, params}, _from, state) do - start_time = System.monotonic_time() - - :telemetry.execute(@param_start, %{system_time: System.system_time()}, %{ - query: query, - params: params - }) + start_time = Telemetry.param_start(%{query: query, params: params}) {result, state} = state |> do_param_query(query, params) |> reschedule_heartbeat() - duration = System.monotonic_time() - start_time - :telemetry.execute(@param_stop, %{duration: duration}) + Telemetry.param_stop(start_time) + {:reply, result, state} end @@ -150,7 +136,7 @@ defmodule Snowflex.Worker do defp do_param_query(%{pid: pid} = state, query, params) do ch_query = to_charlist(query) - ch_params = prepare_params(params) + ch_params = Params.prepare(params) case :odbc.param_query(pid, ch_query, ch_params) do {:error, reason} -> @@ -162,38 +148,6 @@ defmodule Snowflex.Worker do end end - defp prepare_params(params) do - Enum.map(params, &prepare_param/1) - end - - defp prepare_param({type, values}) when not is_list(values) do - prepare_param({type, [values]}) - end - - defp prepare_param({{type_atom, _size} = type, values}) when type_atom in @string_types do - {type, Enum.map(values, &null_or_charlist/1)} - end - - defp prepare_param({type, values}) do - {type, Enum.map(values, &null_or_any/1)} - end - - defp null_or_charlist(nil) do - :null - end - - defp null_or_charlist(val) do - to_charlist(val) - end - - defp null_or_any(nil) do - :null - end - - defp null_or_any(any) do - any - end - defp send_heartbeat(state) do Logger.info("sending heartbeat") diff --git a/mix.exs b/mix.exs index f4971d4..0a81593 100644 --- a/mix.exs +++ b/mix.exs @@ -44,6 +44,7 @@ defmodule Snowflex.MixProject do {:poolboy, "~> 1.5.1"}, {:backoff, "~> 1.1.6"}, {:ecto, "~> 3.0"}, + {:db_connection, "~> 2.4"}, {:telemetry, "~> 0.4 or ~> 1.0"}, {:dialyxir, "~> 1.0", only: :dev, runtime: false}, {:ex_doc, "~> 0.21", only: :dev, runtime: false}, diff --git a/mix.lock b/mix.lock index 1a6da8b..b5f69e4 100644 --- a/mix.lock +++ b/mix.lock @@ -1,5 +1,7 @@ %{ "backoff": {:hex, :backoff, "1.1.6", "83b72ed2108ba1ee8f7d1c22e0b4a00cfe3593a67dbc792799e8cce9f42f796b", [:rebar3], [], "hexpm", "cf0cfff8995fb20562f822e5cc47d8ccf664c5ecdc26a684cbe85c225f9d7c39"}, + "connection": {:hex, :connection, "1.1.0", "ff2a49c4b75b6fb3e674bfc5536451607270aac754ffd1bdfe175abe4a6d7a68", [:mix], [], "hexpm", "722c1eb0a418fbe91ba7bd59a47e28008a189d47e37e0e7bb85585a016b2869c"}, + "db_connection": {:hex, :db_connection, "2.4.1", "6411f6e23f1a8b68a82fa3a36366d4881f21f47fc79a9efb8c615e62050219da", [:mix], [{:connection, "~> 1.0", [hex: :connection, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "ea36d226ec5999781a9a8ad64e5d8c4454ecedc7a4d643e4832bf08efca01f00"}, "decimal": {:hex, :decimal, "2.0.0", "a78296e617b0f5dd4c6caf57c714431347912ffb1d0842e998e9792b5642d697", [:mix], [], "hexpm", "34666e9c55dea81013e77d9d87370fe6cb6291d1ef32f46a1600230b1d44f577"}, "dialyxir": {:hex, :dialyxir, "1.1.0", "c5aab0d6e71e5522e77beff7ba9e08f8e02bad90dfbeffae60eaf0cb47e29488", [:mix], [{:erlex, ">= 0.2.6", [hex: :erlex, repo: "hexpm", optional: false]}], "hexpm", "07ea8e49c45f15264ebe6d5b93799d4dd56a44036cf42d0ad9c960bc266c0b9a"}, "earmark": {:hex, :earmark, "1.4.3", "364ca2e9710f6bff494117dbbd53880d84bebb692dafc3a78eb50aa3183f2bfd", [:mix], [], "hexpm", "8cf8a291ebf1c7b9539e3cddb19e9cef066c2441b1640f13c34c1d3cfc825fec"}, diff --git a/test/snowflex/db_connection/server_test.exs b/test/snowflex/db_connection/server_test.exs new file mode 100644 index 0000000..2d8814a --- /dev/null +++ b/test/snowflex/db_connection/server_test.exs @@ -0,0 +1,107 @@ +defmodule Snowflex.DBConnection.ServerTest do + use ExUnit.Case, async: false + + import ExUnit.CaptureLog + + alias Snowflex.DBConnection.Server + + @connection_args [ + connection: [ + server: "snowflex.us-east-8.snowflakecomputing.com", + role: "DEV", + warehouse: "CUSTOMER_DEV_WH" + ] + ] + + setup do + :meck.new(:odbc, [:no_link]) + on_exit(fn -> :meck.unload(:odbc) end) + end + + describe "telemetry events" do + setup do + :meck.expect(:odbc, :connect, fn _, _ -> {:ok, "mock pid"} end) + :meck.expect(:odbc, :sql_query, fn "mock pid", 'SELECT 1' -> "1" end) + on_exit(fn -> assert :meck.validate(:odbc) end) + end + + test "sends start and stop events" do + start_req_id = {:start, :rand.uniform(100)} + stop_req_id = {:stop, :rand.uniform(100)} + + on_exit(fn -> + :telemetry.detach(start_req_id) + :telemetry.detach(stop_req_id) + end) + + capture_log(fn -> + attach(start_req_id, [:snowflex, :sql_query, :start], self()) + attach(stop_req_id, [:snowflex, :sql_query, :stop], self()) + end) + + :meck.expect(:odbc, :sql_query, fn "mock pid", 'SELECT * FROM my_table' -> + {:selected, ['name'], [{'dustin'}]} + end) + + capture_log(fn -> + server = start_supervised!({Server, @connection_args}) + Server.sql_query(server, "SELECT * FROM my_table") + end) + + assert_received {:event, [:snowflex, :sql_query, :start], %{system_time: _}, + %{query: "SELECT * FROM my_table"}} + + assert_received {:event, [:snowflex, :sql_query, :stop], %{duration: _}, %{}} + end + end + + describe "param_query" do + setup do + :meck.expect(:odbc, :connect, fn _, _ -> {:ok, "mock pid"} end) + on_exit(fn -> assert :meck.validate(:odbc) end) + end + + test "with a string type, converts nil values to :null and strings to charlists" do + :meck.expect(:odbc, :param_query, fn "mock pid", + '[some param query]', + [{{:sql_varchar, 255}, ['abc', :null, 'def']}] -> + "1" + end) + + query = "[some param query]" + params = [{{:sql_varchar, 255}, ["abc", nil, "def"]}] + + capture_log(fn -> + server = start_supervised!({Server, @connection_args}) + Server.param_query(server, query, params) + end) + end + + test "with an integer type, converts nil values to :null" do + :meck.expect(:odbc, :param_query, fn "mock pid", + '[some param query]', + [{{:sql_integer, 255}, [123, :null, 456]}] -> + "1" + end) + + query = "[some param query]" + params = [{{:sql_integer, 255}, [123, nil, 456]}] + + capture_log(fn -> + server = start_supervised!({Server, @connection_args}) + Server.param_query(server, query, params) + end) + end + end + + defp attach(handler_id, event, pid) do + :telemetry.attach( + handler_id, + event, + fn event, measurements, metadata, _ -> + send(pid, {:event, event, measurements, metadata}) + end, + nil + ) + end +end diff --git a/test/snowflex/db_connection_test.exs b/test/snowflex/db_connection_test.exs new file mode 100644 index 0000000..4e41a6c --- /dev/null +++ b/test/snowflex/db_connection_test.exs @@ -0,0 +1,78 @@ +defmodule Snowflex.DBConnectionTest do + use ExUnit.Case, async: true + + defmodule SnowflakeDBConnection do + use Snowflex.DBConnection, + otp_app: :snowflex + end + + defmodule MockWorker do + use GenServer + + # API + + def start_link(_) do + GenServer.start_link(__MODULE__, nil, []) + end + + def sql_query(pid, query, _opts) do + GenServer.call(pid, {:sql_query, query}, 1000) + end + + def param_query(pid, query, params, _opts) do + GenServer.call(pid, {:param_query, query, params}, 1000) + end + + # Callbacks + + def init(_) do + {:ok, %{}} + end + + def handle_call({:sql_query, "insert " <> _}, _from, state) do + {:reply, {:ok, {:updated, 123}}, state} + end + + def handle_call({:sql_query, _}, _from, state) do + {:reply, {:ok, {:selected, ['col'], [{1}, {2}]}}, state} + end + + def handle_call({:param_query, "insert " <> _, _}, _from, state) do + {:reply, {:ok, {:updated, 456}}, state} + end + + def handle_call({:param_query, _, _}, _from, state) do + {:reply, {:ok, {:selected, ['col'], [{1}, {2}]}}, state} + end + end + + setup_all do + start_supervised!(SnowflakeDBConnection) + + :ok + end + + describe "execute/1" do + test "should execute a sql query" do + assert {:ok, result} = SnowflakeDBConnection.execute("my query") + assert [%{"col" => 1}, %{"col" => 2}] == SnowflakeDBConnection.process_result(result) + end + + test "should execute an insert query" do + assert {:ok, result} = SnowflakeDBConnection.execute("insert query") + assert {:updated, 123} == SnowflakeDBConnection.process_result(result) + end + end + + describe "execute/2" do + test "should execute a param query" do + assert {:ok, result} = SnowflakeDBConnection.execute("my query", []) + assert [%{"col" => 1}, %{"col" => 2}] == SnowflakeDBConnection.process_result(result) + end + + test "should execute an insert param query" do + assert {:ok, result} = SnowflakeDBConnection.execute("insert query", ["params"]) + assert {:updated, 456} == SnowflakeDBConnection.process_result(result) + end + end +end