From 490b7abce937b2008cf4b7b5dd6ca8f579f967f7 Mon Sep 17 00:00:00 2001 From: Daiki Katsuragawa <50144563+daikikatsuragawa@users.noreply.github.com> Date: Wed, 16 Nov 2022 14:48:56 +0000 Subject: [PATCH 1/3] Fix incompatible types in assignment Signed-off-by: Daiki Katsuragawa <50144563+daikikatsuragawa@users.noreply.github.com> --- dice_ml/explainer_interfaces/dice_KD.py | 2 +- dice_ml/explainer_interfaces/dice_genetic.py | 13 +++++-- dice_ml/explainer_interfaces/dice_pytorch.py | 9 +++-- dice_ml/explainer_interfaces/dice_random.py | 8 ++-- .../explainer_interfaces/dice_tensorflow2.py | 3 +- .../explainer_interfaces/explainer_base.py | 39 +++++++++++-------- .../explainer_interfaces/feasible_base_vae.py | 8 ++-- .../feasible_model_approx.py | 5 ++- 8 files changed, 51 insertions(+), 36 deletions(-) diff --git a/dice_ml/explainer_interfaces/dice_KD.py b/dice_ml/explainer_interfaces/dice_KD.py index deed36f8..02004082 100644 --- a/dice_ml/explainer_interfaces/dice_KD.py +++ b/dice_ml/explainer_interfaces/dice_KD.py @@ -163,7 +163,7 @@ def vary_valid(self, KD_query_instance, total_CFs, features_to_vary, permitted_r # TODO: this should be a user-specified parameter num_queries = min(len(self.dataset_with_predictions), total_CFs * 10) - cfs = [] + cfs = pd.DataFrame() if self.KD_tree is not None and num_queries > 0: KD_tree_output = self.KD_tree.query(KD_query_instance, num_queries) diff --git a/dice_ml/explainer_interfaces/dice_genetic.py b/dice_ml/explainer_interfaces/dice_genetic.py index 1c9c21a8..9fe27b39 100644 --- a/dice_ml/explainer_interfaces/dice_genetic.py +++ b/dice_ml/explainer_interfaces/dice_genetic.py @@ -5,6 +5,7 @@ import copy import random import timeit +from typing import Any, List, Union import numpy as np import pandas as pd @@ -27,7 +28,7 @@ def __init__(self, data_interface, model_interface): self.num_output_nodes = None # variables required to generate CFs - see generate_counterfactuals() for more info - self.cfs = [] + self.cfs = pd.DataFrame() self.features_to_vary = [] self.cf_init_weights = [] # total_CFs, algorithm, features_to_vary self.loss_weights = [] # yloss_type, diversity_loss_type, feature_weights @@ -343,12 +344,16 @@ def _predict_fn_custom(self, input_instance, desired_class): def compute_yloss(self, cfs, desired_range, desired_class): """Computes the first part (y-loss) of the loss function.""" - yloss = 0.0 + yloss: Any = 0.0 if self.model.model_type == ModelTypes.Classifier: predicted_value = np.array(self.predict_fn_scores(cfs)) if self.yloss_type == 'hinge_loss': maxvalue = np.full((len(predicted_value)), -np.inf) - for c in range(self.num_output_nodes): + if self.num_output_nodes is None: + num_output_nodes = 0 + else: + num_output_nodes = self.num_output_nodes + for c in range(num_output_nodes): if c != desired_class: maxvalue = np.maximum(maxvalue, predicted_value[:, c]) yloss = np.maximum(0, maxvalue - predicted_value[:, int(desired_class)]) @@ -429,7 +434,7 @@ def mate(self, k1, k2, features_to_vary, query_instance): def find_counterfactuals(self, query_instance, desired_range, desired_class, features_to_vary, maxiterations, thresh, verbose): """Finds counterfactuals by generating cfs through the genetic algorithm""" - population = self.cfs.copy() + population: Any = self.cfs.copy() iterations = 0 previous_best_loss = -np.inf current_best_loss = np.inf diff --git a/dice_ml/explainer_interfaces/dice_pytorch.py b/dice_ml/explainer_interfaces/dice_pytorch.py index 12412fda..39f895e0 100644 --- a/dice_ml/explainer_interfaces/dice_pytorch.py +++ b/dice_ml/explainer_interfaces/dice_pytorch.py @@ -4,6 +4,7 @@ import copy import random import timeit +from typing import Any, Optional, Type, Union import numpy as np import torch @@ -223,6 +224,7 @@ def do_optimizer_initializations(self, optimizer, learning_rate): opt_method = optimizer.split(':')[1] # optimizater initialization + self.optimizer: Optional[Union[torch.optim.Adam, torch.optim.RMSprop]] = None if opt_method == "adam": self.optimizer = torch.optim.Adam(self.cfs, lr=learning_rate) elif opt_method == "rmsprop": @@ -230,7 +232,8 @@ def do_optimizer_initializations(self, optimizer, learning_rate): def compute_yloss(self): """Computes the first part (y-loss) of the loss function.""" - yloss = 0.0 + yloss: Any = 0.0 + criterion: Optional[Union[torch.nn.BCEWithLogitsLoss, torch.nn.ReLU]] = None for i in range(self.total_CFs): if self.yloss_type == "l2_loss": temp_loss = torch.pow((self.get_model_output(self.cfs[i]) - self.target_cf_class), 2)[0] @@ -307,7 +310,7 @@ def compute_diversity_loss(self): def compute_regularization_loss(self): """Adds a linear equality constraints to the loss functions - to ensure all levels of a categorical variable sums to one""" - regularization_loss = 0.0 + regularization_loss: Any = 0.0 for i in range(self.total_CFs): for v in self.encoded_categorical_feature_indexes: regularization_loss += torch.pow((torch.sum(self.cfs[i][v[0]:v[-1]+1]) - 1.0), 2) @@ -425,7 +428,7 @@ def find_counterfactuals(self, query_instance, desired_class, optimizer, learnin test_pred = self.predict_fn(torch.tensor(query_instance).float())[0] if desired_class == "opposite": desired_class = 1.0 - np.round(test_pred) - self.target_cf_class = torch.tensor(desired_class).float() + self.target_cf_class: Any = torch.tensor(desired_class).float() self.min_iter = min_iter self.max_iter = max_iter diff --git a/dice_ml/explainer_interfaces/dice_random.py b/dice_ml/explainer_interfaces/dice_random.py index 43acad41..5f15292f 100644 --- a/dice_ml/explainer_interfaces/dice_random.py +++ b/dice_ml/explainer_interfaces/dice_random.py @@ -5,6 +5,7 @@ """ import random import timeit +from typing import List, Optional, Union import numpy as np import pandas as pd @@ -30,10 +31,9 @@ def __init__(self, data_interface, model_interface): self.model.transformer.initialize_transform_func() self.precisions = self.data_interface.get_decimal_precisions(output_type="dict") - if self.data_interface.outcome_name in self.precisions: - self.outcome_precision = [self.precisions[self.data_interface.outcome_name]] - else: - self.outcome_precision = 0 + self.outcome_precision = [ + self.precisions[self.data_interface.outcome_name] + ] if self.data_interface.outcome_name in self.precisions else 0 def _generate_counterfactuals(self, query_instance, total_CFs, desired_range=None, desired_class="opposite", permitted_range=None, diff --git a/dice_ml/explainer_interfaces/dice_tensorflow2.py b/dice_ml/explainer_interfaces/dice_tensorflow2.py index 8004a341..afd1bb51 100644 --- a/dice_ml/explainer_interfaces/dice_tensorflow2.py +++ b/dice_ml/explainer_interfaces/dice_tensorflow2.py @@ -341,8 +341,7 @@ def initialize_CFs(self, query_instance, init_near_query_instance=False): one_init.append(np.random.uniform(self.minx[0][i], self.maxx[0][i])) else: one_init.append(query_instance[0][i]) - one_init = np.array([one_init], dtype=np.float32) - self.cfs[n].assign(one_init) + self.cfs[n].assign(np.array([one_init], dtype=np.float32)) def round_off_cfs(self, assign=False): """function for intermediate projection of CFs.""" diff --git a/dice_ml/explainer_interfaces/explainer_base.py b/dice_ml/explainer_interfaces/explainer_base.py index eab7e875..8a79e818 100644 --- a/dice_ml/explainer_interfaces/explainer_base.py +++ b/dice_ml/explainer_interfaces/explainer_base.py @@ -5,6 +5,7 @@ import pickle from abc import ABC, abstractmethod from collections.abc import Iterable +from typing import Any, Dict, List, Optional, Union import numpy as np import pandas as pd @@ -152,10 +153,9 @@ def generate_counterfactuals(self, query_instances, total_CFs, cf_examples_arr = [] query_instances_list = [] if isinstance(query_instances, pd.DataFrame): - for ix in range(query_instances.shape[0]): - query_instances_list.append(query_instances[ix:(ix+1)]) + query_instances_list = [query_instances[ix:(ix+1)] for ix in range(query_instances.shape[0])] elif isinstance(query_instances, Iterable): - query_instances_list = query_instances + query_instances_list = [query_instance for query_instance in query_instances] for query_instance in tqdm(query_instances_list): self.data_interface.set_continuous_feature_indexes(query_instance) res = self._generate_counterfactuals( @@ -416,7 +416,7 @@ def feature_importance(self, query_instances, cf_examples_list=None, posthoc_sparsity_algorithm=posthoc_sparsity_algorithm, **kwargs).cf_examples_list allcols = self.data_interface.categorical_feature_names + self.data_interface.continuous_feature_names - summary_importance = None + summary_importance: Optional[Union[Dict[int, float]]] = None local_importances = None if global_importance: summary_importance = {} @@ -532,7 +532,7 @@ def do_posthoc_sparsity_enhancement(self, final_cfs_sparse, query_instance, post for feature in features_sorted: # current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.iat[[cf_ix]][self.data_interface.feature_names]) # feat_ix = self.data_interface.continuous_feature_names.index(feature) - diff = query_instance[feature].iat[0] - final_cfs_sparse.at[cf_ix, feature] + diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature]) if(abs(diff) <= quantiles[feature]): if posthoc_sparsity_algorithm == "linear": final_cfs_sparse = self.do_linear_search(diff, decimal_prec, query_instance, cf_ix, @@ -561,16 +561,17 @@ def do_linear_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f while((abs(diff) > 10e-4) and (np.sign(diff*old_diff) > 0) and self.is_cf_valid(current_pred)) and (count_steps < limit_steps_ls): - old_val = final_cfs_sparse.at[cf_ix, feature] + old_val = int(final_cfs_sparse.at[cf_ix, feature]) final_cfs_sparse.at[cf_ix, feature] += np.sign(diff)*change current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names]) old_diff = diff if not self.is_cf_valid(current_pred): final_cfs_sparse.at[cf_ix, feature] = old_val + diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature]) return final_cfs_sparse - diff = query_instance[feature].iat[0] - final_cfs_sparse.at[cf_ix, feature] + diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature]) count_steps += 1 @@ -580,7 +581,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f """Performs a binary search between continuous features of a CF and corresponding values in query_instance until the prediction class changes.""" - old_val = final_cfs_sparse.at[cf_ix, feature] + old_val = int(final_cfs_sparse.at[cf_ix, feature]) final_cfs_sparse.at[cf_ix, feature] = query_instance[feature].iat[0] # Prediction of the query instance current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names]) @@ -593,7 +594,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f # move the CF values towards the query_instance if diff > 0: - left = final_cfs_sparse.at[cf_ix, feature] + left = int(final_cfs_sparse.at[cf_ix, feature]) right = query_instance[feature].iat[0] while left <= right: @@ -613,7 +614,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f else: left = query_instance[feature].iat[0] - right = final_cfs_sparse.at[cf_ix, feature] + right = int(final_cfs_sparse.at[cf_ix, feature]) while right >= left: current_val = right - ((right - left)/2) @@ -731,13 +732,16 @@ def is_cf_valid(self, model_score): model_score = model_score[0] # Converting target_cf_class to a scalar (tf/torch have it as (1,1) shape) if self.model.model_type == ModelTypes.Classifier: - target_cf_class = self.target_cf_class if hasattr(self.target_cf_class, "shape"): if len(self.target_cf_class.shape) == 1: - target_cf_class = self.target_cf_class[0] + temp_target_cf_class = self.target_cf_class[0] elif len(self.target_cf_class.shape) == 2: - target_cf_class = self.target_cf_class[0][0] - target_cf_class = int(target_cf_class) + temp_target_cf_class = self.target_cf_class[0][0] + else: + temp_target_cf_class = int(self.target_cf_class) + else: + temp_target_cf_class = int(self.target_cf_class) + target_cf_class = temp_target_cf_class if len(model_score) == 1: # for tensorflow/pytorch models pred_1 = model_score[0] @@ -757,6 +761,7 @@ def is_cf_valid(self, model_score): return self.target_cf_range[0] <= model_score and model_score <= self.target_cf_range[1] def get_model_output_from_scores(self, model_scores): + output_type: Any = None if self.model.model_type == ModelTypes.Classifier: output_type = np.int32 else: @@ -806,7 +811,6 @@ def build_KD_tree(self, data_df_copy, desired_range, desired_class, predicted_ou data_df_copy[predicted_outcome_name] = predictions # segmenting the dataset according to outcome - dataset_with_predictions = None if self.model.model_type == ModelTypes.Classifier: dataset_with_predictions = data_df_copy.loc[[i == desired_class for i in predictions]].copy() @@ -814,9 +818,12 @@ def build_KD_tree(self, data_df_copy, desired_range, desired_class, predicted_ou dataset_with_predictions = data_df_copy.loc[ [desired_range[0] <= pred <= desired_range[1] for pred in predictions]].copy() + else: + dataset_with_predictions = None + KD_tree = None # Prepares the KD trees for DiCE - if len(dataset_with_predictions) > 0: + if dataset_with_predictions is not None and len(dataset_with_predictions) > 0: dummies = pd.get_dummies(dataset_with_predictions[self.data_interface.feature_names]) KD_tree = KDTree(dummies) diff --git a/dice_ml/explainer_interfaces/feasible_base_vae.py b/dice_ml/explainer_interfaces/feasible_base_vae.py index d31d50ad..96d7d8f1 100644 --- a/dice_ml/explainer_interfaces/feasible_base_vae.py +++ b/dice_ml/explainer_interfaces/feasible_base_vae.py @@ -136,8 +136,9 @@ def train(self, pre_trained=False): train_loss = 0.0 train_size = 0 - train_dataset = torch.tensor(self.vae_train_feat).float() - train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) + train_dataset = torch.utils.data.DataLoader( + torch.tensor(self.vae_train_feat).float(), # type: ignore + batch_size=self.batch_size, shuffle=True) for train in enumerate(train_dataset): self.cf_vae_optimizer.zero_grad() @@ -178,8 +179,7 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp final_cf_pred = [] final_test_pred = [] for i in range(len(query_instance)): - train_x = test_dataset[i] - train_x = torch.tensor(train_x).float() + train_x = torch.tensor(test_dataset[i]).float() train_y = torch.argmax(self.pred_model(train_x), dim=1) curr_gen_cf = [] diff --git a/dice_ml/explainer_interfaces/feasible_model_approx.py b/dice_ml/explainer_interfaces/feasible_model_approx.py index 78d01970..9a32b5e4 100644 --- a/dice_ml/explainer_interfaces/feasible_model_approx.py +++ b/dice_ml/explainer_interfaces/feasible_model_approx.py @@ -81,8 +81,9 @@ def train(self, constraint_type, constraint_variables, constraint_direction, con train_loss = 0.0 train_size = 0 - train_dataset = torch.tensor(self.vae_train_feat).float() - train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) + train_dataset = torch.utils.data.DataLoader( + torch.tensor(self.vae_train_feat).float(), # type: ignore + batch_size=self.batch_size, shuffle=True) for train in enumerate(train_dataset): self.cf_vae_optimizer.zero_grad() From 904c425e1cafd92714d999e22741eab94875597f Mon Sep 17 00:00:00 2001 From: Daiki Katsuragawa <50144563+daikikatsuragawa@users.noreply.github.com> Date: Wed, 16 Nov 2022 15:16:02 +0000 Subject: [PATCH 2/3] Delete unnecessary differences Signed-off-by: Daiki Katsuragawa <50144563+daikikatsuragawa@users.noreply.github.com> --- dice_ml/explainer_interfaces/explainer_base.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/dice_ml/explainer_interfaces/explainer_base.py b/dice_ml/explainer_interfaces/explainer_base.py index 8a79e818..cf012660 100644 --- a/dice_ml/explainer_interfaces/explainer_base.py +++ b/dice_ml/explainer_interfaces/explainer_base.py @@ -532,7 +532,7 @@ def do_posthoc_sparsity_enhancement(self, final_cfs_sparse, query_instance, post for feature in features_sorted: # current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.iat[[cf_ix]][self.data_interface.feature_names]) # feat_ix = self.data_interface.continuous_feature_names.index(feature) - diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature]) + diff = query_instance[feature].iat[0] - final_cfs_sparse.at[cf_ix, feature] if(abs(diff) <= quantiles[feature]): if posthoc_sparsity_algorithm == "linear": final_cfs_sparse = self.do_linear_search(diff, decimal_prec, query_instance, cf_ix, @@ -561,17 +561,16 @@ def do_linear_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f while((abs(diff) > 10e-4) and (np.sign(diff*old_diff) > 0) and self.is_cf_valid(current_pred)) and (count_steps < limit_steps_ls): - old_val = int(final_cfs_sparse.at[cf_ix, feature]) + old_val = final_cfs_sparse.at[cf_ix, feature] final_cfs_sparse.at[cf_ix, feature] += np.sign(diff)*change current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names]) old_diff = diff if not self.is_cf_valid(current_pred): final_cfs_sparse.at[cf_ix, feature] = old_val - diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature]) return final_cfs_sparse - diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature]) + diff = query_instance[feature].iat[0] - final_cfs_sparse.at[cf_ix, feature] count_steps += 1 @@ -581,7 +580,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f """Performs a binary search between continuous features of a CF and corresponding values in query_instance until the prediction class changes.""" - old_val = int(final_cfs_sparse.at[cf_ix, feature]) + old_val = final_cfs_sparse.at[cf_ix, feature] final_cfs_sparse.at[cf_ix, feature] = query_instance[feature].iat[0] # Prediction of the query instance current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names]) @@ -594,7 +593,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f # move the CF values towards the query_instance if diff > 0: - left = int(final_cfs_sparse.at[cf_ix, feature]) + left = final_cfs_sparse.at[cf_ix, feature] right = query_instance[feature].iat[0] while left <= right: @@ -614,7 +613,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f else: left = query_instance[feature].iat[0] - right = int(final_cfs_sparse.at[cf_ix, feature]) + right = final_cfs_sparse.at[cf_ix, feature] while right >= left: current_val = right - ((right - left)/2) From 61b7d7933acdb4c4e451fd18955ffabfa96bbfda Mon Sep 17 00:00:00 2001 From: Daiki Katsuragawa <50144563+daikikatsuragawa@users.noreply.github.com> Date: Wed, 16 Nov 2022 15:24:47 +0000 Subject: [PATCH 3/3] Fix errors checked by flake8 Signed-off-by: Daiki Katsuragawa <50144563+daikikatsuragawa@users.noreply.github.com> --- dice_ml/explainer_interfaces/dice_genetic.py | 2 +- dice_ml/explainer_interfaces/dice_pytorch.py | 2 +- dice_ml/explainer_interfaces/dice_random.py | 1 - dice_ml/explainer_interfaces/explainer_base.py | 2 +- dice_ml/explainer_interfaces/feasible_base_vae.py | 2 +- dice_ml/explainer_interfaces/feasible_model_approx.py | 2 +- 6 files changed, 5 insertions(+), 6 deletions(-) diff --git a/dice_ml/explainer_interfaces/dice_genetic.py b/dice_ml/explainer_interfaces/dice_genetic.py index 9fe27b39..02692524 100644 --- a/dice_ml/explainer_interfaces/dice_genetic.py +++ b/dice_ml/explainer_interfaces/dice_genetic.py @@ -5,7 +5,7 @@ import copy import random import timeit -from typing import Any, List, Union +from typing import Any import numpy as np import pandas as pd diff --git a/dice_ml/explainer_interfaces/dice_pytorch.py b/dice_ml/explainer_interfaces/dice_pytorch.py index 39f895e0..95aa7eea 100644 --- a/dice_ml/explainer_interfaces/dice_pytorch.py +++ b/dice_ml/explainer_interfaces/dice_pytorch.py @@ -4,7 +4,7 @@ import copy import random import timeit -from typing import Any, Optional, Type, Union +from typing import Any, Optional, Union import numpy as np import torch diff --git a/dice_ml/explainer_interfaces/dice_random.py b/dice_ml/explainer_interfaces/dice_random.py index 5f15292f..b8bfd095 100644 --- a/dice_ml/explainer_interfaces/dice_random.py +++ b/dice_ml/explainer_interfaces/dice_random.py @@ -5,7 +5,6 @@ """ import random import timeit -from typing import List, Optional, Union import numpy as np import pandas as pd diff --git a/dice_ml/explainer_interfaces/explainer_base.py b/dice_ml/explainer_interfaces/explainer_base.py index cf012660..2c50ed5d 100644 --- a/dice_ml/explainer_interfaces/explainer_base.py +++ b/dice_ml/explainer_interfaces/explainer_base.py @@ -5,7 +5,7 @@ import pickle from abc import ABC, abstractmethod from collections.abc import Iterable -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Union import numpy as np import pandas as pd diff --git a/dice_ml/explainer_interfaces/feasible_base_vae.py b/dice_ml/explainer_interfaces/feasible_base_vae.py index 96d7d8f1..49840362 100644 --- a/dice_ml/explainer_interfaces/feasible_base_vae.py +++ b/dice_ml/explainer_interfaces/feasible_base_vae.py @@ -137,7 +137,7 @@ def train(self, pre_trained=False): train_size = 0 train_dataset = torch.utils.data.DataLoader( - torch.tensor(self.vae_train_feat).float(), # type: ignore + torch.tensor(self.vae_train_feat).float(), # type: ignore batch_size=self.batch_size, shuffle=True) for train in enumerate(train_dataset): self.cf_vae_optimizer.zero_grad() diff --git a/dice_ml/explainer_interfaces/feasible_model_approx.py b/dice_ml/explainer_interfaces/feasible_model_approx.py index 9a32b5e4..26b4c40b 100644 --- a/dice_ml/explainer_interfaces/feasible_model_approx.py +++ b/dice_ml/explainer_interfaces/feasible_model_approx.py @@ -82,7 +82,7 @@ def train(self, constraint_type, constraint_variables, constraint_direction, con train_size = 0 train_dataset = torch.utils.data.DataLoader( - torch.tensor(self.vae_train_feat).float(), # type: ignore + torch.tensor(self.vae_train_feat).float(), # type: ignore batch_size=self.batch_size, shuffle=True) for train in enumerate(train_dataset): self.cf_vae_optimizer.zero_grad()