-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* implement lazy selection * update docstring in prio queue * complete implementation of acquisition function * setup faiss adapter * small changes * fix dependencies * cleanup * fix pyright error * cleanup * update docstring * drop assumption that priority queue is ordered at initialization * update docstring * fix initial values in pq * fix signs in the value computation * black * fix computation of cached inverse matrix * fix computation of cached inverse * only recompute inv when necessary * fix referenced faiss ip index * fix pyright error * fixes * black * cleanup
- Loading branch information
Showing
5 changed files
with
440 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,366 @@ | ||
from typing import List, NamedTuple, Tuple | ||
import numpy as np | ||
import torch | ||
from afsl.acquisition_functions import ( | ||
EmbeddingBased, | ||
SequentialAcquisitionFunction, | ||
Targeted, | ||
) | ||
from afsl.gaussian import GaussianCovarianceMatrix | ||
from afsl.model import ModelWithEmbeddingOrKernel | ||
from afsl.utils import ( | ||
DEFAULT_EMBEDDING_BATCH_SIZE, | ||
DEFAULT_MINI_BATCH_SIZE, | ||
DEFAULT_NUM_WORKERS, | ||
DEFAULT_SUBSAMPLE, | ||
PriorityQueue, | ||
) | ||
|
||
__all__ = ["LazyVTL", "LazyVTLState"] | ||
|
||
|
||
class LazyVTLState(NamedTuple): | ||
"""State of lazy VTL.""" | ||
|
||
covariance_matrix: GaussianCovarianceMatrix | ||
r"""Kernel matrix of the data. Tensor of shape $n \times n$.""" | ||
m: int | ||
"""Size of the target space.""" | ||
selected_indices: List[int] | ||
"""Indices of points that were already observed.""" | ||
covariance_matrix_indices: List[int] | ||
"""Indices of points that were added to the covariance matrix (excluding the initially added target space).""" | ||
target: torch.Tensor | ||
r"""Tensor of shape $m \times d$ which includes target space.""" | ||
data: torch.Tensor | ||
r"""Tensor of shape $n \times d$ which includes sample space.""" | ||
current_inv: torch.Tensor | ||
"""Current inverse of the covariance matrix of selected data.""" | ||
|
||
|
||
class LazyVTL( | ||
Targeted, | ||
EmbeddingBased, | ||
SequentialAcquisitionFunction[ModelWithEmbeddingOrKernel | None, LazyVTLState], | ||
): | ||
noise_std: float | ||
"""Standard deviation of the noise. Determined automatically if set to `None`.""" | ||
|
||
priority_queue: PriorityQueue | None = None | ||
"""Priority queue over the data set.""" | ||
|
||
def __init__( | ||
self, | ||
target: torch.Tensor, | ||
noise_std: float, | ||
subsampled_target_frac: float = 1, | ||
max_target_size: int | None = None, | ||
mini_batch_size=DEFAULT_MINI_BATCH_SIZE, | ||
embedding_batch_size=DEFAULT_EMBEDDING_BATCH_SIZE, | ||
num_workers=DEFAULT_NUM_WORKERS, | ||
subsample=DEFAULT_SUBSAMPLE, | ||
): | ||
""" | ||
:param target: Tensor of prediction targets (shape $m \times d$). | ||
:param noise_std: Standard deviation of the noise. | ||
:param subsampled_target_frac: Fraction of the target to be subsampled in each iteration. Must be in $(0,1]$. Default is $1$. Ignored if `target` is `None`. | ||
:param max_target_size: Maximum size of the target to be subsampled in each iteration. Default is `None` in which case the target may be arbitrarily large. Ignored if `target` is `None`. | ||
:param mini_batch_size: Size of mini-batch used for computing the acquisition function. | ||
:param embedding_batch_size: Batch size used for computing the embeddings. | ||
:param num_workers: Number of workers used for parallel computation. | ||
:param subsample: Whether to subsample the data set. | ||
""" | ||
SequentialAcquisitionFunction.__init__( | ||
self, | ||
mini_batch_size=mini_batch_size, | ||
num_workers=num_workers, | ||
subsample=subsample, | ||
force_nonsequential=False, | ||
) | ||
Targeted.__init__( | ||
self, | ||
target=target, | ||
subsampled_target_frac=subsampled_target_frac, | ||
max_target_size=max_target_size, | ||
) | ||
EmbeddingBased.__init__(self, embedding_batch_size=embedding_batch_size) | ||
self.noise_std = noise_std | ||
|
||
def set_initial_priority_queue( | ||
self, | ||
embeddings: torch.Tensor, | ||
target_embedding: torch.Tensor, | ||
inner_products: torch.Tensor | None = None, | ||
): | ||
r""" | ||
Constructs the initial priority queue (of length $k$) over the data set. | ||
:param embeddings: Array of shape $k \times d$ containing the data point embeddings. | ||
:param target_embedding: Array of shape $d$ containing the (mean) target embedding. | ||
:param inner_products: Array of length $k$ containing precomputed (absolute) inner products of the data point embeddings with the query embedding. | ||
""" | ||
self_inner_products = torch.sum(embeddings * embeddings, dim=1) | ||
if inner_products is None: | ||
inner_products = embeddings @ target_embedding | ||
values = inner_products**2 / ( # target_inner_product - | ||
self_inner_products + self.noise_var | ||
) | ||
|
||
self.priority_queue = PriorityQueue(values=values.cpu().tolist()) | ||
|
||
def select_from_minibatch( | ||
self, | ||
batch_size: int, | ||
model: ModelWithEmbeddingOrKernel | None, | ||
data: torch.Tensor, | ||
device: torch.device | None, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
r""" | ||
Selects the next batch from the given mini batch `data`. | ||
:param batch_size: Size of the batch to be selected. Needs to be smaller than `mini_batch_size`. | ||
:param model: Model used for data selection. | ||
:param data: Mini batch of inputs (shape $n \times d$) to be selected from. | ||
:param device: Device used for computation of the acquisition function. | ||
:return: Indices of the newly selected batch (with respect to mini batch) and corresponding values of the acquisition function. | ||
""" | ||
state = self.initialize(model, data, device) | ||
|
||
assert ( | ||
self.priority_queue is not None | ||
), "Initial priority queue must be set." # TODO: make optional, compute from scratch if not provided | ||
assert self.priority_queue.size() == data.size( | ||
0 | ||
), "Size of the priority queue must match the size of the data set." | ||
|
||
selected_values = [] | ||
for _ in range(batch_size): | ||
while True: | ||
i, _ = self.priority_queue.pop() | ||
new_value, new_covariance_matrix, state = self.recompute(state, i) | ||
|
||
prev_top_value = self.priority_queue.top_value | ||
self.priority_queue.push(i, new_value) | ||
if ( | ||
new_value >= prev_top_value | ||
): # done if the value is larger than the largest upper bound of other points | ||
break # (i, new_value) contains the next selected index and its value | ||
selected_values.append(new_value) | ||
state = self.step(state, i, new_covariance_matrix) | ||
return torch.tensor(state.selected_indices), torch.tensor(selected_values) | ||
|
||
def initialize( | ||
self, | ||
model: ModelWithEmbeddingOrKernel | None, | ||
data: torch.Tensor, | ||
device: torch.device | None, | ||
) -> LazyVTLState: | ||
target = self.get_target() | ||
m = target.size(0) | ||
|
||
# Compute covariance matrix of targets | ||
assert model is None, "embedding computation via model not supported" | ||
covariance_matrix = GaussianCovarianceMatrix.from_embeddings( | ||
Embeddings=target.to(device) | ||
) | ||
|
||
if self.priority_queue is None: | ||
self.set_initial_priority_queue( | ||
embeddings=data, | ||
target_embedding=torch.mean(target, dim=0), | ||
) | ||
|
||
return LazyVTLState( | ||
covariance_matrix=covariance_matrix, | ||
m=m, | ||
selected_indices=[], | ||
covariance_matrix_indices=[], | ||
target=target, | ||
data=data, | ||
current_inv=torch.empty(0, 0), | ||
) | ||
|
||
def recompute( | ||
self, state: LazyVTLState, data_idx: int | ||
) -> Tuple[float, GaussianCovarianceMatrix, LazyVTLState]: | ||
""" | ||
Update value of data point `data_idx`. | ||
""" | ||
try: # index of data point within covariance matrix | ||
idx = state.m + state.covariance_matrix_indices.index(data_idx) | ||
new_covariance_matrix = state.covariance_matrix | ||
new_state = state | ||
except ValueError: | ||
# update cached inverse covariance matrix of previously selected data, O(n^2) | ||
i = state.current_inv.size(0) | ||
if len(state.selected_indices) > i: | ||
d = state.data.size(1) | ||
prev_data = ( | ||
state.data[ | ||
torch.tensor(state.selected_indices[:i]).to(state.data.device) | ||
] | ||
if i > 0 | ||
else torch.empty(0, d) | ||
) # (i, d) | ||
new_data = state.data[ | ||
torch.tensor(state.selected_indices[i:]).to(state.data.device) | ||
] # (n-1-i, d) | ||
B = prev_data @ new_data.T # (n-1,) | ||
I = torch.eye(new_data.size(0)).to(new_data.device) | ||
C = new_data @ new_data.T + self.noise_var * I # (n-1-i, n-1-i) | ||
new_inv = update_inverse(A_inv=state.current_inv, B=B, C=C) | ||
new_state = LazyVTLState( | ||
covariance_matrix=state.covariance_matrix, | ||
m=state.m, | ||
selected_indices=state.selected_indices, | ||
covariance_matrix_indices=state.covariance_matrix_indices, | ||
target=state.target, | ||
data=state.data, | ||
current_inv=new_inv, | ||
) | ||
else: | ||
new_state = state | ||
|
||
# expand the stored covariance matrix if selected data has not been selected before, O(n^2) | ||
new_covariance_matrix = expand_covariance_matrix( | ||
covariance_matrix=state.covariance_matrix, | ||
current_inv=new_state.current_inv, | ||
data=state.data, | ||
target=state.target, | ||
data_idx=data_idx, | ||
covariance_matrix_indices=state.covariance_matrix_indices, | ||
selected_indices=state.selected_indices, | ||
) | ||
idx = new_covariance_matrix.dim - 1 | ||
|
||
value = compute( | ||
covariance_matrix=new_covariance_matrix, | ||
idx=idx, | ||
noise_var=self.noise_var, | ||
m=state.m, | ||
) | ||
return value, new_covariance_matrix, new_state | ||
|
||
def step( | ||
self, | ||
state: LazyVTLState, | ||
data_idx: int, | ||
new_covariance_matrix: GaussianCovarianceMatrix, | ||
) -> LazyVTLState: | ||
""" | ||
Advances the state. | ||
Updates the stored covariance matrix and the inverse of the covariance matrix (restricted to selected data). | ||
""" | ||
if data_idx not in state.covariance_matrix_indices: | ||
state.covariance_matrix_indices.append( | ||
data_idx | ||
) # Note: not treating as immutable! | ||
|
||
# update the stored covariance matrix by conditioning on the new data point, O(n^2) | ||
idx = state.m + state.covariance_matrix_indices.index(data_idx) | ||
posterior_covariance_matrix = new_covariance_matrix.condition_on( | ||
idx, noise_std=self.noise_std | ||
) | ||
|
||
state.selected_indices.append(data_idx) # Note: not treating as immutable! | ||
return LazyVTLState( | ||
covariance_matrix=posterior_covariance_matrix, | ||
m=state.m, | ||
selected_indices=state.selected_indices, | ||
covariance_matrix_indices=state.covariance_matrix_indices, | ||
target=state.target, | ||
data=state.data, | ||
current_inv=state.current_inv, | ||
) | ||
|
||
@property | ||
def noise_var(self): | ||
return self.noise_std**2 | ||
|
||
def compute(self, state: LazyVTLState) -> torch.Tensor: | ||
raise NotImplementedError | ||
|
||
|
||
def compute( | ||
covariance_matrix: GaussianCovarianceMatrix, idx: int, noise_var: float, m: int | ||
) -> float: | ||
""" | ||
Computes the acquisition value of the data point at index `idx` within the covariance matrix. | ||
Time complexity: O(1) | ||
""" | ||
|
||
def engine(i, j): | ||
return covariance_matrix[i, j] ** 2 / (covariance_matrix[j, j] + noise_var) | ||
|
||
target_indices = torch.arange(m) | ||
values = engine(target_indices, idx) | ||
value = torch.sum(values, dim=0).cpu().item() | ||
return value | ||
|
||
|
||
def expand_covariance_matrix( | ||
covariance_matrix: GaussianCovarianceMatrix, | ||
current_inv: torch.Tensor, | ||
data: torch.Tensor, | ||
target: torch.Tensor, | ||
data_idx: int, | ||
covariance_matrix_indices: List[int], | ||
selected_indices: List[int], | ||
) -> GaussianCovarianceMatrix: | ||
""" | ||
Expands the given covariance matrix with `data_idx`. | ||
:return: Expanded covariance matrix. | ||
Time complexity: O(n^2) | ||
""" | ||
d = data.size(1) | ||
unique_selected_data = ( | ||
data[torch.tensor(covariance_matrix_indices).to(data.device)] | ||
if len(covariance_matrix_indices) > 0 | ||
else torch.empty(0, d) | ||
) # (n', d) | ||
selected_data = ( | ||
data[torch.tensor(selected_indices).to(data.device)] | ||
if len(selected_indices) > 0 | ||
else torch.empty(0, d) | ||
) # (n, d) | ||
new_data = data[data_idx] # (d,) | ||
joint_data = torch.cat( | ||
(target, unique_selected_data, new_data.unsqueeze(0)), dim=0 | ||
) # (m+n'+1, d) | ||
I_d = torch.eye(d).to(selected_data.device) | ||
covariance_vector = ( | ||
joint_data @ (I_d - selected_data.T @ current_inv @ selected_data) @ new_data | ||
) # (m+n'+1,) | ||
assert covariance_vector.size(0) == covariance_matrix.dim + 1 | ||
return covariance_matrix.expand(covariance_vector) | ||
|
||
|
||
def update_inverse( | ||
A_inv: torch.Tensor, B: torch.Tensor, C: torch.Tensor | ||
) -> torch.Tensor: | ||
r""" | ||
Updates the inverse of $n \times n$ a matrix $A$ after adding a new off-diagonal $n \times m$ block $B$ with diagonal $m \times m$ block $C$. | ||
Uses the Sherman-Morrison-Woodbury formula. | ||
Time complexity: O(n^2 + m^3) | ||
""" | ||
if A_inv.numel() == 0: # Check if the previous matrix is empty | ||
return torch.inverse(C) | ||
|
||
D = torch.inverse(C - B.T @ A_inv @ B) | ||
A_inv_B = A_inv @ B | ||
|
||
upper_left = A_inv + A_inv_B @ D @ A_inv_B.T | ||
upper_right = -A_inv_B @ D | ||
lower_left = -D @ A_inv_B.T | ||
lower_right = D | ||
|
||
# Combine blocks into the full inverse matrix | ||
upper = torch.cat((upper_left, upper_right), dim=1) | ||
lower = torch.cat((lower_left, lower_right), dim=1) | ||
A_inv_extended = torch.cat((upper, lower), dim=0) | ||
|
||
return A_inv_extended |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.