From 993a85558432ce7ccf37b581b6012800b88e3a2d Mon Sep 17 00:00:00 2001 From: Srikumar Sastry Date: Tue, 11 Jan 2022 00:43:28 +0100 Subject: [PATCH] Implemented BALD acquisition function within QueryInstanceQBC --- alipy/query_strategy/query_labels.py | 46 +++++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/alipy/query_strategy/query_labels.py b/alipy/query_strategy/query_labels.py index efa3071..32122df 100644 --- a/alipy/query_strategy/query_labels.py +++ b/alipy/query_strategy/query_labels.py @@ -357,7 +357,7 @@ class QueryInstanceQBC(BaseIndexQuery): Method name. This class only implement query_by_bagging for now. disagreement: str - method to calculate disagreement of committees. should be one of ['vote_entropy', 'KL_divergence'] + method to calculate disagreement of committees. should be one of ['vote_entropy', 'KL_divergence', 'BALD'] References ---------- @@ -373,10 +373,10 @@ class QueryInstanceQBC(BaseIndexQuery): def __init__(self, X=None, y=None, method='query_by_bagging', disagreement='vote_entropy'): self._method = method super(QueryInstanceQBC, self).__init__(X, y) - if disagreement in ['vote_entropy', 'KL_divergence']: + if disagreement in ['vote_entropy', 'KL_divergence', 'BALD']: self._disagreement = disagreement else: - raise ValueError("disagreement must be one of ['vote_entropy', 'KL_divergence']") + raise ValueError("disagreement must be one of ['vote_entropy', 'KL_divergence', 'BALD']") def select(self, label_index, unlabel_index, model=None, batch_size=1, n_jobs=None): """Select indexes from the unlabel_index for querying. @@ -438,8 +438,10 @@ def select(self, label_index, unlabel_index, model=None, batch_size=1, n_jobs=No # calc score if self._disagreement == 'vote_entropy': score = self.calc_vote_entropy([estimator.predict(unlabel_x) for estimator in est_arr]) - else: + elif self._disagreement == 'KL_divergence': score = self.calc_avg_KL_divergence([estimator.predict_proba(unlabel_x) for estimator in est_arr]) + else: + score = self.calc_BALD([estimator.predict_proba(unlabel_x) for estimator in est_arr]) return unlabel_index[nlargestarg(score, batch_size)] def select_by_prediction_mat(self, unlabel_index, predict, batch_size=1): @@ -598,6 +600,42 @@ def calc_avg_KL_divergence(cls, predict_matrices): "A 2D probabilistic prediction matrix must be provided, with the shape like [n_samples, n_class]") return score + @classmethod + def calc_BALD(cls, predict_matrices): + """Calculate the mutual information between model predictions and model parameters + for measuring the level of disagreement in QBC. + + Parameters + ---------- + predict_matrices: list + The prediction matrix for each committee. + Each committee predict matrix should have the shape [n_samples, n_classes] for probabilistic output + or [n_samples] for class output. + + Returns + ------- + score: list + Score for each instance. Shape [n_samples] + + References + ---------- + [1] Kirsch, Andreas, Joost Van Amersfoort, and Yarin Gal. "Batchbald: Efficient + and diverse batch acquisition for deep bayesian active learning." Advances in neural + information processing systems 32 (2019): 7026-7037. + """ + input_shape, committee_size = cls()._check_committee_results(predict_matrices) + if len(input_shape) == 2: + prob = np.array(predict_matrices) + pb = np.mean(prob, axis=0) + entropy1 = (-prob*np.log(prob)).sum(axis=2).mean(axis=0) + entropy2 = (-pb*np.log(pb)).sum(axis=1) + bald_scores = entropy2 - entropy1 + score = bald_scores.tolist() + else: + raise Exception( + "A 2D probabilistic prediction matrix must be provided, with the shape like [n_samples, n_class]") + return score + class QueryExpectedErrorReduction(BaseIndexQuery): """The Expected Error Reduction (ERR) algorithm.