diff --git a/clip/eval.py b/clip/linear_probe.py similarity index 87% rename from clip/eval.py rename to clip/linear_probe.py index 5f3dd2165..27c90be39 100644 --- a/clip/eval.py +++ b/clip/linear_probe.py @@ -7,23 +7,22 @@ from model import CLIPModel from PIL import Image from sklearn.linear_model import LogisticRegression +from tqdm import tqdm def get_cifar10(batch_size, root=None): - tr = load_cifar10(root=root) - tr_iter = tr.to_stream().batch(batch_size) + tr = load_cifar10(root=root).batch(batch_size) - test = load_cifar10(root=root, train=False) - test_iter = test.to_stream().batch(batch_size) + test = load_cifar10(root=root, train=False).batch(batch_size) - return tr_iter, test_iter + return tr, test def get_features(model, image_proc, iter): all_features = [] all_labels = [] - for _, batch in enumerate(iter): + for batch in tqdm(iter): image, label = batch["image"], batch["label"] x = image_proc([Image.fromarray(im) for im in image]) y = mx.array(label)