Skip to content

Commit

Permalink
initial implementation (#2)
Browse files Browse the repository at this point in the history
* initial implementation

* updates

* update CI action

* update public submodules

* fix docs generation

* update build system

* fix black error
  • Loading branch information
jonhue authored Feb 19, 2024
1 parent b60c3bc commit 91c3e63
Show file tree
Hide file tree
Showing 35 changed files with 886 additions and 89 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
name: CI

on: [push, pull_request]
on:
push:
branches:
- main
pull_request:

jobs:
build:
Expand Down
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ To start a local server hosting the documentation run ```pdoc ./afsl --math```.

### Publishing

1. update version number in `setup.py` and `afsl/__init__.py`
2. test package metadata: `python setup.py check`
3. generate distribution archives: `python setup.py sdist`
4. *(optional)* upload to test PyPI: `twine upload --repository-url https://test.pypi.org/legacy/ dist/active-few-shot-learning-VERSION.tar.gz`
5. *(optional)* test installation from test PyPI: `pip install --index-url https://test.pypi.org/simple/ active-few-shot-learning --user`
6. upload to PyPI: `twine upload dist/active-few-shot-learning-VERSION.tar.gz`
1. update version number in `pyproject.toml` and `afsl/__init__.py`
2. build: `poetry build`
3. publish: `poetry publish`
4. push version update to GitHub
5. create new release on GitHub
9 changes: 9 additions & 0 deletions afsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@
Active Few-Shot Learning
"""

from afsl.active_data_loader import ActiveDataLoader
from afsl import acquisition_functions, embeddings, model

__all__ = [
"ActiveDataLoader",
"acquisition_functions",
"embeddings",
"model",
]
__version__ = "0.1.0"
__author__ = "Jonas Hübotter"
__credits__ = "ETH Zurich, Switzerland"
122 changes: 122 additions & 0 deletions afsl/acquisition_functions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
import torch
from afsl.embeddings import M, Embedding
from afsl.types import Target
from afsl.utils import mini_batch_wrapper, mini_batch_wrapper_non_cat


class AcquisitionFunction(ABC):
mini_batch_size: int

def __init__(self, mini_batch_size: int = 100):
self.mini_batch_size = mini_batch_size

@abstractmethod
def select(
self,
batch_size: int,
embedding: Embedding[M],
model: M,
data: torch.Tensor,
target: Target,
Sigma: torch.Tensor | None = None,
force_nonsequential=False,
) -> torch.Tensor:
pass


class BatchAcquisitionFunction(AcquisitionFunction):
@abstractmethod
def compute(
self,
embedding: Embedding[M],
model: M,
data: torch.Tensor,
target: Target,
Sigma: torch.Tensor | None = None,
) -> torch.Tensor:
pass

def select(
self,
batch_size: int,
embedding: Embedding[M],
model: M,
data: torch.Tensor,
target: Target,
Sigma: torch.Tensor | None = None,
force_nonsequential=False,
) -> torch.Tensor:
values = mini_batch_wrapper(
fn=lambda batch: self.compute(
embedding=embedding,
model=model,
data=batch,
target=target,
Sigma=Sigma,
),
data=data,
batch_size=self.mini_batch_size,
)
_, indices = torch.topk(values, batch_size)
return indices


State = TypeVar("State")


class SequentialAcquisitionFunction(AcquisitionFunction, Generic[State]):
@abstractmethod
def initialize(
self,
embedding: Embedding[M],
model: M,
data: torch.Tensor,
target: Target,
Sigma: torch.Tensor | None = None,
) -> State:
pass

@abstractmethod
def compute(self, state: State) -> torch.Tensor:
pass

@abstractmethod
def step(self, state: State, i: int) -> State:
pass

def select(
self,
batch_size: int,
embedding: Embedding[M],
model: M,
data: torch.Tensor,
target: Target,
Sigma: torch.Tensor | None = None,
force_nonsequential=False,
) -> torch.Tensor:
states = mini_batch_wrapper_non_cat(
fn=lambda batch: self.initialize(
embedding=embedding,
model=model,
data=batch,
target=target,
Sigma=Sigma,
),
data=data,
batch_size=self.mini_batch_size,
)

if force_nonsequential:
values = torch.cat([self.compute(state) for state in states], dim=0)
_, indices = torch.topk(values, batch_size)
return indices
else:
indices = []
for _ in range(batch_size):
values = torch.cat([self.compute(state) for state in states], dim=0)
i = int(torch.argmax(values).item())
indices.append(i)
states = [self.step(state, i) for state in states]
return torch.tensor(indices)
42 changes: 42 additions & 0 deletions afsl/acquisition_functions/bace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import NamedTuple
import torch
from afsl.acquisition_functions import SequentialAcquisitionFunction
from afsl.embeddings import M, Embedding
from afsl.gaussian import GaussianCovarianceMatrix
from afsl.types import Target


class BaCEState(NamedTuple):
covariance_matrix: GaussianCovarianceMatrix
n: int


class BaCE(SequentialAcquisitionFunction):
noise_std: float

def __init__(self, noise_std=1.0):
super().__init__()
self.noise_std = noise_std

def initialize(
self,
embedding: Embedding[M],
model: M,
data: torch.Tensor,
target: Target,
Sigma: torch.Tensor | None = None,
) -> BaCEState:
assert target is not None, "Target must be non-empty"

n = data.size(0)
data_embeddings = embedding.embed(model, data)
target_embeddings = embedding.embed(model, target)
joint_embeddings = torch.cat((data_embeddings, target_embeddings))
covariance_matrix = GaussianCovarianceMatrix.from_embeddings(
noise_std=self.noise_std, Embeddings=joint_embeddings, Sigma=Sigma
)
return BaCEState(covariance_matrix=covariance_matrix, n=n)

def step(self, state: BaCEState, i: int) -> BaCEState:
posterior_covariance_matrix = state.covariance_matrix.condition_on(i)
return BaCEState(covariance_matrix=posterior_covariance_matrix, n=state.n)
82 changes: 82 additions & 0 deletions afsl/acquisition_functions/badge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import List, NamedTuple
import torch
from afsl.acquisition_functions import SequentialAcquisitionFunction
from afsl.embeddings import M, Embedding
from afsl.types import Target
from afsl.utils import mini_batch_wrapper_non_cat


class BADGEState(NamedTuple):
embeddings: torch.Tensor
centroid_indices: List[torch.Tensor]


def compute_distances(embeddings, centroids):
# Compute the distance of all points in embeddings from each centroid
distances = torch.cdist(embeddings, centroids, p=2)
# Return the minimum distance for each point
min_distances = torch.min(distances, dim=1).values
return min_distances


class BADGE(SequentialAcquisitionFunction[BADGEState]):
def initialize(
self,
embedding: Embedding[M],
model: M,
data: torch.Tensor,
target: Target,
Sigma: torch.Tensor | None = None,
) -> BADGEState:
embeddings = embedding.embed(model, data)
# Choose the first centroid randomly
centroid_indices = [
torch.randint(0, embeddings.size(0), (1,)).to(embeddings.device)
]
return BADGEState(embeddings=embeddings, centroid_indices=centroid_indices)

def step(self, state: BADGEState, i: int) -> BADGEState:
state.centroid_indices.append(torch.tensor(i).to(state.embeddings.device))
return state

def compute(self, state: BADGEState) -> torch.Tensor:
# Compute the distance of each point to the nearest centroid
centroids = state.embeddings[
torch.cat(state.centroid_indices).to(state.embeddings.device)
]
sqd_distances = torch.square(compute_distances(state.embeddings, centroids))
# Choose the next centroid with a probability proportional to the square of the distance
probabilities = sqd_distances / sqd_distances.sum()
return probabilities

def select(
self,
batch_size: int,
embedding: Embedding[M],
model: M,
data: torch.Tensor,
target: Target,
Sigma: torch.Tensor | None = None,
force_nonsequential=False,
) -> torch.Tensor:
assert not force_nonsequential, "Non-sequential selection is not supported"

states = mini_batch_wrapper_non_cat(
fn=lambda batch: self.initialize(
embedding=embedding,
model=model,
data=batch,
target=target,
Sigma=Sigma,
),
data=data,
batch_size=self.mini_batch_size,
)

indices = []
for _ in range(batch_size):
probabilities = torch.cat([self.compute(state) for state in states], dim=0)
i = int(torch.multinomial(probabilities, num_samples=1).item())
indices.append(i)
states = [self.step(state, i) for state in states]
return torch.tensor(indices)
35 changes: 35 additions & 0 deletions afsl/acquisition_functions/cosine_similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
import torch.nn.functional as F
from afsl.acquisition_functions import BatchAcquisitionFunction
from afsl.embeddings import Embedding
from afsl.model import LatentModel
from afsl.types import Target
from afsl.utils import get_device


class CosineSimilarity(BatchAcquisitionFunction):
def compute(
self,
embedding: Embedding,
model: LatentModel,
data: torch.Tensor,
target: Target,
Sigma: torch.Tensor | None = None,
) -> torch.Tensor:
assert target is not None, "Target must be non-empty"

model.eval()
device = get_device(model)
with torch.no_grad():
data_latent = model.latent(data.to(device))
target_latent = model.latent(target.to(device))

data_latent_normalized = F.normalize(data_latent, p=2, dim=1)
target_latent_normalized = F.normalize(target_latent, p=2, dim=1)

cosine_similarities = torch.matmul(
data_latent_normalized, target_latent_normalized.T
)

average_cosine_similarities = torch.mean(cosine_similarities, dim=1)
return average_cosine_similarities
20 changes: 20 additions & 0 deletions afsl/acquisition_functions/ctl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
import wandb
from afsl.acquisition_functions.bace import BaCE, BaCEState


class CTL(BaCE):
def compute(self, state: BaCEState) -> torch.Tensor:
ind_a = torch.arange(state.n)
ind_b = torch.arange(state.n, state.covariance_matrix.dim)
covariance_aa = state.covariance_matrix[ind_a, :][:, ind_a]
covariance_bb = state.covariance_matrix[ind_b, :][:, ind_b]
covariance_ab = state.covariance_matrix[ind_a, :][:, ind_b]

std_a = torch.sqrt(torch.diag(covariance_aa))
std_b = torch.sqrt(torch.diag(covariance_bb))
std_ab = torch.ger(std_a, std_b) # outer product of standard deviations

correlations = covariance_ab / std_ab
average_correlations = torch.mean(correlations, dim=1)
return average_correlations
33 changes: 33 additions & 0 deletions afsl/acquisition_functions/greedy_max_det.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
import wandb
from afsl.acquisition_functions.bace import BaCE, BaCEState
from afsl.embeddings import M, Embedding
from afsl.gaussian import GaussianCovarianceMatrix
from afsl.types import Target


class GreedyMaxDet(BaCE):
def compute(self, state: BaCEState) -> torch.Tensor:
variances = torch.diag(state.covariance_matrix[:, :])
wandb.log(
{
"max_var": torch.max(variances),
"min_var": torch.min(variances),
}
)
return variances

def initialize(
self,
embedding: Embedding[M],
model: M,
data: torch.Tensor,
target: Target,
Sigma: torch.Tensor | None = None,
) -> BaCEState:
n = data.size(0)
data_embeddings = embedding.embed(model, data)
covariance_matrix = GaussianCovarianceMatrix.from_embeddings(
noise_std=self.noise_std, Embeddings=data_embeddings, Sigma=Sigma
)
return BaCEState(covariance_matrix=covariance_matrix, n=n)
Loading

0 comments on commit 91c3e63

Please sign in to comment.