-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial implementation * updates * update CI action * update public submodules * fix docs generation * update build system * fix black error
- Loading branch information
Showing
35 changed files
with
886 additions
and
89 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,10 @@ | ||
name: CI | ||
|
||
on: [push, pull_request] | ||
on: | ||
push: | ||
branches: | ||
- main | ||
pull_request: | ||
|
||
jobs: | ||
build: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Generic, TypeVar | ||
import torch | ||
from afsl.embeddings import M, Embedding | ||
from afsl.types import Target | ||
from afsl.utils import mini_batch_wrapper, mini_batch_wrapper_non_cat | ||
|
||
|
||
class AcquisitionFunction(ABC): | ||
mini_batch_size: int | ||
|
||
def __init__(self, mini_batch_size: int = 100): | ||
self.mini_batch_size = mini_batch_size | ||
|
||
@abstractmethod | ||
def select( | ||
self, | ||
batch_size: int, | ||
embedding: Embedding[M], | ||
model: M, | ||
data: torch.Tensor, | ||
target: Target, | ||
Sigma: torch.Tensor | None = None, | ||
force_nonsequential=False, | ||
) -> torch.Tensor: | ||
pass | ||
|
||
|
||
class BatchAcquisitionFunction(AcquisitionFunction): | ||
@abstractmethod | ||
def compute( | ||
self, | ||
embedding: Embedding[M], | ||
model: M, | ||
data: torch.Tensor, | ||
target: Target, | ||
Sigma: torch.Tensor | None = None, | ||
) -> torch.Tensor: | ||
pass | ||
|
||
def select( | ||
self, | ||
batch_size: int, | ||
embedding: Embedding[M], | ||
model: M, | ||
data: torch.Tensor, | ||
target: Target, | ||
Sigma: torch.Tensor | None = None, | ||
force_nonsequential=False, | ||
) -> torch.Tensor: | ||
values = mini_batch_wrapper( | ||
fn=lambda batch: self.compute( | ||
embedding=embedding, | ||
model=model, | ||
data=batch, | ||
target=target, | ||
Sigma=Sigma, | ||
), | ||
data=data, | ||
batch_size=self.mini_batch_size, | ||
) | ||
_, indices = torch.topk(values, batch_size) | ||
return indices | ||
|
||
|
||
State = TypeVar("State") | ||
|
||
|
||
class SequentialAcquisitionFunction(AcquisitionFunction, Generic[State]): | ||
@abstractmethod | ||
def initialize( | ||
self, | ||
embedding: Embedding[M], | ||
model: M, | ||
data: torch.Tensor, | ||
target: Target, | ||
Sigma: torch.Tensor | None = None, | ||
) -> State: | ||
pass | ||
|
||
@abstractmethod | ||
def compute(self, state: State) -> torch.Tensor: | ||
pass | ||
|
||
@abstractmethod | ||
def step(self, state: State, i: int) -> State: | ||
pass | ||
|
||
def select( | ||
self, | ||
batch_size: int, | ||
embedding: Embedding[M], | ||
model: M, | ||
data: torch.Tensor, | ||
target: Target, | ||
Sigma: torch.Tensor | None = None, | ||
force_nonsequential=False, | ||
) -> torch.Tensor: | ||
states = mini_batch_wrapper_non_cat( | ||
fn=lambda batch: self.initialize( | ||
embedding=embedding, | ||
model=model, | ||
data=batch, | ||
target=target, | ||
Sigma=Sigma, | ||
), | ||
data=data, | ||
batch_size=self.mini_batch_size, | ||
) | ||
|
||
if force_nonsequential: | ||
values = torch.cat([self.compute(state) for state in states], dim=0) | ||
_, indices = torch.topk(values, batch_size) | ||
return indices | ||
else: | ||
indices = [] | ||
for _ in range(batch_size): | ||
values = torch.cat([self.compute(state) for state in states], dim=0) | ||
i = int(torch.argmax(values).item()) | ||
indices.append(i) | ||
states = [self.step(state, i) for state in states] | ||
return torch.tensor(indices) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from typing import NamedTuple | ||
import torch | ||
from afsl.acquisition_functions import SequentialAcquisitionFunction | ||
from afsl.embeddings import M, Embedding | ||
from afsl.gaussian import GaussianCovarianceMatrix | ||
from afsl.types import Target | ||
|
||
|
||
class BaCEState(NamedTuple): | ||
covariance_matrix: GaussianCovarianceMatrix | ||
n: int | ||
|
||
|
||
class BaCE(SequentialAcquisitionFunction): | ||
noise_std: float | ||
|
||
def __init__(self, noise_std=1.0): | ||
super().__init__() | ||
self.noise_std = noise_std | ||
|
||
def initialize( | ||
self, | ||
embedding: Embedding[M], | ||
model: M, | ||
data: torch.Tensor, | ||
target: Target, | ||
Sigma: torch.Tensor | None = None, | ||
) -> BaCEState: | ||
assert target is not None, "Target must be non-empty" | ||
|
||
n = data.size(0) | ||
data_embeddings = embedding.embed(model, data) | ||
target_embeddings = embedding.embed(model, target) | ||
joint_embeddings = torch.cat((data_embeddings, target_embeddings)) | ||
covariance_matrix = GaussianCovarianceMatrix.from_embeddings( | ||
noise_std=self.noise_std, Embeddings=joint_embeddings, Sigma=Sigma | ||
) | ||
return BaCEState(covariance_matrix=covariance_matrix, n=n) | ||
|
||
def step(self, state: BaCEState, i: int) -> BaCEState: | ||
posterior_covariance_matrix = state.covariance_matrix.condition_on(i) | ||
return BaCEState(covariance_matrix=posterior_covariance_matrix, n=state.n) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from typing import List, NamedTuple | ||
import torch | ||
from afsl.acquisition_functions import SequentialAcquisitionFunction | ||
from afsl.embeddings import M, Embedding | ||
from afsl.types import Target | ||
from afsl.utils import mini_batch_wrapper_non_cat | ||
|
||
|
||
class BADGEState(NamedTuple): | ||
embeddings: torch.Tensor | ||
centroid_indices: List[torch.Tensor] | ||
|
||
|
||
def compute_distances(embeddings, centroids): | ||
# Compute the distance of all points in embeddings from each centroid | ||
distances = torch.cdist(embeddings, centroids, p=2) | ||
# Return the minimum distance for each point | ||
min_distances = torch.min(distances, dim=1).values | ||
return min_distances | ||
|
||
|
||
class BADGE(SequentialAcquisitionFunction[BADGEState]): | ||
def initialize( | ||
self, | ||
embedding: Embedding[M], | ||
model: M, | ||
data: torch.Tensor, | ||
target: Target, | ||
Sigma: torch.Tensor | None = None, | ||
) -> BADGEState: | ||
embeddings = embedding.embed(model, data) | ||
# Choose the first centroid randomly | ||
centroid_indices = [ | ||
torch.randint(0, embeddings.size(0), (1,)).to(embeddings.device) | ||
] | ||
return BADGEState(embeddings=embeddings, centroid_indices=centroid_indices) | ||
|
||
def step(self, state: BADGEState, i: int) -> BADGEState: | ||
state.centroid_indices.append(torch.tensor(i).to(state.embeddings.device)) | ||
return state | ||
|
||
def compute(self, state: BADGEState) -> torch.Tensor: | ||
# Compute the distance of each point to the nearest centroid | ||
centroids = state.embeddings[ | ||
torch.cat(state.centroid_indices).to(state.embeddings.device) | ||
] | ||
sqd_distances = torch.square(compute_distances(state.embeddings, centroids)) | ||
# Choose the next centroid with a probability proportional to the square of the distance | ||
probabilities = sqd_distances / sqd_distances.sum() | ||
return probabilities | ||
|
||
def select( | ||
self, | ||
batch_size: int, | ||
embedding: Embedding[M], | ||
model: M, | ||
data: torch.Tensor, | ||
target: Target, | ||
Sigma: torch.Tensor | None = None, | ||
force_nonsequential=False, | ||
) -> torch.Tensor: | ||
assert not force_nonsequential, "Non-sequential selection is not supported" | ||
|
||
states = mini_batch_wrapper_non_cat( | ||
fn=lambda batch: self.initialize( | ||
embedding=embedding, | ||
model=model, | ||
data=batch, | ||
target=target, | ||
Sigma=Sigma, | ||
), | ||
data=data, | ||
batch_size=self.mini_batch_size, | ||
) | ||
|
||
indices = [] | ||
for _ in range(batch_size): | ||
probabilities = torch.cat([self.compute(state) for state in states], dim=0) | ||
i = int(torch.multinomial(probabilities, num_samples=1).item()) | ||
indices.append(i) | ||
states = [self.step(state, i) for state in states] | ||
return torch.tensor(indices) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from afsl.acquisition_functions import BatchAcquisitionFunction | ||
from afsl.embeddings import Embedding | ||
from afsl.model import LatentModel | ||
from afsl.types import Target | ||
from afsl.utils import get_device | ||
|
||
|
||
class CosineSimilarity(BatchAcquisitionFunction): | ||
def compute( | ||
self, | ||
embedding: Embedding, | ||
model: LatentModel, | ||
data: torch.Tensor, | ||
target: Target, | ||
Sigma: torch.Tensor | None = None, | ||
) -> torch.Tensor: | ||
assert target is not None, "Target must be non-empty" | ||
|
||
model.eval() | ||
device = get_device(model) | ||
with torch.no_grad(): | ||
data_latent = model.latent(data.to(device)) | ||
target_latent = model.latent(target.to(device)) | ||
|
||
data_latent_normalized = F.normalize(data_latent, p=2, dim=1) | ||
target_latent_normalized = F.normalize(target_latent, p=2, dim=1) | ||
|
||
cosine_similarities = torch.matmul( | ||
data_latent_normalized, target_latent_normalized.T | ||
) | ||
|
||
average_cosine_similarities = torch.mean(cosine_similarities, dim=1) | ||
return average_cosine_similarities |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import torch | ||
import wandb | ||
from afsl.acquisition_functions.bace import BaCE, BaCEState | ||
|
||
|
||
class CTL(BaCE): | ||
def compute(self, state: BaCEState) -> torch.Tensor: | ||
ind_a = torch.arange(state.n) | ||
ind_b = torch.arange(state.n, state.covariance_matrix.dim) | ||
covariance_aa = state.covariance_matrix[ind_a, :][:, ind_a] | ||
covariance_bb = state.covariance_matrix[ind_b, :][:, ind_b] | ||
covariance_ab = state.covariance_matrix[ind_a, :][:, ind_b] | ||
|
||
std_a = torch.sqrt(torch.diag(covariance_aa)) | ||
std_b = torch.sqrt(torch.diag(covariance_bb)) | ||
std_ab = torch.ger(std_a, std_b) # outer product of standard deviations | ||
|
||
correlations = covariance_ab / std_ab | ||
average_correlations = torch.mean(correlations, dim=1) | ||
return average_correlations |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import torch | ||
import wandb | ||
from afsl.acquisition_functions.bace import BaCE, BaCEState | ||
from afsl.embeddings import M, Embedding | ||
from afsl.gaussian import GaussianCovarianceMatrix | ||
from afsl.types import Target | ||
|
||
|
||
class GreedyMaxDet(BaCE): | ||
def compute(self, state: BaCEState) -> torch.Tensor: | ||
variances = torch.diag(state.covariance_matrix[:, :]) | ||
wandb.log( | ||
{ | ||
"max_var": torch.max(variances), | ||
"min_var": torch.min(variances), | ||
} | ||
) | ||
return variances | ||
|
||
def initialize( | ||
self, | ||
embedding: Embedding[M], | ||
model: M, | ||
data: torch.Tensor, | ||
target: Target, | ||
Sigma: torch.Tensor | None = None, | ||
) -> BaCEState: | ||
n = data.size(0) | ||
data_embeddings = embedding.embed(model, data) | ||
covariance_matrix = GaussianCovarianceMatrix.from_embeddings( | ||
noise_std=self.noise_std, Embeddings=data_embeddings, Sigma=Sigma | ||
) | ||
return BaCEState(covariance_matrix=covariance_matrix, n=n) |
Oops, something went wrong.