Skip to content

Commit

Permalink
Add :type option to load model under specific precision (#311)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Dec 18, 2023
1 parent 47c364d commit b947dd2
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 5 deletions.
51 changes: 46 additions & 5 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -317,16 +317,36 @@ defmodule Bumblebee do
@doc """
Builds an `Axon` model according to the given specification.
## Options
* `:type` - either a type or `Axon.MixedPrecision` policy to apply
to the model
## Example
spec = Bumblebee.configure(Bumblebee.Vision.ResNet, architecture: :base, embedding_size: 128)
model = Bumblebee.build_model(spec)
"""
@doc type: :model
@spec build_model(Bumblebee.ModelSpec.t()) :: Axon.t()
def build_model(%module{} = spec) do
module.model(spec)
@spec build_model(Bumblebee.ModelSpec.t(), keyword()) :: Axon.t()
def build_model(%module{} = spec, opts \\ []) do
opts = Keyword.validate!(opts, [:type])

model = module.model(spec)

case opts[:type] do
nil ->
model

%Axon.MixedPrecision.Policy{} = policy ->
Axon.MixedPrecision.apply_policy(model, policy)

type ->
type = Nx.Type.normalize!(type)
policy = Axon.MixedPrecision.create_policy(params: type, compute: type, output: type)
Axon.MixedPrecision.apply_policy(model, policy)
end
end

@doc """
Expand Down Expand Up @@ -446,6 +466,22 @@ defmodule Bumblebee do
The model is downloaded and cached on your disk, use `cache_dir/0` to
find the location.
## Parameters precision
On GPUs computations that use numeric type of lower precision can
be faster and use less memory, while still providing valid results.
You can configure the model to use particular type by passing the
`:type` option, such as `:bf16`.
Some repositories have multiple variants of the parameter files
with different numeric types. The variant is usually indicated in
the file extension and you can load a particular file by specifying
`:params_variant`, or `:params_filename`. Note however that this
does not determine the numeric type used for inference. The file
type is relevant in context of download bandwidth and disk space.
If you want to use a lower precision for inference, make sure to
also specify `:type`.
## Options
* `:spec` - the model specification to use when building the model.
Expand All @@ -470,6 +506,10 @@ defmodule Bumblebee do
* `:backend` - the backend to allocate the tensors on. It is either
an atom or a tuple in the shape `{backend, options}`
* `:type` - either a type or `Axon.MixedPrecision` policy to apply
to the model. Passing this option automatically casts parameters
to the desired type
## Examples
By default the model type is inferred from configuration, so loading
Expand Down Expand Up @@ -502,13 +542,14 @@ defmodule Bumblebee do
:architecture,
:params_variant,
:params_filename,
:log_params_diff,
:backend,
:log_params_diff
:type
])

with {:ok, repo_files} <- get_repo_files(repository),
{:ok, spec} <- maybe_load_model_spec(opts, repository, repo_files),
model <- build_model(spec),
model <- build_model(spec, Keyword.take(opts, [:type])),
{:ok, params} <- load_params(spec, model, repository, repo_files, opts) do
{:ok, %{model: model, params: params, spec: spec}}
end
Expand Down
10 changes: 10 additions & 0 deletions lib/bumblebee/conversion/pytorch.ex
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ defmodule Bumblebee.Conversion.PyTorch do

case verify_param_shape(param_expr, value) do
:ok ->
value = ensure_type(param_expr, value)
{value, diff}

{:error, expected, actual} ->
Expand Down Expand Up @@ -486,6 +487,15 @@ defmodule Bumblebee.Conversion.PyTorch do
Utils.Nx.map(expr, &Nx.shape/1)
end

defp ensure_type(param_expr, value) do
Utils.Nx.zip_with(param_expr, value, fn expr, tensor ->
case {Nx.type(expr), Nx.type(tensor)} do
{type, type} -> tensor
{expected, _actual} -> Nx.as_type(tensor, expected)
end
end)
end

defp unflatten_leading(tensor, axis_size) do
shape =
tensor
Expand Down
40 changes: 40 additions & 0 deletions lib/bumblebee/utils/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,46 @@ defmodule Bumblebee.Utils.Nx do
|> elem(0)
end

@doc """
Recursively zips the given containers with the given function.
"""
@spec zip_with(
tensor_or_container,
tensor_or_container,
(Nx.Tensor.t(), Nx.Tensor.t() -> term())
) :: tensor_or_container
when tensor_or_container: Nx.Tensor.t() | Nx.Container.t()
def zip_with(left, right, fun)

def zip_with(%Nx.Tensor{} = left, %Nx.Tensor{} = right, fun) do
fun.(left, right)
end

def zip_with(left, right, fun) do
right_items =
right
|> Nx.Container.reduce([], fn item, acc -> [item | acc] end)
|> Enum.reverse()

case Nx.Container.traverse(left, right_items, &recur_zip_with(&1, &2, fun)) do
{result, []} ->
result

{_result, _leftover} ->
raise ArgumentError, "unable to merge arguments with incompatible structure"
end
end

defp recur_zip_with(left, [right | right_items], fun) do
case {left, right} do
{%Nx.Tensor{} = left, %Nx.Tensor{} = right} ->
{fun.(left, right), right_items}

{left, right} ->
{recur_zip_with(left, right, fun), right_items}
end
end

@doc """
Returns the underlying tensor as a list.
Expand Down
18 changes: 18 additions & 0 deletions test/bumblebee_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,23 @@ defmodule BumblebeeTest do
)
end
end

test "passing :type casts params accordingly" do
assert {:ok, %{params: params}} =
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"},
type: :bf16
)

assert Nx.type(params["decoder.blocks.0.ffn.output"]["kernel"]) == {:bf, 16}
assert Nx.type(params["decoder.blocks.0.ffn.output"]["bias"]) == {:bf, 16}

assert {:ok, %{params: params}} =
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"},
type: Axon.MixedPrecision.create_policy(params: :f16)
)

assert Nx.type(params["decoder.blocks.0.ffn.output"]["kernel"]) == {:f, 16}
assert Nx.type(params["decoder.blocks.0.ffn.output"]["bias"]) == {:f, 16}
end
end
end

0 comments on commit b947dd2

Please sign in to comment.