Skip to content

Commit

Permalink
feat: simplify data handling
Browse files Browse the repository at this point in the history
  • Loading branch information
SauravMaheshkar committed Sep 14, 2024
1 parent 213a950 commit dee703e
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions clip/eval.py → clip/linear_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,22 @@
from model import CLIPModel
from PIL import Image
from sklearn.linear_model import LogisticRegression
from tqdm import tqdm


def get_cifar10(batch_size, root=None):
tr = load_cifar10(root=root)
tr_iter = tr.to_stream().batch(batch_size)
tr = load_cifar10(root=root).batch(batch_size)

test = load_cifar10(root=root, train=False)
test_iter = test.to_stream().batch(batch_size)
test = load_cifar10(root=root, train=False).batch(batch_size)

return tr_iter, test_iter
return tr, test


def get_features(model, image_proc, iter):
all_features = []
all_labels = []

for _, batch in enumerate(iter):
for batch in tqdm(iter):
image, label = batch["image"], batch["label"]
x = image_proc([Image.fromarray(im) for im in image])
y = mx.array(label)
Expand Down

0 comments on commit dee703e

Please sign in to comment.