Skip to content

Commit

Permalink
Align batch handling for serving run and batched run (#252)
Browse files Browse the repository at this point in the history
Co-authored-by: José Valim <jose.valim@dashbit.co>
  • Loading branch information
jonatanklosko and josevalim authored Sep 25, 2023
1 parent 7302867 commit a503832
Show file tree
Hide file tree
Showing 15 changed files with 45 additions and 26 deletions.
22 changes: 17 additions & 5 deletions lib/bumblebee/audio/speech_to_text_whisper.ex
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size)
|> Nx.Serving.batch_size(batch_size)
|> Nx.Serving.client_preprocessing(fn input ->
if opts[:stream] do
Shared.validate_input_for_stream!(input)
Expand All @@ -81,7 +81,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
{inputs, multi?} =
Shared.validate_serving_input!(input, fn
%Nx.Tensor{shape: {_}} = input ->
{:ok, input}
{:ok, Nx.backend_transfer(input, Nx.BinaryBackend)}

{:file, path} when is_binary(path) ->
ffmpeg_read_as_pcm(path, sampling_rate)
Expand All @@ -104,8 +104,20 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
all_chunks = List.flatten(all_chunks)
{all_chunks, lengths} = Enum.unzip(all_chunks)

inputs = Bumblebee.Featurizer.process_input(featurizer, all_chunks)
{Nx.Batch.concatenate([inputs]), {multi?, all_num_chunks, lengths}}
if batch_size do
stream =
all_chunks
|> Stream.chunk_every(batch_size)
|> Stream.map(fn all_chunks ->
inputs = Bumblebee.Featurizer.process_input(featurizer, all_chunks)
Nx.Batch.concatenate([inputs])
end)

{stream, {multi?, all_num_chunks, lengths}}
else
inputs = Bumblebee.Featurizer.process_input(featurizer, all_chunks)
{Nx.Batch.concatenate([inputs]), {multi?, all_num_chunks, lengths}}
end
end)
|> maybe_stream(opts[:stream], spec, featurizer, tokenizer, timestamps?)
end
Expand Down Expand Up @@ -571,7 +583,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
])
|> case do
{data, 0} ->
{:ok, Nx.from_binary(data, :f32)}
{:ok, Nx.from_binary(data, :f32, backend: Nx.BinaryBackend)}

