Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(clip): add linear probe evaluation script #960

Merged
merged 3 commits into from
Oct 25, 2024

Conversation

SauravMaheshkar
Copy link
Contributor

Adds a script to perform linear probe evaluation using the mlx.data module for data loading. Mostly a mirror of the Linear-probe evaluation script from the official CLIP repository.

References

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the addition. I would rename the script to linear_probe.py and add the eval. I am wondering if it would be nicer (since it is an example after all) to train a logistic regression model in MLX instead of using scikit-learn.

clip/eval.py Outdated
all_features = []
all_labels = []

for _, batch in enumerate(iter):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why enumerate? Possibly tqdm would be nice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in dee703e

clip/eval.py Outdated
tr_iter = tr.to_stream().batch(batch_size)

test = load_cifar10(root=root, train=False)
test_iter = test.to_stream().batch(batch_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it is a very small in memory dataset the stream is unnecessary here. I 'd keep it a buffer so that it is nicer and we have a len as well.

Namely,

train_iter = load_cifar10(root=root).batch(batch_size)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in dee703e

clip/eval.py Outdated

image_embeds = model.get_image_features(x)
all_features.append(image_embeds)
all_labels.append(y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need an mx.eval(image_embeds) at some point otherwise you just create a huge graph for the GPU to compute at the same time which leads to memory problems.

@angeloskath
Copy link
Member

Another thing to consider would be to have two commands in the linear_probe.py script. One that extracts features and saves them in a safe tensors file and another that does trains the logistic regression classifier given that file. The first part might be generally useful for extracting clip features for a dataset for instance.

@SauravMaheshkar
Copy link
Contributor Author

Thanks for the addition. I would rename the script to linear_probe.py and add the eval. I am wondering if it would be nicer (since it is an example after all) to train a logistic regression model in MLX instead of using scikit-learn.

What I had in mind when submitting this PR was to showcase similar performance between the official implementation and the mlx port. Adding a mlx implementation of logistic regression seems like a nice idea but IMO it should reside in a different directory. Maybe another misc/ or core/ directory that contains implementations of various fundamental models.

@SauravMaheshkar
Copy link
Contributor Author

SauravMaheshkar commented Oct 22, 2024

Gentle ping @angeloskath

@angeloskath angeloskath merged commit 4971462 into ml-explore:main Oct 25, 2024
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants