Skip to content

Commit

Permalink
Updating the sample
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Nov 19, 2023
1 parent cea24ab commit a165987
Showing 1 changed file with 27 additions and 15 deletions.
42 changes: 27 additions & 15 deletions language_interpolation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,24 +107,24 @@ def generate_transformer_text(
model.eval()

for index, text in enumerate(text_list):
just = ((len(text) // characters_per_feature)+1) * characters_per_feature
just = ((len(text) // characters_per_feature) + 1) * characters_per_feature
text_list[index] = text.rjust(just)

print('text_list', text_list)
print('text lengths', [len(text) for text in text_list])
print("text_list", text_list)
print("text lengths", [len(text) for text in text_list])
results = []
for text_in in text_list:
for i in range(output_size):
encoding, text_used = encode_input_from_text(
text_in=text_in, features=max_characters
)
print('encoding length', len(encoding), encoding.shape)
print("encoding length", len(encoding), encoding.shape)
encoding = (
ascii_to_float(encoding)
.to(model._device)
.reshape(1, -1, characters_per_feature)
)
print('encoding', encoding)
print("encoding", encoding)
model.eval()
output = model(encoding)
values, indices, ascii = decode_output_to_text(
Expand All @@ -135,10 +135,11 @@ def generate_transformer_text(
# prevents the same response for every query.
actual = random.choices(ascii, values.tolist())
text_in = text_in + actual[0]
just = ((len(text_in) // characters_per_feature)+1) * characters_per_feature
just = (
(len(text_in) // characters_per_feature) + 1
) * characters_per_feature
text_in = text_in.rjust(just)


results.append(text_in.replace("\n", " "))

return results
Expand All @@ -152,14 +153,25 @@ def __init__(self, cfg):
def on_train_epoch_end(self, trainer, pl_module, outputs=None):
with torch.no_grad():
for topk in range(1, self._cfg.topk + 1):
predictions = generate_text(
pl_module,
features=self._cfg.net.features,
text_list=self._cfg.prompts,
output_size=self._cfg.num_predict,
topk=topk,
add_channel_dimension=self._cfg.data.add_channel_dimension,
)
if self._cfg.model_type == "high_order_transformer":
predictions = generate_transformer_text(
pl_module,
characters_per_feature=self._cfg.data.characters_per_feature,
max_characters=self._cfg.data.characters_per_feature
* self._cfg.data.max_features,
text_list=self._cfg.prompts,
output_size=self._cfg.num_predict,
topk=topk,
)
else:
predictions = generate_text(
pl_module,
features=self._cfg.net.features,
text_list=self._cfg.prompts,
output_size=self._cfg.num_predict,
topk=topk,
add_channel_dimension=self._cfg.data.add_channel_dimension,
)

for index, text in enumerate(predictions):
trainer.logger.experiment.add_text(
Expand Down

0 comments on commit a165987

Please sign in to comment.