Skip to content

Commit

Permalink
Working on a very simple positional embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Nov 29, 2023
1 parent a52832e commit e191fa0
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 24 deletions.
58 changes: 36 additions & 22 deletions language_interpolation/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,27 +141,28 @@ def __init__(
device=device,
)

self.value_layer=value_layer
self.value_layer = value_layer
if value_layer is None:
self.value_layer = torch.nn.Linear(
in_features=self.embed_dim,
out_features=self.out_dim * heads,
device=device,
)

self.output_layer = output_layer
if output_layer is None :
if output_layer is None:
self.output_layer = torch.nn.Linear(
in_features=self.out_dim*heads,
in_features=self.out_dim * heads,
out_features=self.out_dim,
device=device,
)


if normalization is None:
self.normalization = lambda x: x
# It seems pytorch doesn't like lambdas so
# explicitly define it.
def noop(x):
return x

self.normalization = normalization
self.normalization = normalization or noop

def forward(
self,
Expand Down Expand Up @@ -191,19 +192,19 @@ def forward(
vt = vt.reshape(value.shape[0], value.shape[1], vt.shape[1])

qkv_list = []
for head in range(self.heads) :
start = head*self.out_dim
end = (head+1)*self.out_dim
qth = qt[:,:,start:end]
kth = kt[:,:,start:end]
vth = vt[:,:,start:end]
for head in range(self.heads):
start = head * self.out_dim
end = (head + 1) * self.out_dim
qth = qt[:, :, start:end]
kth = kt[:, :, start:end]
vth = vt[:, :, start:end]

qkh = torch.nn.functional.softmax(qth @ kth.transpose(1, 2))

# Matrix multiply of last 2 dimensions
qkv_list.append(qkh @ vth)

res = torch.cat(qkv_list,dim=2)
res = torch.cat(qkv_list, dim=2)

v = res.reshape(res.shape[0] * res.shape[1], -1)
output = self.output_layer(v)
Expand Down Expand Up @@ -251,7 +252,7 @@ def high_order_attention_block(
output = high_order_fc_layers(
layer_type=layer_type,
n=n,
in_features=out_dim*heads,
in_features=out_dim * heads,
out_features=out_dim,
segments=segments,
device=device,
Expand All @@ -273,7 +274,14 @@ def high_order_attention_block(

class HighOrderAttentionNetwork(torch.nn.Module):
def __init__(
self, layer_type: str, layers: list, n: int, normalization: None, heads:int=1, device: str='cuda'
self,
layer_type: str,
layers: list,
n: int,
normalization: None,
heads: int = 1,
device: str = "cuda",
max_context: int = 10,
):
super().__init__()
self._device = device
Expand Down Expand Up @@ -304,10 +312,16 @@ def __init__(
device=device,
)

# Make the positions 0 to max_context-1
self.positional_embedding = (
torch.arange(max_context, dtype=torch.get_default_dtype()).unsqueeze(1) + 0.5
) / (max_context - 1.0)

def forward(self, x: Tensor) -> Tensor:
query = x
key = x
value = x
xp = x+self.positional_embedding[:x.shape[1]]
query = xp
key = xp
value = xp
for layer in self.layer:
res = layer(query, key, value)
query = res
Expand Down Expand Up @@ -364,7 +378,7 @@ def select_network(cfg: DictConfig, device: str = None):
cfg.net.n,
normalization=None,
device=cfg.accelerator,
heads=cfg.net.heads
heads=cfg.net.heads,
)

elif cfg.net.model_type == "high_order":
Expand Down
6 changes: 4 additions & 2 deletions tests/test_attention_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

def test_attention_network():
characters_per_feature = 10
max_features = 100

data_module = TransformerDataModule(
characters_per_feature=10,
max_features=100,
max_features=max_features,
batch_size=32,
gutenberg_ids_test=[1],
gutenberg_ids_train=[2],
Expand All @@ -37,7 +38,8 @@ def test_attention_network():
normalization=None,
layer_type="continuous",
device="cpu",
heads =2
heads =2,
max_context=max_features
)
result = network(input_data)
print('result', result)
Expand Down

0 comments on commit e191fa0

Please sign in to comment.