Skip to content

Commit

Permalink
Fix duplicate QCOs error: raise error only if duplicate QCOs are not …
Browse files Browse the repository at this point in the history
…identical. (#1282)
  • Loading branch information
elad-c authored Dec 1, 2024
1 parent 952746a commit 52bb1c5
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
7 changes: 4 additions & 3 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,9 +556,10 @@ def get_qco(self, tpc: TargetPlatformCapabilities) -> QuantizationConfigOptions:
# Extract qco with is_match_type to overcome mismatch of function types in TF 2.15
matching_qcos = [_qco for _type, _qco in tpc.layer2qco.items() if self.is_match_type(_type)]
if matching_qcos:
if len(matching_qcos) > 1:
Logger.error('Found duplicate qco types!')
return matching_qcos[0]
if all([_qco == matching_qcos[0] for _qco in matching_qcos]):
return matching_qcos[0]
else:
Logger.critical(f"Found duplicate qco types for node '{self.name}' of type '{self.type}'!") # pragma: no cover
return tpc.tp_model.default_qco

def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,5 @@ def is_match_type(self, _type: Type) -> bool:
Whether _type matches the self node type
"""
names_match = _type.__name__ == self.type.__name__ if FOUND_TF else False
names_match = _type.__name__ == self.type.__name__
return super().is_match_type(_type) or names_match

0 comments on commit 52bb1c5

Please sign in to comment.