Skip to content

Commit

Permalink
Merge pull request #156 from basf/val_data-fix
Browse files Browse the repository at this point in the history
fix validation dataset bug
  • Loading branch information
AnFreTh authored Nov 12, 2024
2 parents 4d1f787 + 7de4982 commit 74ad3aa
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 42 deletions.
30 changes: 16 additions & 14 deletions mambular/models/sklearn_base_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def build_model(
X = pd.DataFrame(X)
if isinstance(y, pd.Series):
y = y.values
if X_val:
if X_val is not None:
if not isinstance(X_val, pd.DataFrame):
X_val = pd.DataFrame(X_val)
if isinstance(y_val, pd.Series):
Expand Down Expand Up @@ -218,14 +218,14 @@ def build_model(
config=self.config,
cat_feature_info=self.data_module.cat_feature_info,
num_feature_info=self.data_module.num_feature_info,
lr_patience=lr_patience
if lr_patience is not None
else self.config.lr_patience,
lr_patience=(
lr_patience if lr_patience is not None else self.config.lr_patience
),
lr=lr if lr is not None else self.config.lr,
lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor,
weight_decay=weight_decay
if weight_decay is not None
else self.config.weight_decay,
weight_decay=(
weight_decay if weight_decay is not None else self.config.weight_decay
),
optimizer_type=self.optimizer_type,
optimizer_args=self.optimizer_kwargs,
)
Expand Down Expand Up @@ -345,7 +345,7 @@ def fit(
X = pd.DataFrame(X)
if isinstance(y, pd.Series):
y = y.values
if X_val:
if X_val is not None:
if not isinstance(X_val, pd.DataFrame):
X_val = pd.DataFrame(X_val)
if isinstance(y_val, pd.Series):
Expand Down Expand Up @@ -375,14 +375,16 @@ def fit(
config=self.config,
cat_feature_info=self.data_module.cat_feature_info,
num_feature_info=self.data_module.num_feature_info,
lr_patience=lr_patience
if lr_patience is not None
else self.config.lr_patience,
lr_patience=(
lr_patience if lr_patience is not None else self.config.lr_patience
),
lr=lr if lr is not None else self.config.lr,
lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor,
weight_decay=weight_decay
if weight_decay is not None
else self.config.weight_decay,
weight_decay=(
weight_decay
if weight_decay is not None
else self.config.weight_decay
),
optimizer_type=self.optimizer_type,
optimizer_args=self.optimizer_kwargs,
)
Expand Down
28 changes: 14 additions & 14 deletions mambular/models/sklearn_base_lss.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def build_model(
X = pd.DataFrame(X)
if isinstance(y, pd.Series):
y = y.values
if X_val:
if X_val is not None:
if not isinstance(X_val, pd.DataFrame):
X_val = pd.DataFrame(X_val)
if isinstance(y_val, pd.Series):
Expand Down Expand Up @@ -238,13 +238,13 @@ def build_model(
cat_feature_info=self.data_module.cat_feature_info,
num_feature_info=self.data_module.num_feature_info,
lr=lr if lr is not None else self.config.lr,
lr_patience=lr_patience
if lr_patience is not None
else self.config.lr_patience,
lr_patience=(
lr_patience if lr_patience is not None else self.config.lr_patience
),
lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor,
weight_decay=weight_decay
if weight_decay is not None
else self.config.weight_decay,
weight_decay=(
weight_decay if weight_decay is not None else self.config.weight_decay
),
lss=True,
optimizer_type=self.optimizer_type,
optimizer_args=self.optimizer_kwargs,
Expand Down Expand Up @@ -388,7 +388,7 @@ def fit(
X = pd.DataFrame(X)
if isinstance(y, pd.Series):
y = y.values
if X_val:
if X_val is not None:
if not isinstance(X_val, pd.DataFrame):
X_val = pd.DataFrame(X_val)
if isinstance(y_val, pd.Series):
Expand Down Expand Up @@ -418,13 +418,13 @@ def fit(
cat_feature_info=self.data_module.cat_feature_info,
num_feature_info=self.data_module.num_feature_info,
lr=lr if lr is not None else self.config.lr,
lr_patience=lr_patience
if lr_patience is not None
else self.config.lr_patience,
lr_patience=(
lr_patience if lr_patience is not None else self.config.lr_patience
),
lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor,
weight_decay=weight_decay
if weight_decay is not None
else self.config.weight_decay,
weight_decay=(
weight_decay if weight_decay is not None else self.config.weight_decay
),
lss=True,
optimizer_type=self.optimizer_type,
optimizer_args=self.optimizer_kwargs,
Expand Down
30 changes: 16 additions & 14 deletions mambular/models/sklearn_base_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def build_model(
X = pd.DataFrame(X)
if isinstance(y, pd.Series):
y = y.values
if X_val:
if X_val is not None:
if not isinstance(X_val, pd.DataFrame):
X_val = pd.DataFrame(X_val)
if isinstance(y_val, pd.Series):
Expand Down Expand Up @@ -217,13 +217,13 @@ def build_model(
cat_feature_info=self.data_module.cat_feature_info,
num_feature_info=self.data_module.num_feature_info,
lr=lr if lr is not None else self.config.lr,
lr_patience=lr_patience
if lr_patience is not None
else self.config.lr_patience,
lr_patience=(
lr_patience if lr_patience is not None else self.config.lr_patience
),
lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor,
weight_decay=weight_decay
if weight_decay is not None
else self.config.weight_decay,
weight_decay=(
weight_decay if weight_decay is not None else self.config.weight_decay
),
optimizer_type=self.optimizer_type,
optimizer_args=self.optimizer_kwargs,
)
Expand Down Expand Up @@ -341,7 +341,7 @@ def fit(
X = pd.DataFrame(X)
if isinstance(y, pd.Series):
y = y.values
if X_val:
if X_val is not None:
if not isinstance(X_val, pd.DataFrame):
X_val = pd.DataFrame(X_val)
if isinstance(y_val, pd.Series):
Expand Down Expand Up @@ -369,13 +369,15 @@ def fit(
cat_feature_info=self.data_module.cat_feature_info,
num_feature_info=self.data_module.num_feature_info,
lr=lr if lr is not None else self.config.lr,
lr_patience=lr_patience
if lr_patience is not None
else self.config.lr_patience,
lr_patience=(
lr_patience if lr_patience is not None else self.config.lr_patience
),
lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor,
weight_decay=weight_decay
if weight_decay is not None
else self.config.weight_decay,
weight_decay=(
weight_decay
if weight_decay is not None
else self.config.weight_decay
),
optimizer_type=self.optimizer_type,
optimizer_args=self.optimizer_kwargs,
)
Expand Down

0 comments on commit 74ad3aa

Please sign in to comment.