{_, 1} ->
{:error, "ffmpeg failed to decode the given file"}
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/diffusion/stable_diffusion.ex
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
fn defn_options -> apply(&init/10, init_args ++ [defn_options]) end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size)
|> Nx.Serving.batch_size(batch_size)
|> Nx.Serving.client_preprocessing(&client_preprocessing(&1, tokenizer, sequence_length))
|> Nx.Serving.client_postprocessing(&client_postprocessing(&1, &2, safety_checker))
end
Expand Down
3 changes: 2 additions & 1 deletion lib/bumblebee/text/conversation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ defmodule Bumblebee.Text.Conversation do
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys)
|> Nx.Serving.batch_size(batch_size)
|> Nx.Serving.process_options(batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn input ->
{histories, multi?} = Shared.validate_serving_input!(input, &validate_input/1)

Expand Down
3 changes: 2 additions & 1 deletion lib/bumblebee/text/fill_mask.ex
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ defmodule Bumblebee.Text.FillMask do
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys)
|> Nx.Serving.batch_size(batch_size)
|> Nx.Serving.process_options(batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn input ->
{texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1)

Expand Down
3 changes: 2 additions & 1 deletion lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,8 @@ defmodule Bumblebee.Text.Generation do
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys)
|> Nx.Serving.batch_size(batch_size)
|> Nx.Serving.process_options(batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn input ->
if opts[:stream] do
Shared.validate_input_for_stream!(input)
Expand Down
3 changes: 2 additions & 1 deletion lib/bumblebee/text/question_answering.ex
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ defmodule Bumblebee.Text.QuestionAnswering do
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys)
|> Nx.Serving.batch_size(batch_size)
|> Nx.Serving.process_options(batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn raw_input ->
{raw_inputs, multi?} =
Shared.validate_serving_input!(raw_input, fn
Expand Down
3 changes: 2 additions & 1 deletion lib/bumblebee/text/text_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ defmodule Bumblebee.Text.TextClassification do
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys)
|> Nx.Serving.batch_size(batch_size)
|> Nx.Serving.process_options(batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn input ->
{texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1)

Expand Down
3 changes: 2 additions & 1 deletion lib/bumblebee/text/text_embedding.ex
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ defmodule Bumblebee.Text.TextEmbedding do
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys)
|> Nx.Serving.batch_size(batch_size)
|> Nx.Serving.process_options(batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn input ->
{texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1)

Expand Down
3 changes: 2 additions & 1 deletion lib/bumblebee/text/token_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ defmodule Bumblebee.Text.TokenClassification do
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys)
|> Nx.Serving.batch_size(batch_size)
|> Nx.Serving.process_options(batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn input ->
{texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1)

Expand Down
3 changes: 2 additions & 1 deletion lib/bumblebee/text/zero_shot_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ defmodule Bumblebee.Text.ZeroShotClassification do
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys)
|> Nx.Serving.batch_size(batch_size)
|> Nx.Serving.process_options(batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn input ->
{texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1)

Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/vision/image_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ defmodule Bumblebee.Vision.ImageClassification do
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size)
|> Nx.Serving.batch_size(batch_size)
|> Nx.Serving.client_preprocessing(fn input ->
{images, multi?} = Shared.validate_serving_input!(input, &Shared.validate_image/1)
inputs = Bumblebee.Featurizer.process_input(featurizer, images)
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/vision/image_embedding.ex
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ defmodule Bumblebee.Vision.ImageEmbedding do
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size)
|> Nx.Serving.batch_size(batch_size)
|> Nx.Serving.client_preprocessing(fn input ->
{images, multi?} = Shared.validate_serving_input!(input, &Shared.validate_image/1)

Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/vision/image_to_text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ defmodule Bumblebee.Vision.ImageToText do
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size)
|> Nx.Serving.batch_size(batch_size)
|> Nx.Serving.client_preprocessing(fn input ->
{images, multi?} = Shared.validate_serving_input!(input, &Shared.validate_image/1)
inputs = Bumblebee.Featurizer.process_input(featurizer, images)
Expand Down
6 changes: 3 additions & 3 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ defmodule Bumblebee.MixProject do
[
{:axon, "~> 0.6.0", axon_opts()},
{:tokenizers, "~> 0.4"},
{:nx, "~> 0.6.1"},
{:exla, "~> 0.6.1", only: [:dev, :test]},
{:torchx, "~> 0.6.1", only: [:dev, :test]},
{:nx, "~> 0.6.2"},
{:exla, ">= 0.0.0", only: [:dev, :test]},
{:torchx, ">= 0.0.0", only: [:dev, :test]},
# {:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
# {:exla, github: "elixir-nx/nx", sparse: "exla", override: true, only: [:dev, :test]},
# {:torchx, github: "elixir-nx/nx", sparse: "torchx", override: true, only: [:dev, :test]},
Expand Down
11 changes: 5 additions & 6 deletions mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,18 @@
"cowboy_telemetry": {:hex, :cowboy_telemetry, "0.4.0", "f239f68b588efa7707abce16a84d0d2acf3a0f50571f8bb7f56a15865aae820c", [:rebar3], [{:cowboy, "~> 2.7", [hex: :cowboy, repo: "hexpm", optional: false]}, {:telemetry, "~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "7d98bac1ee4565d31b62d59f8823dfd8356a169e7fcbb83831b8a5397404c9de"},
"cowlib": {:hex, :cowlib, "2.11.0", "0b9ff9c346629256c42ebe1eeb769a83c6cb771a6ee5960bd110ab0b9b872063", [:make, :rebar3], [], "hexpm", "2b3e9da0b21c4565751a6d4901c20d1b4cc25cbb7fd50d91d2ab6dd287bc86a9"},
"decimal": {:hex, :decimal, "2.1.1", "5611dca5d4b2c3dd497dec8f68751f1f1a54755e8ed2a966c2633cf885973ad6", [:mix], [], "hexpm", "53cfe5f497ed0e7771ae1a475575603d77425099ba5faef9394932b35020ffcc"},
"dll_loader_helper": {:hex, :dll_loader_helper, "1.1.0", "e7d015e980942a0d67e306827ec907e7e853a21186bd92bb968d986698591a0f", [:mix], [{:dll_loader_helper_beam, "~> 1.1", [hex: :dll_loader_helper_beam, repo: "hexpm", optional: false]}], "hexpm", "2b6c11ee7bb48f6a132ce8f872202f9e828c019988da1e2d40ad41496195df0c"},
"dll_loader_helper_beam": {:hex, :dll_loader_helper_beam, "1.1.0", "d51232663985dbc998c59b5d080feecd5398d5b75a9f0293a9855db774c2684d", [:rebar3], [], "hexpm", "aa85d0d0e9398916a80b2fd751885877934ae3ea008288f99ff829c0b8ef1f55"},
"dll_loader_helper_beam": {:hex, :dll_loader_helper_beam, "1.2.0", "557c43befb8e3b119b718da302adccde3bd855acdb999498a14a2a8d2814b8b9", [:rebar3], [], "hexpm", "a2115d4bf1cca488a7b33f3c648847f64019b32c0382d10286d84dd5c3cbc0e5"},
"earmark_parser": {:hex, :earmark_parser, "1.4.33", "3c3fd9673bb5dcc9edc28dd90f50c87ce506d1f71b70e3de69aa8154bc695d44", [:mix], [], "hexpm", "2d526833729b59b9fdb85785078697c72ac5e5066350663e5be6a1182da61b8f"},
"elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"},
"ex_doc": {:hex, :ex_doc, "0.30.6", "5f8b54854b240a2b55c9734c4b1d0dd7bdd41f71a095d42a70445c03cf05a281", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "bd48f2ddacf4e482c727f9293d9498e0881597eae6ddc3d9562bd7923375109f"},
"exla": {:hex, :exla, "0.6.1", "a4400933a04d018c5fb508c75a080c73c3c1986f6c16a79bbfee93ba22830d4d", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.6.1", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.5.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "f0e95b0f91a937030cf9fcbe900c9d26933cb31db2a26dfc8569aa239679e6d4"},
"exla": {:git, "https://github.com/elixir-nx/nx.git", "70de11d469d58de776ed829be767505930207e58", [sparse: "exla"]},
"jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"},
"makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"},
"makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"},
"makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"},
"mime": {:hex, :mime, "2.0.3", "3676436d3d1f7b81b5a2d2bd8405f412c677558c81b1c92be58c00562bb59095", [:mix], [], "hexpm", "27a30bf0db44d25eecba73755acf4068cbfe26a4372f9eb3e4ea3a45956bff6b"},
"nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"},
"nx": {:hex, :nx, "0.6.1", "df65cd61312bcaa756559fb994596d403c822e353616094fdfc31a15181c95f8", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "23dcc8e2824a6e19fcdebef39145fdff7625fd7d26fd50c1990ac0a1dd05f960"},
"nx": {:git, "https://github.com/elixir-nx/nx.git", "70de11d469d58de776ed829be767505930207e58", [sparse: "nx"]},
"nx_image": {:hex, :nx_image, "0.1.1", "69cf0d2fd873d12b028583aa49b5e0a25f6aca307afc337a5d871851a20fba1d", [:mix], [{:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "55c8206a822237f6027168f11214e3887263c5b8a1f8e0634eea82c96e5093e3"},
"nx_signal": {:hex, :nx_signal, "0.2.0", "e1ca0318877b17c81ce8906329f5125f1e2361e4c4235a5baac8a95ee88ea98e", [:mix], [{:nx, "~> 0.6", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "7247e5e18a177a59c4cb5355952900c62fdeadeb2bad02a9a34237b68744e2bb"},
"plug": {:hex, :plug, "1.14.2", "cff7d4ec45b4ae176a227acd94a7ab536d9b37b942c8e8fa6dfc0fff98ff4d80", [:mix], [{:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:plug_crypto, "~> 1.1.1 or ~> 1.2", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.3 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "842fc50187e13cf4ac3b253d47d9474ed6c296a8732752835ce4a86acdf68d13"},
Expand All @@ -34,8 +33,8 @@
"stb_image": {:hex, :stb_image, "0.6.2", "d680a418416b1d778231d1d16151be3474d187e8505e1bd524aa0d08d2de094f", [:make, :mix], [{:cc_precompiler, "~> 0.1.0", [hex: :cc_precompiler, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7.0", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: true]}], "hexpm", "231ad012f649dd2bd5ef99e9171e814f3235e8f7c45009355789ac4836044a39"},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
"tokenizers": {:hex, :tokenizers, "0.4.0", "140283ca74a971391ddbd83cd8cbdb9bd03736f37a1b6989b82d245a95e1eb97", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.6", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "ef1a9824f5a893cd3b831c0e5b3d72caa250d2ec462035cc6afef6933b13a82e"},
"torchx": {:hex, :torchx, "0.6.1", "2a9862ebc4b397f42c51f0fa3f9f4e3451a83df6fba42882f8523cbc925c8ae1", [:make, :mix], [{:dll_loader_helper, "~> 0.1 or ~> 1.0", [hex: :dll_loader_helper, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.6.1", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "99b3fc73b52d6cfbe5cad8bdd74277ddc99297ce8fc6765b1dabec80681e8d9d"},
"torchx": {:git, "https://github.com/elixir-nx/nx.git", "70de11d469d58de776ed829be767505930207e58", [sparse: "torchx"]},
"unpickler": {:hex, :unpickler, "0.1.0", "c2262c0819e6985b761e7107546cef96a485f401816be5304a65fdd200d5bd6a", [:mix], [], "hexpm", "e2b3f61e62406187ac52afead8a63bfb4e49394028993f3c4c42712743cab79e"},
"unzip": {:hex, :unzip, "0.8.0", "ee21d87c21b01567317387dab4228ac570ca15b41cfc221a067354cbf8e68c4d", [:mix], [], "hexpm", "ffa67a483efcedcb5876971a50947222e104d5f8fea2c4a0441e6f7967854827"},
"xla": {:hex, :xla, "0.5.0", "fb8a02c02e5a4f4531fbf18a90c325e471037f983f0115d23f510e7dd9a6aa65", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "571ac797a4244b8ba8552ed0295a54397bd896708be51e4da6cbb784f6678061"},
"xla": {:hex, :xla, "0.5.1", "8ba4c2c51c1a708ff54e9d4f88158c1a75b7f2cb3e5db02bf222b5b3852afffd", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "82a2490f6e9a76c8a29d1aedb47f07c59e3d5081095eac5a74db34d46c8212bc"},
}

0 comments on commit a503832

Please sign in to comment.