Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented BALD acquisition function within QueryInstanceQBC #35

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions alipy/query_strategy/query_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ class QueryInstanceQBC(BaseIndexQuery):
Method name. This class only implement query_by_bagging for now.

disagreement: str
method to calculate disagreement of committees. should be one of ['vote_entropy', 'KL_divergence']
method to calculate disagreement of committees. should be one of ['vote_entropy', 'KL_divergence', 'BALD']

References
----------
Expand All @@ -373,10 +373,10 @@ class QueryInstanceQBC(BaseIndexQuery):
def __init__(self, X=None, y=None, method='query_by_bagging', disagreement='vote_entropy'):
self._method = method
super(QueryInstanceQBC, self).__init__(X, y)
if disagreement in ['vote_entropy', 'KL_divergence']:
if disagreement in ['vote_entropy', 'KL_divergence', 'BALD']:
self._disagreement = disagreement
else:
raise ValueError("disagreement must be one of ['vote_entropy', 'KL_divergence']")
raise ValueError("disagreement must be one of ['vote_entropy', 'KL_divergence', 'BALD']")

def select(self, label_index, unlabel_index, model=None, batch_size=1, n_jobs=None):
"""Select indexes from the unlabel_index for querying.
Expand Down Expand Up @@ -438,8 +438,10 @@ def select(self, label_index, unlabel_index, model=None, batch_size=1, n_jobs=No
# calc score
if self._disagreement == 'vote_entropy':
score = self.calc_vote_entropy([estimator.predict(unlabel_x) for estimator in est_arr])
else:
elif self._disagreement == 'KL_divergence':
score = self.calc_avg_KL_divergence([estimator.predict_proba(unlabel_x) for estimator in est_arr])
else:
score = self.calc_BALD([estimator.predict_proba(unlabel_x) for estimator in est_arr])
return unlabel_index[nlargestarg(score, batch_size)]

def select_by_prediction_mat(self, unlabel_index, predict, batch_size=1):
Expand Down Expand Up @@ -598,6 +600,42 @@ def calc_avg_KL_divergence(cls, predict_matrices):
"A 2D probabilistic prediction matrix must be provided, with the shape like [n_samples, n_class]")
return score

@classmethod
def calc_BALD(cls, predict_matrices):
"""Calculate the mutual information between model predictions and model parameters
for measuring the level of disagreement in QBC.

Parameters
----------
predict_matrices: list
The prediction matrix for each committee.
Each committee predict matrix should have the shape [n_samples, n_classes] for probabilistic output
or [n_samples] for class output.

Returns
-------
score: list
Score for each instance. Shape [n_samples]

References
----------
[1] Kirsch, Andreas, Joost Van Amersfoort, and Yarin Gal. "Batchbald: Efficient
and diverse batch acquisition for deep bayesian active learning." Advances in neural
information processing systems 32 (2019): 7026-7037.
"""
input_shape, committee_size = cls()._check_committee_results(predict_matrices)
if len(input_shape) == 2:
prob = np.array(predict_matrices)
pb = np.mean(prob, axis=0)
entropy1 = (-prob*np.log(prob)).sum(axis=2).mean(axis=0)
entropy2 = (-pb*np.log(pb)).sum(axis=1)
bald_scores = entropy2 - entropy1
score = bald_scores.tolist()
else:
raise Exception(
"A 2D probabilistic prediction matrix must be provided, with the shape like [n_samples, n_class]")
return score


class QueryExpectedErrorReduction(BaseIndexQuery):
"""The Expected Error Reduction (ERR) algorithm.
Expand Down