Skip to content

MATLAB implementation of the OpenAI CLIP deep learning model

License

Notifications You must be signed in to change notification settings

Lxrd-AJ/openai-clip-matlab

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

36 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

openai-clip-matlab

Open in MATLAB Online

MATLAB implementation of the OpenAI CLIP deep learning model

For training code see trainCLIP.mlx

CLIP Interface

if ~isfile("net-gpu-poor-squeezenet-bert-tiny.mat")
    !curl -O https://github.com/Lxrd-AJ/openai-clip-matlab/releases/download/v1.0.0/net-gpu-poor-squeezenet-bert-tiny.mat
end

load net-gpu-poor-squeezenet-bert-tiny.mat net logTemperature
imageInputSize = net.Layers(1).Layers(1).InputSize(1:2);
clip = CLIP(net, Temperature=exp(logTemperature), ImageInputSize=imageInputSize)
clip = 
  CLIP with no properties.

% Try encoding some images

% First gather all image paths from the dataset
imgBaseDir = "./flickr-dataset/Flicker8k_Dataset/";

% Test Images
devImages = readlines("flickr-dataset/Flickr_8k.devImages.txt");
devImages = fullfile(imgBaseDir, devImages);

% Get some random images from the dataset
someImagePaths = randsample(devImages, 20);
images = arrayfun(@(x) imread(x), someImagePaths, UniformOutput=false);
montage(images)

figure_1.png

[probs, logits] = clip.predict(someImagePaths, ["two Dogs", "Birthday Party"]);
disp(probs)
  20x2 single gpuArray dlarray

    0.0000    0.0003
    0.0000    0.0000
    0.0000    0.9989
    0.0000    0.0005
    0.0000    0.0001
    0.0000    0.0000
    0.0000    0.0000
    0.0000    0.0000
    0.0000    0.0000
    0.0000    0.0000
    0.0000    0.0000
    0.9998    0.0000
    0.0000    0.0000
    0.0000    0.0000
    0.0000    0.0000
    0.0000    0.0000
    0.0002    0.0000
    0.0000    0.0002
    0.0000    0.0000
    0.0000    0.0000
[maxProb, maxIdx] = max(extractdata(gather(probs)));
disp("Query Match Probability: " + maxProb)
    "Query Match Probability: 0.99983"    "Query Match Probability: 0.99886"
maxImages = {};
for idx=1:numel(maxIdx)
    maxImages{end+1} = imread(someImagePaths(maxIdx(idx)));
end
montage(maxImages)

figure_2.png

Other Notes

datastore = CLIPDatastore(ImageFolder="./flickr-dataset/Flicker8k_Dataset");
ds = shuffle(datastore, "PercentageToKeep", 1)
ds = 
  CLIPDatastore with no properties.

disp("Number of training images " + numel(ds))
Number of training images 30000
testDatastore = CLIPDatastore(ImageFolder="./flickr-dataset/Flicker8k_Dataset", TrainTestVal="./flickr-dataset/Flickr_8k.testImages.txt");
tds = shuffle(testDatastore, "PercentageToKeep", 1)
tds = 
  CLIPDatastore with no properties.

disp("Number of test images " + numel(tds))
Number of test images 5000
[net,tokenizer] = bert();
r = read(ds)
1 2 3
1 375x500x3 uint8 1x19 double "A brown and white dog be stand on a beach with a tennis ball beside it ."
[im, tokens, caption] = r{:};
imshow(im)

figure_0.png

%caption = "A group of horse and their rider be race each other .";
[~, segments] = encode(tokenizer, caption);

dltoken = dlarray(tokens, 'CT');
dlsegment = dlarray(segments{1}, 'CT');
mask = dlarray(ones(1, numel(dlsegment)), 'CT');

pred = predict(net, dltoken, mask, dlsegment);

decodedTokens = decode(tokenizer, tokens)
decodedTokens = "[CLS] a brown and white dog be stand on a beach with a tennis ball beside it . [SEP]"
disp(size(pred))
   768    19
last = pred(:,1) % Use the [CLS] token but the official CLIP implementation trained their own text encoder and used the last token's embedding
last = 
  768(C) x 1(T) single dlarray

   -0.1619
    0.1122
    0.0165
    0.2698
    0.4631
   -0.1553
   -0.8530
     ...
   -0.6838
    0.0352
    0.1187
   -0.6164
    0.5134
    0.2817
   -0.7342
    0.0385
    0.1096
   -0.1773
   -0.6413

  768(C) x 1(T) single dlarray

