From 8f71f6cb8e9dc6eab551fac03d9914e045dee4d3 Mon Sep 17 00:00:00 2001 From: jloveric Date: Wed, 29 Nov 2023 19:39:01 -0800 Subject: [PATCH] Special positional embedding --- language_interpolation/networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/language_interpolation/networks.py b/language_interpolation/networks.py index bba4616..b9c10a2 100644 --- a/language_interpolation/networks.py +++ b/language_interpolation/networks.py @@ -343,8 +343,8 @@ def __init__( def forward(self, x: Tensor) -> Tensor: - # Scale the input to [-0.5*max_context, 0.5*max_context] where every token is bumped by 1 - # the 0th token is 0 and the max_context token is 0.5*max_context-1 + # Scale the input to [-1, 1] where every token is bumped by 1/(2*max_context) + # the 0th token is -1 and the nth token is 1 # THIS LOOKS RIGHT! xp = ((0.5 * (x + 1) + self.positional_embedding[: x.shape[1]])*2 - self.max_context)/self.max_context