Skip to content

Commit

Permalink
implement lazy selection (#58)
Browse files Browse the repository at this point in the history
* implement lazy selection

* update docstring in prio queue

* complete implementation of acquisition function

* setup faiss adapter

* small changes

* fix dependencies

* cleanup

* fix pyright error

* cleanup

* update docstring

* drop assumption that priority queue is ordered at initialization

* update docstring

* fix initial values in pq

* fix signs in the value computation

* black

* fix computation of cached inverse matrix

* fix computation of cached inverse

* only recompute inv when necessary

* fix referenced faiss ip index

* fix pyright error

* fixes

* black

* cleanup
  • Loading branch information
jonhue authored Sep 12, 2024
1 parent 0952ca9 commit cb26cb5
Show file tree
Hide file tree
Showing 5 changed files with 440 additions and 2 deletions.
366 changes: 366 additions & 0 deletions afsl/acquisition_functions/lazy_vtl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,366 @@
from typing import List, NamedTuple, Tuple
import numpy as np
import torch
from afsl.acquisition_functions import (
EmbeddingBased,
SequentialAcquisitionFunction,
Targeted,
)
from afsl.gaussian import GaussianCovarianceMatrix
from afsl.model import ModelWithEmbeddingOrKernel
from afsl.utils import (
DEFAULT_EMBEDDING_BATCH_SIZE,
DEFAULT_MINI_BATCH_SIZE,
DEFAULT_NUM_WORKERS,
DEFAULT_SUBSAMPLE,
PriorityQueue,
)

__all__ = ["LazyVTL", "LazyVTLState"]


class LazyVTLState(NamedTuple):
"""State of lazy VTL."""

covariance_matrix: GaussianCovarianceMatrix
r"""Kernel matrix of the data. Tensor of shape $n \times n$."""
m: int
"""Size of the target space."""
selected_indices: List[int]
"""Indices of points that were already observed."""
covariance_matrix_indices: List[int]
"""Indices of points that were added to the covariance matrix (excluding the initially added target space)."""
target: torch.Tensor
r"""Tensor of shape $m \times d$ which includes target space."""
data: torch.Tensor
r"""Tensor of shape $n \times d$ which includes sample space."""
current_inv: torch.Tensor
"""Current inverse of the covariance matrix of selected data."""


class LazyVTL(
Targeted,
EmbeddingBased,
SequentialAcquisitionFunction[ModelWithEmbeddingOrKernel | None, LazyVTLState],
):
noise_std: float
"""Standard deviation of the noise. Determined automatically if set to `None`."""

priority_queue: PriorityQueue | None = None
"""Priority queue over the data set."""

def __init__(
self,
target: torch.Tensor,
noise_std: float,
subsampled_target_frac: float = 1,
max_target_size: int | None = None,
mini_batch_size=DEFAULT_MINI_BATCH_SIZE,
embedding_batch_size=DEFAULT_EMBEDDING_BATCH_SIZE,
num_workers=DEFAULT_NUM_WORKERS,
subsample=DEFAULT_SUBSAMPLE,
):
"""
:param target: Tensor of prediction targets (shape $m \times d$).
:param noise_std: Standard deviation of the noise.
:param subsampled_target_frac: Fraction of the target to be subsampled in each iteration. Must be in $(0,1]$. Default is $1$. Ignored if `target` is `None`.
:param max_target_size: Maximum size of the target to be subsampled in each iteration. Default is `None` in which case the target may be arbitrarily large. Ignored if `target` is `None`.
:param mini_batch_size: Size of mini-batch used for computing the acquisition function.
:param embedding_batch_size: Batch size used for computing the embeddings.
:param num_workers: Number of workers used for parallel computation.
:param subsample: Whether to subsample the data set.
"""
SequentialAcquisitionFunction.__init__(
self,
mini_batch_size=mini_batch_size,
num_workers=num_workers,
subsample=subsample,
force_nonsequential=False,
)
Targeted.__init__(
self,
target=target,
subsampled_target_frac=subsampled_target_frac,
max_target_size=max_target_size,
)
EmbeddingBased.__init__(self, embedding_batch_size=embedding_batch_size)
self.noise_std = noise_std

def set_initial_priority_queue(
self,
embeddings: torch.Tensor,
target_embedding: torch.Tensor,
inner_products: torch.Tensor | None = None,
):
r"""
Constructs the initial priority queue (of length $k$) over the data set.
:param embeddings: Array of shape $k \times d$ containing the data point embeddings.
:param target_embedding: Array of shape $d$ containing the (mean) target embedding.
:param inner_products: Array of length $k$ containing precomputed (absolute) inner products of the data point embeddings with the query embedding.
"""
self_inner_products = torch.sum(embeddings * embeddings, dim=1)
if inner_products is None:
inner_products = embeddings @ target_embedding
values = inner_products**2 / ( # target_inner_product -
self_inner_products + self.noise_var
)

self.priority_queue = PriorityQueue(values=values.cpu().tolist())

def select_from_minibatch(
self,
batch_size: int,
model: ModelWithEmbeddingOrKernel | None,
data: torch.Tensor,
device: torch.device | None,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Selects the next batch from the given mini batch `data`.
:param batch_size: Size of the batch to be selected. Needs to be smaller than `mini_batch_size`.
:param model: Model used for data selection.
:param data: Mini batch of inputs (shape $n \times d$) to be selected from.
:param device: Device used for computation of the acquisition function.
:return: Indices of the newly selected batch (with respect to mini batch) and corresponding values of the acquisition function.
"""
state = self.initialize(model, data, device)

assert (
self.priority_queue is not None
), "Initial priority queue must be set." # TODO: make optional, compute from scratch if not provided
assert self.priority_queue.size() == data.size(
0
), "Size of the priority queue must match the size of the data set."

selected_values = []
for _ in range(batch_size):
while True:
i, _ = self.priority_queue.pop()
new_value, new_covariance_matrix, state = self.recompute(state, i)

prev_top_value = self.priority_queue.top_value
self.priority_queue.push(i, new_value)
if (
new_value >= prev_top_value
): # done if the value is larger than the largest upper bound of other points
break # (i, new_value) contains the next selected index and its value
selected_values.append(new_value)
state = self.step(state, i, new_covariance_matrix)
return torch.tensor(state.selected_indices), torch.tensor(selected_values)

def initialize(
self,
model: ModelWithEmbeddingOrKernel | None,
data: torch.Tensor,
device: torch.device | None,
) -> LazyVTLState:
target = self.get_target()
m = target.size(0)

# Compute covariance matrix of targets
assert model is None, "embedding computation via model not supported"
covariance_matrix = GaussianCovarianceMatrix.from_embeddings(
Embeddings=target.to(device)
)

if self.priority_queue is None:
self.set_initial_priority_queue(
embeddings=data,
target_embedding=torch.mean(target, dim=0),
)

return LazyVTLState(
covariance_matrix=covariance_matrix,
m=m,
selected_indices=[],
covariance_matrix_indices=[],
target=target,
data=data,
current_inv=torch.empty(0, 0),
)

def recompute(
self, state: LazyVTLState, data_idx: int
) -> Tuple[float, GaussianCovarianceMatrix, LazyVTLState]:
"""
Update value of data point `data_idx`.
"""
try: # index of data point within covariance matrix
idx = state.m + state.covariance_matrix_indices.index(data_idx)
new_covariance_matrix = state.covariance_matrix
new_state = state
except ValueError:
# update cached inverse covariance matrix of previously selected data, O(n^2)
i = state.current_inv.size(0)
if len(state.selected_indices) > i:
d = state.data.size(1)
prev_data = (
state.data[
torch.tensor(state.selected_indices[:i]).to(state.data.device)
]
if i > 0
else torch.empty(0, d)
) # (i, d)
new_data = state.data[
torch.tensor(state.selected_indices[i:]).to(state.data.device)
] # (n-1-i, d)
B = prev_data @ new_data.T # (n-1,)
I = torch.eye(new_data.size(0)).to(new_data.device)
C = new_data @ new_data.T + self.noise_var * I # (n-1-i, n-1-i)
new_inv = update_inverse(A_inv=state.current_inv, B=B, C=C)
new_state = LazyVTLState(
covariance_matrix=state.covariance_matrix,
m=state.m,
selected_indices=state.selected_indices,
covariance_matrix_indices=state.covariance_matrix_indices,
target=state.target,
data=state.data,
current_inv=new_inv,
)
else:
new_state = state

# expand the stored covariance matrix if selected data has not been selected before, O(n^2)
new_covariance_matrix = expand_covariance_matrix(
covariance_matrix=state.covariance_matrix,
current_inv=new_state.current_inv,
data=state.data,
target=state.target,
data_idx=data_idx,
covariance_matrix_indices=state.covariance_matrix_indices,
selected_indices=state.selected_indices,
)
idx = new_covariance_matrix.dim - 1

value = compute(
covariance_matrix=new_covariance_matrix,
idx=idx,
noise_var=self.noise_var,
m=state.m,
)
return value, new_covariance_matrix, new_state

def step(
self,
state: LazyVTLState,
data_idx: int,
new_covariance_matrix: GaussianCovarianceMatrix,
) -> LazyVTLState:
"""
Advances the state.
Updates the stored covariance matrix and the inverse of the covariance matrix (restricted to selected data).
"""
if data_idx not in state.covariance_matrix_indices:
state.covariance_matrix_indices.append(
data_idx
) # Note: not treating as immutable!

# update the stored covariance matrix by conditioning on the new data point, O(n^2)
idx = state.m + state.covariance_matrix_indices.index(data_idx)
posterior_covariance_matrix = new_covariance_matrix.condition_on(
idx, noise_std=self.noise_std
)

state.selected_indices.append(data_idx) # Note: not treating as immutable!
return LazyVTLState(
covariance_matrix=posterior_covariance_matrix,
m=state.m,
selected_indices=state.selected_indices,
covariance_matrix_indices=state.covariance_matrix_indices,
target=state.target,
data=state.data,
current_inv=state.current_inv,
)

@property
def noise_var(self):
return self.noise_std**2

def compute(self, state: LazyVTLState) -> torch.Tensor:
raise NotImplementedError


def compute(
covariance_matrix: GaussianCovarianceMatrix, idx: int, noise_var: float, m: int
) -> float:
"""
Computes the acquisition value of the data point at index `idx` within the covariance matrix.
Time complexity: O(1)
"""

def engine(i, j):
return covariance_matrix[i, j] ** 2 / (covariance_matrix[j, j] + noise_var)

target_indices = torch.arange(m)
values = engine(target_indices, idx)
value = torch.sum(values, dim=0).cpu().item()
return value


def expand_covariance_matrix(
covariance_matrix: GaussianCovarianceMatrix,
current_inv: torch.Tensor,
data: torch.Tensor,
target: torch.Tensor,
data_idx: int,
covariance_matrix_indices: List[int],
selected_indices: List[int],
) -> GaussianCovarianceMatrix:
"""
Expands the given covariance matrix with `data_idx`.
:return: Expanded covariance matrix.
Time complexity: O(n^2)
"""
d = data.size(1)
unique_selected_data = (
data[torch.tensor(covariance_matrix_indices).to(data.device)]
if len(covariance_matrix_indices) > 0
else torch.empty(0, d)
) # (n', d)
selected_data = (
data[torch.tensor(selected_indices).to(data.device)]
if len(selected_indices) > 0
else torch.empty(0, d)
) # (n, d)
new_data = data[data_idx] # (d,)
joint_data = torch.cat(
(target, unique_selected_data, new_data.unsqueeze(0)), dim=0
) # (m+n'+1, d)
I_d = torch.eye(d).to(selected_data.device)
covariance_vector = (
joint_data @ (I_d - selected_data.T @ current_inv @ selected_data) @ new_data
) # (m+n'+1,)
assert covariance_vector.size(0) == covariance_matrix.dim + 1
return covariance_matrix.expand(covariance_vector)


def update_inverse(
A_inv: torch.Tensor, B: torch.Tensor, C: torch.Tensor
) -> torch.Tensor:
r"""
Updates the inverse of $n \times n$ a matrix $A$ after adding a new off-diagonal $n \times m$ block $B$ with diagonal $m \times m$ block $C$.
Uses the Sherman-Morrison-Woodbury formula.
Time complexity: O(n^2 + m^3)
"""
if A_inv.numel() == 0: # Check if the previous matrix is empty
return torch.inverse(C)

D = torch.inverse(C - B.T @ A_inv @ B)
A_inv_B = A_inv @ B

upper_left = A_inv + A_inv_B @ D @ A_inv_B.T
upper_right = -A_inv_B @ D
lower_left = -D @ A_inv_B.T
lower_right = D

# Combine blocks into the full inverse matrix
upper = torch.cat((upper_left, upper_right), dim=1)
lower = torch.cat((lower_left, lower_right), dim=1)
A_inv_extended = torch.cat((upper, lower), dim=0)

return A_inv_extended
8 changes: 8 additions & 0 deletions afsl/adapters/faiss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import NamedTuple, Tuple
from warnings import warn
from afsl.acquisition_functions import AcquisitionFunction, Targeted
from afsl.acquisition_functions.lazy_vtl import LazyVTL
import faiss # type: ignore
import torch
import time
Expand Down Expand Up @@ -53,6 +55,12 @@ def __init__(
only_faiss: bool = False,
device: torch.device | None = None,
):
"""
:param index: Faiss index object.
:param acquisition_function: Acquisition function object.
:param only_faiss: Whether to only use Faiss for search.
:param device: Device to use for computation.
"""
self.index = index
self.acquisition_function = acquisition_function
self.only_faiss = only_faiss
Expand Down
Loading

0 comments on commit cb26cb5

Please sign in to comment.