From 300d3eeb4e3fdee8f482feb3cc7085c07836654a Mon Sep 17 00:00:00 2001 From: christhetree Date: Tue, 6 Feb 2024 10:54:48 +0000 Subject: [PATCH] [cm] Adding text param descriptions to wrapper --- neutone_sdk/non_realtime_wrapper.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/neutone_sdk/non_realtime_wrapper.py b/neutone_sdk/non_realtime_wrapper.py index 4d11e99..2d52948 100644 --- a/neutone_sdk/non_realtime_wrapper.py +++ b/neutone_sdk/non_realtime_wrapper.py @@ -63,8 +63,10 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None: self.cat_param_n_values = {} self.cat_param_labels = {} + # TODO(cm): param validation, check default isn't longer than char limit self.n_text_params = 0 self.text_param_names = [] + self.text_param_descriptions = [] self.text_param_max_n_chars = [] self.text_param_default_values = [] @@ -99,6 +101,7 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None: elif p.type == NeutoneParameterType.TEXT: self.n_text_params += 1 self.text_param_names.append(p.name) + self.text_param_descriptions.append(p.description) self.text_param_max_n_chars.append(p.max_n_chars) self.text_param_default_values.append(p.default_value) @@ -113,6 +116,7 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None: self.cat_param_labels["__torchscript_typing"] = ["__torchscript_typing"] if not self.n_text_params: self.text_param_names.append("__torchscript_typing") + self.text_param_descriptions.append("__torchscript_typing") self.text_param_max_n_chars.append(-1) self.text_param_default_values.append("__torchscript_typing") @@ -549,6 +553,15 @@ def get_text_param_names(self) -> List[str]: return [] return self.text_param_names + @tr.jit.export + def get_text_param_descriptions(self) -> List[str]: + """ + Returns the names of the text parameters. + """ + if not self.n_text_params: # Needed for TorchScript typing + return [] + return self.text_param_descriptions + @tr.jit.export def get_text_param_max_n_chars(self) -> List[int]: """ @@ -592,6 +605,7 @@ def get_preserved_attributes(self) -> List[str]: "get_categorical_param_n_values", "get_categorical_param_labels", "get_text_param_names", + "get_text_param_descriptions", "get_text_param_max_n_chars", "get_text_param_default_values", "get_preserved_attributes",