diff --git a/language_interpolation/utils.py b/language_interpolation/utils.py index 04f3dc4..ee7ac44 100644 --- a/language_interpolation/utils.py +++ b/language_interpolation/utils.py @@ -107,24 +107,24 @@ def generate_transformer_text( model.eval() for index, text in enumerate(text_list): - just = ((len(text) // characters_per_feature)+1) * characters_per_feature + just = ((len(text) // characters_per_feature) + 1) * characters_per_feature text_list[index] = text.rjust(just) - print('text_list', text_list) - print('text lengths', [len(text) for text in text_list]) + print("text_list", text_list) + print("text lengths", [len(text) for text in text_list]) results = [] for text_in in text_list: for i in range(output_size): encoding, text_used = encode_input_from_text( text_in=text_in, features=max_characters ) - print('encoding length', len(encoding), encoding.shape) + print("encoding length", len(encoding), encoding.shape) encoding = ( ascii_to_float(encoding) .to(model._device) .reshape(1, -1, characters_per_feature) ) - print('encoding', encoding) + print("encoding", encoding) model.eval() output = model(encoding) values, indices, ascii = decode_output_to_text( @@ -135,10 +135,11 @@ def generate_transformer_text( # prevents the same response for every query. actual = random.choices(ascii, values.tolist()) text_in = text_in + actual[0] - just = ((len(text_in) // characters_per_feature)+1) * characters_per_feature + just = ( + (len(text_in) // characters_per_feature) + 1 + ) * characters_per_feature text_in = text_in.rjust(just) - results.append(text_in.replace("\n", " ")) return results @@ -152,14 +153,25 @@ def __init__(self, cfg): def on_train_epoch_end(self, trainer, pl_module, outputs=None): with torch.no_grad(): for topk in range(1, self._cfg.topk + 1): - predictions = generate_text( - pl_module, - features=self._cfg.net.features, - text_list=self._cfg.prompts, - output_size=self._cfg.num_predict, - topk=topk, - add_channel_dimension=self._cfg.data.add_channel_dimension, - ) + if self._cfg.model_type == "high_order_transformer": + predictions = generate_transformer_text( + pl_module, + characters_per_feature=self._cfg.data.characters_per_feature, + max_characters=self._cfg.data.characters_per_feature + * self._cfg.data.max_features, + text_list=self._cfg.prompts, + output_size=self._cfg.num_predict, + topk=topk, + ) + else: + predictions = generate_text( + pl_module, + features=self._cfg.net.features, + text_list=self._cfg.prompts, + output_size=self._cfg.num_predict, + topk=topk, + add_channel_dimension=self._cfg.data.add_channel_dimension, + ) for index, text in enumerate(predictions): trainer.logger.experiment.add_text(