diff --git a/lib/paper_trail/multi.ex b/lib/paper_trail/multi.ex index 434a7358..91e00339 100644 --- a/lib/paper_trail/multi.ex +++ b/lib/paper_trail/multi.ex @@ -143,14 +143,22 @@ defmodule PaperTrail.Multi do {:error, Map.merge(changeset, %{repo: repo, changes: filtered_changes})} {:ok, map} -> - {:ok, Map.drop(map, [:initial_version])} + {:ok, map |> Map.drop([:initial_version]) |> return_operation(options)} end _ -> case transaction do {:error, :model, changeset, %{}} -> {:error, Map.merge(changeset, %{repo: repo})} - _ -> transaction + {:ok, result} -> {:ok, return_operation(result, options)} end end end + + @spec return_operation(map, Keyword.t()) :: any + defp return_operation(result, options) do + case Keyword.get(options, :return_operation) do + nil -> result + operation -> Map.fetch!(result, operation) + end + end end diff --git a/test/paper_trail/base_test.exs b/test/paper_trail/base_test.exs index d9ccb564..c01d5ccd 100644 --- a/test/paper_trail/base_test.exs +++ b/test/paper_trail/base_test.exs @@ -91,6 +91,18 @@ defmodule PaperTrailTest do assert company == first(Company, :id) |> @repo.one |> serialize end + test "creating a company with return_operation option works" do + {:ok, company} = create_company_with_version(@create_company_params, return_operation: :model) + + company_count = Company.count() + version_count = Version.count() + + assert company_count == 1 + assert version_count == 1 + + assert company == Company |> first(:id) |> @repo.one + end + test "PaperTrail.insert/2 with an error returns and error tuple like Repo.insert/2" do result = create_company_with_version(%{name: nil, is_active: true, city: "Greenwich"})