Skip to content

Commit

Permalink
[cm] Adding text param descriptions to wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
christhetree committed Feb 6, 2024
1 parent 1f9b011 commit 300d3ee
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions neutone_sdk/non_realtime_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:
self.cat_param_n_values = {}
self.cat_param_labels = {}

# TODO(cm): param validation, check default isn't longer than char limit
self.n_text_params = 0
self.text_param_names = []
self.text_param_descriptions = []
self.text_param_max_n_chars = []
self.text_param_default_values = []

Expand Down Expand Up @@ -99,6 +101,7 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:
elif p.type == NeutoneParameterType.TEXT:
self.n_text_params += 1
self.text_param_names.append(p.name)
self.text_param_descriptions.append(p.description)
self.text_param_max_n_chars.append(p.max_n_chars)
self.text_param_default_values.append(p.default_value)

Expand All @@ -113,6 +116,7 @@ def __init__(self, model: nn.Module, use_debug_mode: bool = True) -> None:
self.cat_param_labels["__torchscript_typing"] = ["__torchscript_typing"]
if not self.n_text_params:
self.text_param_names.append("__torchscript_typing")
self.text_param_descriptions.append("__torchscript_typing")
self.text_param_max_n_chars.append(-1)
self.text_param_default_values.append("__torchscript_typing")

Expand Down Expand Up @@ -549,6 +553,15 @@ def get_text_param_names(self) -> List[str]:
return []
return self.text_param_names

@tr.jit.export
def get_text_param_descriptions(self) -> List[str]:
"""
Returns the names of the text parameters.
"""
if not self.n_text_params: # Needed for TorchScript typing
return []
return self.text_param_descriptions

@tr.jit.export
def get_text_param_max_n_chars(self) -> List[int]:
"""
Expand Down Expand Up @@ -592,6 +605,7 @@ def get_preserved_attributes(self) -> List[str]:
"get_categorical_param_n_values",
"get_categorical_param_labels",
"get_text_param_names",
"get_text_param_descriptions",
"get_text_param_max_n_chars",
"get_text_param_default_values",
"get_preserved_attributes",
Expand Down

0 comments on commit 300d3ee

Please sign in to comment.