Skip to content

Commit

Permalink
Trying to test large character spacing
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Dec 2, 2023
1 parent 82b4a35 commit 0abd22e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
2 changes: 1 addition & 1 deletion language_interpolation/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def forward(self, x: Tensor) -> Tensor:
# characters are small spacinb
# xp = small_character_spacing(x=x, max_context=self.max_context, positional_embedding=self.positional_embedding)
# characters are large spacing
xp = large_character_spacing(x=x, max_context=self.max_context, positional_embedding=self.positional_embebdding)
xp = large_character_spacing(x=x, max_context=self.max_context, positional_embedding=self.positional_embedding)

query = xp
key = xp
Expand Down
21 changes: 15 additions & 6 deletions tests/test_attention_network.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest

from language_interpolation.networks import HighOrderAttentionNetwork
from language_interpolation.networks import HighOrderAttentionNetwork, large_character_spacing, small_character_spacing
from language_interpolation.lightning_datamodule import TransformerDataModule

from omegaconf import DictConfig
from language_interpolation.utils import generate_transformer_text

import torch

def test_attention_network():
characters_per_feature = 10
Expand Down Expand Up @@ -38,21 +38,30 @@ def test_attention_network():
normalization=None,
layer_type="continuous",
device="cpu",
heads =2,
max_context=max_features
heads=2,
max_context=max_features,
)
result = network(input_data)
print('final result', result)
print("final result", result)
print("result", result.shape)
assert result.shape[0] == 32
assert result.shape[1] == 128

new_sample = torch.rand(1, max_features, 10)*2-1

output = large_character_spacing(
x=new_sample,
max_context=network.max_context,
positional_embedding=network.positional_embedding,
)
print('output', output)

text_list = ["hello sir", "Test this now"]
ans = generate_transformer_text(
model=network,
text_list=text_list,
characters_per_feature=characters_per_feature,
max_characters=1000,
output_size=10
output_size=10,
)
print("ans", ans)

0 comments on commit 0abd22e

Please sign in to comment.