Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TO-REVIEW] Add multi-domain Monge alignment and JCPOT Target shift method #180

Merged
merged 21 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,10 @@ The library is distributed under the 3-Clause BSD license.
[27] S. Si, D. Tao and B. Geng. In IEEE Transactions on Knowledge and Data Engineering, (2010) [Bregman Divergence-Based Regularization for Transfer Subspace Learning](https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=4118b4fc7d61068b9b448fd499876d139baeec81)

[28] Solomon, J., Rustamov, R., Guibas, L., & Butscher, A. (2014, January). [Wasserstein propagation for semi-supervised learning](https://proceedings.mlr.press/v32/solomon14.pdf). In International Conference on Machine Learning (pp. 306-314). PMLR.

[29] Montesuma, Eduardo Fernandes, and Fred Maurice Ngole Mboula. ["Wasserstein barycenter for multi-source domain adaptation."](https://openaccess.thecvf.com/content/CVPR2021/papers/Montesuma_Wasserstein_Barycenter_for_Multi-Source_Domain_Adaptation_CVPR_2021_paper.pdf) In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 16785-16793. 2021.

[30] Gnassounou, Theo, Rémi Flamary, and Alexandre Gramfort. ["Convolution Monge Mapping Normalization for learning on sleep data."](https://proceedings.neurips.cc/paper_files/paper/2023/file/21718991f6acf19a42376b5c7a8668c5-Paper-Conference.pdf) Advances in Neural Information Processing Systems 36 (2024).

[31] Redko, Ievgen, Nicolas Courty, Rémi Flamary, and Devis Tuia.[ "Optimal transport for multi-source domain adaptation under target shift."](https://proceedings.mlr.press/v89/redko19a/redko19a.pdf) In The 22nd International Conference on artificial intelligence and statistics, pp. 849-858. PMLR, 2019.

101 changes: 100 additions & 1 deletion examples/methods/plot_label_prop_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,14 @@
from sklearn.kernel_ridge import KernelRidge
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import mean_squared_error
from sklearn.svm import SVC

from skada import OTLabelPropAdapter, make_da_pipeline, source_target_split
from skada import (
JCPOTLabelPropAdapter,
OTLabelPropAdapter,
make_da_pipeline,
source_target_split,
)
from skada.datasets import make_shifted_datasets

# %%
Expand Down Expand Up @@ -330,3 +336,96 @@

plt.title("Propagated labels data")
plt.axis(ax)


# %%
# Generate classification classification dataset and plot it
# -----------------------------------------------------
#
# We generate a simple 2D target shift dataset.

X, y, sample_domain = make_shifted_datasets(
n_samples_source=20,
n_samples_target=20,
shift="target_shift",
noise=0.2,
random_state=42,
)


Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain)


plt.figure(5, (10, 5))
plt.subplot(1, 2, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source")
plt.title("Source data")
ax = plt.axis()

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target")
plt.title("Target data")
plt.axis(ax)


# %%
# Train with LabelProp and JCPOT + classifier
# -------------------------
#
# On this target shift dataset, we can see that the label propagation method
# does not work well because it finds correspondences between the source and
# target samples with different classes. In this case JCPOT is more robust
# to this kind of shift because it estimates the class proportions in the target.


clf = make_da_pipeline(OTLabelPropAdapter(), SVC())
clf.fit(X, y, sample_domain=sample_domain)

clf_jcpot = make_da_pipeline(JCPOTLabelPropAdapter(reg=0.1), SVC())
clf_jcpot.fit(X, y, sample_domain=sample_domain)


yt_pred = clf.predict(Xt)
acc_t = (yt_pred == yt).mean()

print(f"LabelProp Accuracy on target: {acc_t:.2f}")


yt_pred = clf_jcpot.predict(Xt)
acc_s_jcpot = (yt_pred == yt).mean()

print(f"JCPOT Accuracy on target: {acc_s_jcpot:.2f}")

XX, YY = np.meshgrid(np.linspace(ax[0], ax[1], 100), np.linspace(ax[2], ax[3], 100))
Z = clf.predict(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape)
Z_jcpot = clf_jcpot.predict(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape)


plt.figure(7, (10, 5))


plt.subplot(1, 2, 1)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Prediction")
plt.imshow(
Z,
extent=(ax[0], ax[1], ax[2], ax[3]),
origin="lower",
alpha=0.5,
cmap="tab10",
vmax=9,
)
plt.title(f"LabelProp reglog on target (ACC={acc_t:.2f})")
plt.axis(ax)

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Prediction")
plt.imshow(
Z_jcpot,
extent=(ax[0], ax[1], ax[2], ax[3]),
origin="lower",
alpha=0.5,
cmap="tab10",
vmax=9,
)
plt.title(f"JCPOT reglog on target (ACC={acc_s_jcpot:.2f})")
plt.axis(ax)
259 changes: 259 additions & 0 deletions examples/methods/plot_monge_alignment_da.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
"""
Multi-domain Linear Monge Alignment
===================================

This example illustrates the use of the MultiLinearMongeAlignmentAdapter

"""

# Author: Remi Flamary
#
# License: BSD 3-Clause
# sphinx_gallery_thumbnail_number = 4

# %% Imports
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LogisticRegression

from skada import (
MultiLinearMongeAlignmentAdapter,
make_da_pipeline,
source_target_split,
)
from skada.datasets import make_shifted_datasets

# %%
# Generate concept drift classification dataset and plot it
# -----------------------------------------------------
#
# We generate a simple 2D concept drift dataset.

X, y, sample_domain = make_shifted_datasets(
n_samples_source=20,
n_samples_target=20,
shift="concept_drift",
noise=0.2,
label="multiclass",
random_state=42,
)


Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain)


plt.figure(5, (10, 5))
plt.subplot(1, 2, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source")
plt.title("Source data")
ax = plt.axis()

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target")
plt.title("Target data")
plt.axis(ax)

# %%
# Train a classifier on source data
# --------------------------------
#
# We train a simple SVC classifier on the source domain and evaluate its
# performance on the source and target domain. Performance is much lower on
# the target domain due to the shift. We also plot the decision boundary


clf = MultiLinearMongeAlignmentAdapter()
clf.fit(X, sample_domain=sample_domain)

X_adapt = clf.transform(X, sample_domain=sample_domain, allow_source=True)


plt.figure(5, (10, 3))
plt.subplot(1, 3, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source")
plt.title("Source data")
ax = plt.axis()

plt.subplot(1, 3, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target")
plt.title("Target data")
plt.axis(ax)

plt.subplot(1, 3, 3)
plt.scatter(
X_adapt[sample_domain >= 0, 0],
X_adapt[sample_domain >= 0, 1],
c=y[sample_domain >= 0],
marker="o",
cmap="tab10",
vmax=9,
label="Source",
alpha=0.5,
)
plt.scatter(
X_adapt[sample_domain < 0, 0],
X_adapt[sample_domain < 0, 1],
c=y[sample_domain < 0],
marker="x",
cmap="tab10",
vmax=9,
label="Target",
alpha=1,
)
plt.legend()
plt.title("Adapted data")


# %%
# Train a classifier on adapted data
# ----------------------------------

clf = make_da_pipeline(
MultiLinearMongeAlignmentAdapter(),
LogisticRegression(),
)

clf.fit(X, y, sample_domain=sample_domain)

print(
"Average accuracy on all domains:",
clf.score(X, y, sample_domain=sample_domain, allow_source=True),
)

# %% Multisource and taregt data


def get_multidomain_data(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should go to toy datasets, right?

n_samples_source=100,
n_samples_target=100,
noise=0.1,
random_state=None,
n_sources=3,
n_targets=2,
):
np.random.seed(random_state)
X, y, sample_domain = make_shifted_datasets(
n_samples_source=n_samples_source,
n_samples_target=n_samples_target,
noise=noise,
shift="concept_drift",
label="multiclass",
random_state=random_state,
)
for ns in range(n_sources - 1):
Xi, yi, sample_domaini = make_shifted_datasets(
n_samples_source=n_samples_source,
n_samples_target=n_samples_target,
noise=noise,
shift="concept_drift",
label="multiclass",
random_state=random_state + ns,
mean=np.random.randn(2),
sigma=np.random.rand(2) * 0.5 + 0.5,
)
Xs, Xt, ys, yt = source_target_split(Xi, yi, sample_domain=sample_domaini)
X = np.vstack([X, Xt])
y = np.hstack([y, yt])
sample_domain = np.hstack([sample_domain, np.ones(Xt.shape[0]) * (ns + 2)])

for nt in range(n_targets - 1):
Xi, yi, sample_domaini = make_shifted_datasets(
n_samples_source=n_samples_source,
n_samples_target=n_samples_target,
noise=noise,
shift="concept_drift",
label="multiclass",
random_state=random_state + nt + 42,
mean=np.random.randn(2),
sigma=np.random.rand(2) * 0.5 + 0.5,
)
Xs, Xt, ys, yt = source_target_split(Xi, yi, sample_domain=sample_domaini)
X = np.vstack([X, Xt])
y = np.hstack([y, yt])
sample_domain = np.hstack([sample_domain, -np.ones(Xt.shape[0]) * (nt + 1)])

return X, y, sample_domain


X, y, sample_domain = get_multidomain_data(
n_samples_source=50,
n_samples_target=50,
noise=0.1,
random_state=43,
n_sources=3,
n_targets=2,
)

Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain)


plt.figure(5, (10, 5))
plt.subplot(1, 2, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source")
plt.title("Source data")
ax = plt.axis()

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target")
plt.title("Target domains")
plt.axis(ax)


# %%
clf = MultiLinearMongeAlignmentAdapter()
clf.fit(X, sample_domain=sample_domain)

X_adapt = clf.transform(X, sample_domain=sample_domain, allow_source=True)


plt.figure(5, (10, 3))
plt.subplot(1, 3, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source")
plt.title("Source data")
ax = plt.axis()

plt.subplot(1, 3, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target")
plt.title("Target data")
plt.axis(ax)

plt.subplot(1, 3, 3)
plt.scatter(
X_adapt[sample_domain >= 0, 0],
X_adapt[sample_domain >= 0, 1],
c=y[sample_domain >= 0],
marker="o",
cmap="tab10",
vmax=9,
label="Source",
alpha=0.5,
)
plt.scatter(
X_adapt[sample_domain < 0, 0],
X_adapt[sample_domain < 0, 1],
c=y[sample_domain < 0],
marker="x",
cmap="tab10",
vmax=9,
label="Target",
alpha=1,
)
plt.legend()
plt.axis(ax)
plt.title("Adapted data")

# %%
# Train a classifier on adapted data
# ----------------------------------

clf = make_da_pipeline(
MultiLinearMongeAlignmentAdapter(),
LogisticRegression(),
)

clf.fit(X, y, sample_domain=sample_domain)

print(
"Average accuracy on all domains:",
clf.score(X, y, sample_domain=sample_domain, allow_source=True),
)
Loading
Loading