Skip to content

Commit

Permalink
fix eps for nTransformer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 28, 2024
1 parent f56477a commit 50f9a5e
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
5 changes: 2 additions & 3 deletions nGPT_pytorch/nGPTExperimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,14 @@ def l2norm(
eps = None,
groups = 1
):
eps = default(eps, 1e-5 if t.dtype == torch.float16 else 1e-10)

if groups > 1:
t = t.chunk(groups, dim = dim)
t = torch.stack(t)

if norm_eps == 0.:
out = F.normalize(t, dim = dim, p = 2, eps = eps)
out = F.normalize(t, dim = dim, p = 2)
else:
eps = default(eps, 1e-5 if t.dtype == torch.float16 else 1e-10)
norm = t.norm(dim = dim, keepdim = True)
target_norm = norm.detach().clamp(min = 1. - norm_eps, max = 1. + norm_eps)
divisor = norm / target_norm
Expand Down
4 changes: 2 additions & 2 deletions nGPT_pytorch/nTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ def l2norm(
eps = None,
groups = 1
):
eps = default(eps, 1e-5 if t.dtype == torch.float16 else 1e-10)

if groups > 1:
t = t.chunk(groups, dim = dim)
t = torch.stack(t)

if norm_eps == 0.:
out = F.normalize(t, dim = dim, p = 2, eps = eps)
out = F.normalize(t, dim = dim, p = 2)
else:
eps = default(eps, 1e-5 if t.dtype == torch.float16 else 1e-10)
norm = t.norm(dim = dim, keepdim = True)
target_norm = norm.detach().clamp(min = 1. - norm_eps, max = 1. + norm_eps)
divisor = norm / target_norm
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "nGPT-pytorch"
version = "0.1.12"
version = "0.1.14"
description = "nGPT"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down

0 comments on commit 50f9a5e

Please sign in to comment.