% NB: Using Bert for batched prediction
[net, tokenizer] = bert() %bert("Model","tiny");
net = 
  dlnetwork with properties:

         Layers: [129x1 nnet.cnn.layer.Layer]
    Connections: [164x2 table]
     Learnables: [197x3 table]
          State: [0x3 table]
     InputNames: {'input_ids'  'attention_mask'  'seg_ids'}
    OutputNames: {'enc12_layernorm2'}
    Initialized: 1

  View summary with summary.

tokenizer = 
  bertTokenizer with properties:

        IgnoreCase: 1
      StripAccents: 1
      PaddingToken: "[PAD]"
       PaddingCode: 1
        StartToken: "[CLS]"
         StartCode: 102
      UnknownToken: "[UNK]"
       UnknownCode: 101
    SeparatorToken: "[SEP]"
     SeparatorCode: 103
       ContextSize: 512

paddingValue = tokenizer.PaddingCode;
str = [
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];
[inputIdsStr, segmentIdsStr] = encode(tokenizer, str);

% The `maskStr` is used to specify the indices of the `paddingValue` so
% that the model ignores the padding
[inputIdsStr, maskStr] = padsequences(inputIdsStr, 2,"PaddingValue",paddingValue);
segmentIdsStr = padsequences(segmentIdsStr, 2,"PaddingValue",paddingValue);

inputIdsStr = dlarray(inputIdsStr, "CTB");
maskStr = dlarray(maskStr, "CTB");
segmentIdsStr = dlarray(segmentIdsStr, "CTB");

predictions = predict(net,inputIdsStr,maskStr,segmentIdsStr);
net = imageEncoder();
randX = dlarray(randn(net.Layers(1).InputSize), 'SSC');
net = dlnetwork(net, randX);
out = predict(net, randX);
disp(size(out))
      100352           1
net = textEncoder();
randX = dlarray(randn(1,1,10), 'CBT');
net = dlnetwork(net, Initialize=false);
net = initialize(net, randX, randX, randX)
net = 
  dlnetwork with properties:

         Layers: [1x1 nnet.cnn.layer.NetworkLayer]
    Connections: [0x2 table]
     Learnables: [37x3 table]
          State: [0x3 table]
     InputNames: {'bert_encoder/bert_model/input_ids'  'bert_encoder/bert_model/attention_mask'  'bert_encoder/bert_model/seg_ids'}
    OutputNames: {'bert_encoder'}
    Initialized: 1

  View summary with summary.

randInputIDs = dlarray(randi(1000, [1 3 10]), 'CBT');
attentionMask = dlarray(ones(size(randInputIDs)), 'CBT');
segmentIDs = dlarray(ones(size(randInputIDs)), 'CBT');

out = predict(net, randInputIDs, attentionMask, segmentIDs);
clsEmbeddings = out(:,:,1);
projHead = projectionHead();
net = dlnetwork(projHead, dlarray(randn(1,2048), 'BC'))
net = 
  dlnetwork with properties:

         Layers: [1x1 nnet.cnn.layer.NetworkLayer]
    Connections: [0x2 table]
     Learnables: [4x3 table]
          State: [0x3 table]
     InputNames: {'proj'}
    OutputNames: {'proj'}
    Initialized: 1

  View summary with summary.

out = predict(net, dlarray(randn(1,2048), 'BC'));
size(out)
ans = 1x2
   256     1

Flickr Dataset

See https://hockenmaier.cs.illinois.edu/8k-pictures.html Data sources for download

Other notes

TODO (Training) -- Probably won't get around to performing this

  • Design a smaller model (use Bert tiny and design a smaller image encoder from an existing pretrained image model - use squeezenet)
    • Allow the encoder models to learn but with a smaller learning rate
  • Use [SEP] token from bert rather than [CLS] token
  • Allow the model to learn the logits scaling
  • Support training on the train, validation and test sets
    • Update datastore
    • Calculate accuracy metric: argmax(logits) == targets
    • In training loop perform validation while training
    • Compute accuracy on validation set
  • Follow model design and training guides in Section 2.4 & 2.5
    • Use cosine schedule
    • Clip logits scaling temperature parameter to 100 max
  • Move image resizing outside of the processMiniBatch function and into a transform function for the datastore
  • Upgraded datastore class: Use the provided train, validation and test sets.
  • Save the model at different checkpoints during training
  • Train on Flickr30k dataset

TODO (Model Interface)

About

MATLAB implementation of the OpenAI CLIP deep learning model

Resources

License

Stars

Watchers

Forks

Packages

No packages published