Skip to content

Commit

Permalink
Transformer seems to be running
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Nov 19, 2023
1 parent a165987 commit fe6e32d
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions language_interpolation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,21 +110,17 @@ def generate_transformer_text(
just = ((len(text) // characters_per_feature) + 1) * characters_per_feature
text_list[index] = text.rjust(just)

print("text_list", text_list)
print("text lengths", [len(text) for text in text_list])
results = []
for text_in in text_list:
for i in range(output_size):
encoding, text_used = encode_input_from_text(
text_in=text_in, features=max_characters
)
print("encoding length", len(encoding), encoding.shape)
encoding = (
ascii_to_float(encoding)
.to(model._device)
.reshape(1, -1, characters_per_feature)
)
print("encoding", encoding)
model.eval()
output = model(encoding)
values, indices, ascii = decode_output_to_text(
Expand Down Expand Up @@ -153,7 +149,7 @@ def __init__(self, cfg):
def on_train_epoch_end(self, trainer, pl_module, outputs=None):
with torch.no_grad():
for topk in range(1, self._cfg.topk + 1):
if self._cfg.model_type == "high_order_transformer":
if self._cfg.net.model_type == "high_order_transformer":
predictions = generate_transformer_text(
pl_module,
characters_per_feature=self._cfg.data.characters_per_feature,
Expand Down

0 comments on commit fe6e32d

Please sign in to comment.