Skip to content

Commit

Permalink
Working on the sample generator
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Nov 19, 2023
1 parent 26cf879 commit cea24ab
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 13 deletions.
3 changes: 1 addition & 2 deletions language_interpolation/single_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def unify_ids(specific_ids: List[int], id_range: List[int]):
return ids


def encode_input_from_text(text_in: str, features: int) -> Tuple[torch.tensor, str]:
def encode_input_from_text(text_in: str, features: int=0) -> Tuple[torch.tensor, str]:
"""
Convert a string to input that the network can take. Take the last "features" number
of characters and convert to numbers. Return those numbers as the network input, also
Expand All @@ -76,7 +76,6 @@ def encode_input_from_text(text_in: str, features: int) -> Tuple[torch.tensor, s
encoding = [ord(val) for val in raw_sample]
return torch.tensor(encoding), raw_sample


def decode_output_to_text(
encoding: torch.tensor, topk: int = 1
) -> Tuple[torch.tensor, str]:
Expand Down
49 changes: 48 additions & 1 deletion language_interpolation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def generate_text(
topk: int = 1,
add_channel_dimension: bool = False,
):

model.eval()

features = features
Expand Down Expand Up @@ -97,6 +96,54 @@ def generate_text(
return results


def generate_transformer_text(
model: nn.Module,
characters_per_feature: int,
max_characters: int,
text_list: List[str],
output_size: int,
topk: int = 1,
):
model.eval()

for index, text in enumerate(text_list):
just = ((len(text) // characters_per_feature)+1) * characters_per_feature
text_list[index] = text.rjust(just)

print('text_list', text_list)
print('text lengths', [len(text) for text in text_list])
results = []
for text_in in text_list:
for i in range(output_size):
encoding, text_used = encode_input_from_text(
text_in=text_in, features=max_characters
)
print('encoding length', len(encoding), encoding.shape)
encoding = (
ascii_to_float(encoding)
.to(model._device)
.reshape(1, -1, characters_per_feature)
)
print('encoding', encoding)
model.eval()
output = model(encoding)
values, indices, ascii = decode_output_to_text(
encoding=output[0], topk=topk
)

# pick the next character weighted by probabilities of each character
# prevents the same response for every query.
actual = random.choices(ascii, values.tolist())
text_in = text_in + actual[0]
just = ((len(text_in) // characters_per_feature)+1) * characters_per_feature
text_in = text_in.rjust(just)


results.append(text_in.replace("\n", " "))

return results


class TextGenerationSampler(Callback):
def __init__(self, cfg):
super().__init__()
Expand Down
34 changes: 24 additions & 10 deletions tests/test_attention_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,52 @@
from language_interpolation.lightning_datamodule import TransformerDataModule

from omegaconf import DictConfig
from language_interpolation.utils import generate_transformer_text


def test_attention_network():
characters_per_feature = 10

def test_attention_network() :
data_module = TransformerDataModule(
characters_per_feature=10,
max_features=100,
batch_size=32,
gutenberg_ids_test=[1],
gutenberg_ids_train=[2],
gutenberg_ids_val=[3],
pre_process_workers=0
pre_process_workers=0,
)

data_module.setup()

train_dataloader = data_module.train_dataloader()
input_data, output, indexes = next(iter(train_dataloader))
print('indexes', indexes)
print('input shape', input_data.shape, "output.shape", output.shape)
print("indexes", indexes)
print("input shape", input_data.shape, "output.shape", output.shape)

assert len(indexes) == 32
assert input_data.shape[0]==32
assert input_data.shape[2]==10
assert input_data.shape[0] == 32
assert input_data.shape[2] == 10

network = HighOrderAttentionNetwork(
layers = [[10, 5],[5, 5]],
layers=[[10, 5], [5, 5]],
n=3,
segments=2,
normalization=None,
layer_type="continuous",
device='cpu'
device="cpu",
)
result = network(input_data)
print('result', result.shape)
print("result", result.shape)
assert result.shape[0] == 32
assert result.shape[1] == 1
assert result.shape[1] == 128

text_list = ["hello sir", "Test this now"]
ans = generate_transformer_text(
model=network,
text_list=text_list,
characters_per_feature=characters_per_feature,
max_characters=1000,
output_size=10
)
print("ans", ans)

0 comments on commit cea24ab

Please sign in to comment.