Skip to content

Commit

Permalink
Fix default defect classification
Browse files Browse the repository at this point in the history
  • Loading branch information
HardNorth committed Nov 22, 2024
1 parent db663fc commit df4d3b5
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions app/machine_learning/models/defect_type_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def predict_proba(self, data: pd.DataFrame) -> np.ndarray:
DEFAULT_VECTORIZER = DummyVectorizer()


def get_model(default_value: Any, model_name: str) -> Any:
def get_model(self: DefaultDict, model_name: str, default_value: any) -> Any:
m = BASE_DEFECT_TYPE_PATTERN.match(model_name)
if not m:
raise KeyError(model_name)
Expand All @@ -67,15 +67,18 @@ def get_model(default_value: Any, model_name: str) -> Any:
base_model_name = m.group(2)
if not base_model_name:
raise KeyError(model_name)
return default_value
if base_model_name in self:
return self[base_model_name]
else:
return default_value


def get_vectorizer_model(_: Any, model_name: str) -> Any:
return get_model(DEFAULT_VECTORIZER, model_name)
def get_vectorizer_model(self: Any, model_name: str) -> Any:
return get_model(self, model_name, DEFAULT_VECTORIZER)


def get_classifier_model(_: Any, model_name: str) -> Any:
return get_model(DEFAULT_MODEL, model_name)
def get_classifier_model(self: Any, model_name: str) -> Any:
return get_model(self, model_name, DEFAULT_MODEL)


class DefectTypeModel(MlModel):
Expand Down

0 comments on commit df4d3b5

Please sign in to comment.