-
Notifications
You must be signed in to change notification settings - Fork 920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(clip): add linear probe evaluation script #960
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the addition. I would rename the script to linear_probe.py
and add the eval. I am wondering if it would be nicer (since it is an example after all) to train a logistic regression model in MLX instead of using scikit-learn.
clip/eval.py
Outdated
all_features = [] | ||
all_labels = [] | ||
|
||
for _, batch in enumerate(iter): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why enumerate? Possibly tqdm would be nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in dee703e
clip/eval.py
Outdated
tr_iter = tr.to_stream().batch(batch_size) | ||
|
||
test = load_cifar10(root=root, train=False) | ||
test_iter = test.to_stream().batch(batch_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it is a very small in memory dataset the stream is unnecessary here. I 'd keep it a buffer so that it is nicer and we have a len
as well.
Namely,
train_iter = load_cifar10(root=root).batch(batch_size)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in dee703e
clip/eval.py
Outdated
|
||
image_embeds = model.get_image_features(x) | ||
all_features.append(image_embeds) | ||
all_labels.append(y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need an mx.eval(image_embeds)
at some point otherwise you just create a huge graph for the GPU to compute at the same time which leads to memory problems.
Another thing to consider would be to have two commands in the |
What I had in mind when submitting this PR was to showcase similar performance between the official implementation and the mlx port. Adding a mlx implementation of logistic regression seems like a nice idea but IMO it should reside in a different directory. Maybe another |
Gentle ping @angeloskath |
Adds a script to perform linear probe evaluation using the
mlx.data
module for data loading. Mostly a mirror of the Linear-probe evaluation script from the official CLIP repository.References