Skip to content

Commit

Permalink
refactor distance-based acquisition functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jonhue committed Feb 21, 2024
1 parent 4832073 commit 513b038
Show file tree
Hide file tree
Showing 10 changed files with 176 additions and 177 deletions.
9 changes: 5 additions & 4 deletions afsl/acquisition_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
You can use a custom acquisition function as follows:
```python
from afsl.acquisition_functions.greedy_max_det import GreedyMaxDet
from afsl.acquisition_functions.undirected_itl import UndirectedITL
acquisition_function = GreedyMaxDet()
acquisition_function = UndirectedITL()
data_loader = afsl.ActiveDataLoader(data, batch_size=64, acquisition_function=acquisition_function)
```
Expand All @@ -20,8 +20,9 @@
| [VTL](acquisition_functions/vtl) | ✅ | ✅ | ✅ | embedding / kernel |
| [CTL](acquisition_functions/ctl) | ✅ | ❌ | ✅ | embedding / kernel |
| [Cosine Similarity](acquisition_functions/cosine_similarity) | ✅ | ❌ | ❌ | embedding |
| [GreedyMaxDet](acquisition_functions/greedy_max_det) | ❌ | ✅ | ✅ | embedding / kernel |
| [GreedyMaxDist](acquisition_functions/greedy_max_dist) | ❌ | (✅) | ✅ | embedding / kernel |
| [Undirected ITL](acquisition_functions/undirected_itl) | ❌ | ✅ | ✅ | embedding / kernel |
| [Undirected VTL](acquisition_functions/undirected_vtl) | ❌ | ✅ | ✅ | embedding / kernel |
| [MaxDist](acquisition_functions/max_dist) | ❌ | (✅) | ✅ | embedding / kernel |
| [k-means++](acquisition_functions/kmeans_pp) | ❌ | (✅) | ✅ | embedding / kernel |
| [Uncertainty Sampling](acquisition_functions/uncertainty_sampling) | ❌ | ✅ | ❌ | embedding / kernel |
| [MaxMargin](acquisition_functions/max_margin) | ❌ | (✅) | ❌ | softmax |
Expand Down
2 changes: 1 addition & 1 deletion afsl/acquisition_functions/cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class CosineSimilarity(Targeted, BatchAcquisitionFunction):
r"""
The cosine similarity between two vectors $\vphi$ and $\vphip$ is \\[\angle(\vphi, \vphip) = \frac{\vphi^\top \vphip}{\|\vphi\|_2 \|\vphip\|_2}.\\]
The cosine similarity between two vectors $\vphi$ and $\vphip$ is \\[\angle(\vphi, \vphip) \defeq \frac{\vphi^\top \vphip}{\|\vphi\|_2 \|\vphip\|_2}.\\]
Given a set of targets $\spA$ and a model which for an input $\vx$ computes an embedding $\vphi(\vx)$, `CosineSimilarity`[^1] selects the inputs $\vx$ which maximize \\[ \frac{1}{|\spA|} \sum_{\vxp \in \spA} \angle(\vphi(\vx), \vphi(\vxp)). \\]
Intuitively, this selects the points that are most similar to the targets $\spA$.
Expand Down
102 changes: 0 additions & 102 deletions afsl/acquisition_functions/distance.py

This file was deleted.

38 changes: 0 additions & 38 deletions afsl/acquisition_functions/greedy_max_dist.py

This file was deleted.

25 changes: 6 additions & 19 deletions afsl/acquisition_functions/kmeans_pp.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import torch
from afsl.acquisition_functions.distance import (
DistanceBasedAcquisitionFunction,
DistanceState,
)
from afsl.acquisition_functions.max_dist import MaxDist


class KMeansPP(DistanceBasedAcquisitionFunction):
class KMeansPP(MaxDist):
r"""
Given a model which for two inputs $\vx$ and $\vxp$ induces a distance $d(\vx,\vxp)$,[^1] `KMeansPP`[^2] selects the batch via [k-means++ seeding](https://en.wikipedia.org/wiki/K-means%2B%2B).
That is, the first centroid $\vx_1$ is chosen randomly and the subsequent centroids are chosen with a probability proportional to the square of the distance to the nearest previously selected centroid: \\[ \Pr{\vx_i = \vx} \propto \min_{j < i} d(\vx; \vx_j)^2. \\]
.. note::
This acquisition function is similar to [GreedyMaxDist](greedy_max_dist) but selects the batch randomly rather than deterministically.
This acquisition function is similar to [MaxDist](max_dist) but selects the batch randomly rather than deterministically.
`KMeansPP` explicitly enforces *diversity* in the selected batch.
If the selected centroids from previous batches are used to initialize the centroids for the current batch,[^3] then `KMeansPP` heuristically also leads to *informative* samples since samples are chosen to be different from previously seen data.
Expand All @@ -23,7 +20,7 @@ class KMeansPP(DistanceBasedAcquisitionFunction):
Using the afsl.embeddings.classification.CrossEntropyEmbedding embeddings, this acquisition function is known as BADGE (*Batch Active learning by Diverse Gradient Embeddings*).[^4]
[^1]: See afsl.acquisition_functions.distance.DistanceBasedAcquisitionFunction for a discussion of how a distance is induced by embeddings or a kernel.
[^1]: See [here](max_dist#where-does-the-distance-come-from) for a discussion of how a distance is induced by embeddings or a kernel.
[^2]: Holzmüller, D., Zaverkin, V., Kästner, J., and Steinwart, I. A framework and benchmark for deep batch active learning for regression. JMLR, 24(164), 2023.
Expand All @@ -32,17 +29,7 @@ class KMeansPP(DistanceBasedAcquisitionFunction):
[^4]: Ash, J. T., Zhang, C., Krishnamurthy, A., Langford, J., and Agarwal, A. Deep batch active learning by diverse, uncertain gradient lower bounds. ICLR, 2020.
"""

def compute(self, state: DistanceState) -> torch.Tensor:
if len(state.centroid_indices) == 0:
# Choose the first centroid randomly
return torch.ones(state.data.size(0))

# Compute the distance of each point to the nearest centroid
sqd_distances = torch.square(self.compute_min_distances(state))
# Choose the next centroid with a probability proportional to the square of the distance
probabilities = sqd_distances / sqd_distances.sum()
return probabilities

@staticmethod
def selector(probabilities: torch.Tensor) -> int:
def selector(min_sqd_distances: torch.Tensor) -> int:
probabilities = min_sqd_distances / min_sqd_distances.sum()
return int(torch.multinomial(probabilities, num_samples=1).item())
135 changes: 135 additions & 0 deletions afsl/acquisition_functions/max_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from typing import NamedTuple
import torch
from afsl.acquisition_functions import SequentialAcquisitionFunction
from afsl.model import ModelWithEmbedding, ModelWithEmbeddingOrKernel, ModelWithKernel
from afsl.utils import DEFAULT_MINI_BATCH_SIZE, compute_embedding

__all__ = ["MaxDist", "DistanceState", "sqd_kernel_distance"]


class DistanceState(NamedTuple):
"""State of sequential batch selection."""

centroid_indices: torch.Tensor
"""Indices of previously selected centroids."""
min_sqd_distances: torch.Tensor
"""Minimum squared distances to previously selected centroids. Tensor of shape $n$."""
kernel_matrix: torch.Tensor
r"""Kernel matrix of the data. Tensor of shape $n \times n$."""


class MaxDist(SequentialAcquisitionFunction[ModelWithEmbeddingOrKernel, DistanceState]):
r"""
Given a model which for two inputs $\vx$ and $\vxp$ induces a distance $d(\vx,\vxp)$,[^1] `MaxDist`[^2] constructs the batch by choosing the point with the maximum distance to the nearest previously selected point: \\[ \vx_i = \argmax_{\vx} \min_{j < i} d(\vx, \vx_j). \\]
The first point $\vx_1$ is chosen randomly.
.. note::
This acquisition function is similar to [k-means++](kmeans_pp) but selects the batch deterministically rather than randomly.
`MaxDist` explicitly enforces *diversity* in the selected batch.
If the selected centroids from previous batches are used to initialize the centroids for the current batch,[^3] then `MaxDist` heuristically also leads to *informative* samples since samples are chosen to be different from previously seen data.
| Relevance? | Informativeness? | Diversity? | Model Requirement |
|------------|------------------|------------|--------------------|
| ❌ | (✅) | ✅ | embedding / kernel |
#### Where does the distance come from?
This acquisition function rests on the assumption that the model induces a distance $d(\vx,\vxp)$ between points $\vx$ and $\vxp$, either via an embedding or a kernel.
- **Embeddings** $\vphi(\cdot)$ induce the (euclidean) *embedding distance* \\[ d_\vphi(\vx,\vxp) \defeq \norm{\vphi(\vx) - \vphi(\vxp)}_2. \\]
- A **kernel** $k$ induces the *kernel distance* \\[ d_k(\vx,\vxp) \defeq = \sqrt{k(\vx,\vx) + k(\vxp,\vxp) - 2 k(\vx,\vxp)}. \\]
It is straightforward to see that if $k(\vx,\vxp) = \vphi(\vx)^\top \vphi(\vxp)$ then embedding and kernel distances coincide, i.e., $d_\vphi(\vx,\vxp) = d_k(\vx,\vxp)$.
[^2]: Holzmüller, D., Zaverkin, V., Kästner, J., and Steinwart, I. A framework and benchmark for deep batch active learning for regression. JMLR, 24(164), 2023.
[^3]: see `initialize_with_previous_samples`
"""

initialize_with_previous_samples: bool = True
"""Whether to initialize the centroids with the samples from previous batches."""

def __init__(
self,
mini_batch_size=DEFAULT_MINI_BATCH_SIZE,
force_nonsequential=False,
initialize_with_previous_samples=True,
):
super().__init__(
mini_batch_size=mini_batch_size, force_nonsequential=force_nonsequential
)
self.initialize_with_previous_samples = initialize_with_previous_samples

def initialize(
self,
model: ModelWithEmbeddingOrKernel,
data: torch.Tensor,
) -> DistanceState:
if isinstance(model, ModelWithEmbedding):
embeddings = compute_embedding(
model, data, mini_batch_size=self.mini_batch_size
)

if self.initialize_with_previous_samples:
centroid_indices = self.selected
if isinstance(model, ModelWithEmbedding):
centroids = embeddings[centroid_indices.to(embeddings.device)]
distances = torch.square(torch.cdist(embeddings, centroids, p=2))
else:
centroids = data[centroid_indices.to(data.device)]
distances = sqd_kernel_distance(data, centroids, model)
min_sqd_distances = torch.min(distances, dim=1).values
else:
centroid_indices = torch.tensor([])
min_sqd_distances = torch.full(size=(data.size(0),), fill_value=torch.inf)

if isinstance(model, ModelWithEmbedding):
kernel_matrix = embeddings @ embeddings.T
else:
kernel_matrix = model.kernel(data, data)

return DistanceState(
centroid_indices=centroid_indices,
min_sqd_distances=min_sqd_distances,
kernel_matrix=kernel_matrix,
)

def compute(self, state: DistanceState) -> torch.Tensor:
return state.min_sqd_distances

def step(self, state: DistanceState, i: int) -> DistanceState:
centroid_indices = torch.cat(
[
state.centroid_indices,
torch.tensor([i]).to(state.centroid_indices.device),
]
)
new_sqd_distances = (
state.kernel_matrix[i, i]
+ torch.diag(state.kernel_matrix)
- 2 * state.kernel_matrix[i, :]
)
min_sqd_distances = torch.min(state.min_sqd_distances, new_sqd_distances)
return DistanceState(
centroid_indices=centroid_indices,
min_sqd_distances=min_sqd_distances,
kernel_matrix=state.kernel_matrix,
)


def sqd_kernel_distance(
x1: torch.Tensor, x2: torch.Tensor, model: ModelWithKernel
) -> torch.Tensor:
r"""
Returns the squared *kernel distance* \\[ d_k(\vx,\vxp)^2 \defeq \norm{\vphi(\vx) - \vphi(\vxp)}_2^2 = k(\vx,\vx) + k(\vxp,\vxp) - 2 k(\vx,\vxp) \\] induced by the kernel $k(\vx,\vxp) = \vphi(\vx)^\top \vphi(\vxp)$.
:param x1: Tensor of shape $n \times d$.
:param x2: Tensor of shape $m \times d$.
:param model: Model with a kernel method.
:return: Tensor of shape $n \times m$ of pairwise squared distances.
"""
return torch.sqrt(
model.kernel(x1, x1) + model.kernel(x2, x2) - 2 * model.kernel(x1, x2)
)
6 changes: 3 additions & 3 deletions afsl/acquisition_functions/uncertainty_sampling.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from afsl.acquisition_functions.greedy_max_det import GreedyMaxDet
from afsl.acquisition_functions.undirected_itl import UndirectedITL
from afsl.utils import DEFAULT_MINI_BATCH_SIZE


class UncertaintySampling(GreedyMaxDet):
class UncertaintySampling(UndirectedITL):
r"""
`UncertaintySampling`[^1] selects the most uncertain data point: \\[ \argmax_\vx\ \sigma^2(\vx) \\] where $\sigma^2(\vx) = k(\vx, \vx)$ denotes the variance of $\vx$ induced by the kernel $k$.[^3]
Expand All @@ -12,7 +12,7 @@ class UncertaintySampling(GreedyMaxDet):
.. note::
`UncertaintySampling` coincides with [GreedyMaxDet](greedy_max_det) with `force_nonsequential=True`.
`UncertaintySampling` coincides with [Undirected ITL](undirected_itl) with `force_nonsequential=True`.
| Relevance? | Informativeness? | Diversity? | Model Requirement |
|------------|------------------|------------|--------------------|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from afsl.acquisition_functions.bace import BaCE, BaCEState


class GreedyMaxDet(BaCE):
class UndirectedITL(BaCE):
def compute(self, state: BaCEState) -> torch.Tensor:
variances = torch.diag(state.covariance_matrix[:, :])
wandb.log(
Expand Down
Loading

0 comments on commit 513b038

Please sign in to comment.