Skip to content

Commit

Permalink
Fix get_info functionality in TP model
Browse files Browse the repository at this point in the history
  • Loading branch information
liord committed Dec 19, 2024
1 parent 6d6dbb0 commit 76a4ec7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def clone_and_edit_weight_attribute(self, attrs: List[str] = None, **kwargs) ->
# Ensure all attributes exist in the config
for attr in attrs_to_update:
if attr not in qc.attr_weights_configs_mapping:
Logger.critical(f"{attr} does not exist in {qc}.")
Logger.critical(f"{attr} does not exist in {qc}.") # pragma: no cover
updated_attr_mapping = {
attr: qc.attr_weights_configs_mapping[attr].clone_and_edit(**kwargs)
for attr in attrs_to_update
Expand Down Expand Up @@ -348,15 +348,7 @@ class TargetPlatformModelComponent:
"""
Component of TargetPlatformModel (Fusing, OperatorsSet, etc.).
"""
def get_info(self) -> Dict[str, Any]:
"""
Get information about the component to display.
Returns:
Dict[str, Any]: Returns an empty dictionary. The actual component should override
this method to provide relevant information.
"""
return {}
pass


@dataclass(frozen=True)
Expand Down Expand Up @@ -419,17 +411,6 @@ def __post_init__(self):
# Set the inherited name attribute using `object.__setattr__` since the dataclass is frozen
object.__setattr__(self, "name", concatenated_name)

def get_info(self) -> Dict[str, Any]:
"""
Get information about the concatenated set as a dictionary.
Returns:
Dict[str, Any]: A dictionary containing the concatenated name and
the list of names of the operator sets in `operators_set`.
"""
return {"name": self.name,
OPS_SET_LIST: [s.name for s in self.operators_set]}


@dataclass(frozen=True)
class Fusing(TargetPlatformModelComponent):
Expand Down Expand Up @@ -561,7 +542,7 @@ def get_info(self) -> Dict[str, Any]:
"""
return {
"Model name": self.name,
"Operators sets": [o.get_info() for o in self.operator_set],
"Operators sets": [o.get_info() for o in self.operator_set] if self.operator_set else [],
"Fusing patterns": [f.get_info() for f in self.fusing_patterns] if self.fusing_patterns else [],
}

Expand Down
4 changes: 3 additions & 1 deletion tests/common_tests/test_tp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ def test_default_options_more_than_single_qc(self):

def test_tp_model_show(self):
tpm = schema.TargetPlatformModel(TEST_QCO,
operator_set=tuple([schema.OperatorsSet("opA")]),
tpc_minor_version=None,
tpc_patch_version=None,
tpc_platform_type=None,
operator_set=tuple([schema.OperatorsSet("opA"), schema.OperatorsSet("opB")]),
fusing_patterns=tuple(
[schema.Fusing((schema.OperatorsSet("opA"), schema.OperatorsSet("opB")))]),
add_metadata=False)
tpm.show()

Expand Down

0 comments on commit 76a4ec7

Please sign in to comment.