Skip to content

Commit

Permalink
Lower fp8 lora memory usage.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Sep 3, 2024
1 parent d043997 commit 00a5d08
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions comfy/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ def manual_stochastic_round_to_float8(x, dtype, generator=None):
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
)
del abs_x

return sign.to(dtype=dtype)
return sign



Expand All @@ -57,6 +56,11 @@ def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
generator = torch.Generator(device=value.device)
generator.manual_seed(seed)
return manual_stochastic_round_to_float8(value, dtype, generator=generator)
output = torch.empty_like(value, dtype=dtype)
num_slices = max(1, (value.numel() / (4096 * 4096)))
slice_size = max(1, round(value.shape[0] / num_slices))
for i in range(0, value.shape[0], slice_size):
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
return output

return value.to(dtype=dtype)

0 comments on commit 00a5d08

Please sign in to comment.