-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate embeddings via prediction loop #56
Conversation
Implement the embedding generator in the LightningModule's predict_step. The embeddings are tensor arrays that are saved to a .npy file in the data/embeddings/ folder. Input data is retrieved from the predict_dataloader, which is currently using the validation datapipe rather than a dedicated datapipe. Have documented how to generate the embedding output file using LightningCLI on the main README.md file. Also added a unit test to ensure that saving and loading from an embedding_0.npy file works.
src/model_vit.py
Outdated
# Get embeddings generated from encoder | ||
embeddings: torch.Tensor = outputs_encoder.last_hidden_state | ||
assert embeddings.shape == torch.Size( | ||
[self.B, 17, 768] # (batch_size, sequence_length, hidden_size) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@srmsoumya, if you have time, could you try to see why the outputs of the encoder have a shape like (32, 17, 768)
, or (batch_size, sequence_length, hidden_size)
? Specifically, I'm not sure what the sequence_length (size: 17) dimension is about, and couldn't quite figure it out from reading https://huggingface.co/docs/transformers/model_doc/vit_mae. More just something to understand, since we'll be moving to your MAE implementation in #47 later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@weiji14 we are masking out 75%
of the patches from the image. Given our image size is 256 x 256
and we use a patch size of 32
, that will give us 256 / 32
=> 8 x 8
=> 64 patches
. Masking out 75%
of these patches gives us 0.25 x 64 => 16 patches
. Adding 1 extra cls token gives us a total of 17 patches
to input into the transformer portion of the encoder. This results in a batch size of batch_size: 32
x number of unmasked patches: 17
x patch embedding: 768
. It is also the output we get from the encoder.
When creating the embeddings from the encoder of the model, we should switch-off the masking strategy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonderful explanation! Thanks @srmsoumya, I've disabled the masking in the predict_step
now at commit f09d2e7, and the output is now (32, 65, 768) as expected. Here's a new sample embedding for one image: embedding_0.npy.zip [array shape: (1, 65, 768)]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect, that looks just right!
@weiji14 For embeddings per image, the general practice is either to pick the 1st embedding vector i.e batch_size x unmasked_patches[:1] x embedding_dim
or take a mean of the remaining vectors i.e (batch_size x unmasked_patches[1:] x embedding_dim).mean(dim=1)
. Doing this will give us embeddings of size batch_size x embedding_dim
which is a single vector representing an image.
First vector represents the cls
token, which should represent the feature embedding of the image (this is a borrowed concept from BERT) or the mean of remaining vectors should also work (which is what I am trying to implement with vit-pytorch, makes code less complicated).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, do you have any papers describing when either the cls token or mean of the patch embeddings are used? Also @leothomas, do you have any insights on which method of collapsing the embeddings into a single 1d vector was used for the similarity search/vector database projects you've worked on? Just want some extra context before deciding on which way to go.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey there!
Just to clarify, the patches are are 32x32 (pixels) subsections of the overall image, correct? Why are we masking 75% of them? Is that due to cloud cover, or seeking patches only over land?
In regards to the cls
token, it seems to be a classification token - which represents the entire image.
The theory behind averaging the embeddings is that there is a very high liklyhood (but not a guarantee) that the averages of 2 similar collections will be similar and the averages of 2 difference collections will be different. It seems that this has mostly been researched in the case of text embeddings, where the exact order of words may matter less than in the case of image patches. I suspect that averaging the embeddings for each patch may lose some of the physical relationships of the overall image, but would be very interesting to compare and contrast the two.
If we're going to use something like pgvector
, we can easily have both types of embeddings colocated in the database and build a partial index for each!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to clarify, the patches are are 32x32 (pixels) subsections of the overall image, correct? Why are we masking 75% of them? Is that due to cloud cover, or seeking patches only over land?
Yes, each patch are 32x32 pixels. We were masking 75% of the patches because that's how a Masked Autoencoder is trained (see previous PR at #37), but for inference/prediction, we shouldn't apply the mask.
In regards to the
cls
token, it seems to be a classification token - which represents the entire image.The theory behind averaging the embeddings is that there is a very high liklyhood (but not a guarantee) that the averages of 2 similar collections will be similar and the averages of 2 difference collections will be different. It seems that this has mostly been researched in the case of text embeddings, where the exact order of words may matter less than in the case of image patches. I suspect that averaging the embeddings for each patch may lose some of the physical relationships of the overall image, but would be very interesting to compare and contrast the two.
If we're going to use something like
pgvector
, we can easily have both types of embeddings colocated in the database and build a partial index for each!
That sounds good actually. We could save two files, a cls_token embedding (1x768), and a patch embedding (either 1x768, or 64x768)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for fun, I took a look at the embedding in their raw un-averaged form.
import pandas as pd
import numpy as np
embedding: np.ndarray = np.load(file="data/embeddings/embedding_0.npy")
df: pd.DataFrame = pd.DataFrame(data=embedding.squeeze())
df.shape # (65, 768)
# Get descriptive statistics on each of the 768 columns
df[1:].describe()
# Heatmap of embeddings (first row is cls_token)
df.style.background_gradient(axis="columns", cmap="Greens")
Row 0 is the cls_token, and row 1-64 is each patch. Columns are the 768 embeddings.
Descriptive stats:
Heatmap:
Scrolling through the 768 columns, I don't think I saw much in terms of outliers, the values seem pretty consistent within a column (standard deviation is usually <0.01), so should be ok to just use the mean I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If my intuition is correct, is not that surprising that most tiles are "semantically flat" since most tiles will be "on thing" all forest, or grass, or mountain ... The test for the need of richer embeddings would be to pick up examples with semantically rich, or even semantic polarization, like an image with both land and water, or city and forest.
Also since we train with a MAE, I can see how each semantic patch will actually learn to include the expected semantics of the surrounding patches within the chip, not its own concent; so that the inference can recreate the missing bits, making the stdev within a patch smaller.
Am I making sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, it would be good to explore what the embeddings look like for tiles that have more diverse land cover types. I've done just that by picking the most diverse tile we have, details at #35 (comment)
Also since we train with a MAE, I can see how each semantic patch will actually learn to include the expected semantics of the surrounding patches within the chip, not its own concent; so that the inference can recreate the missing bits, making the stdev within a patch smaller.
Am I making sense?
Yes, the 1x768 embedding generated from each 32x32 patch actually contains information from the other patches. More details at #67, and we can continue the discussion there!
Previously, 75% of the patches, or 48 out of a total of 64 were masked out, leaving 16 patches plus 1 cls_token = 17 sequences. Disabling the mask gives 64 + 1 cls_token = 65 sequences. Moved some assert statements with a fixed sequence_length dim from the forward function to the training_step. Also updated the unit test to ensure output embeddings have a shape like (batch_size, 65, 768).
Refactoring the setup method in the LightningDataModule to not do a random split on the predict stage. I.e. just do the GeoTIFF to torch.Tensor conversion directly, followed by batching and collating.
Need to explicitly pass an argument to stage in the test_geotiffdatapipemodule unit test. Testing both the fit and predict stages.
Make sure that the generated embeddings do not have NaN values in them.
Gonna leave this for a day or so for review. There are a few nice-to-haves (e.g. better documentation on how to read the embeddings, a better filename that includes some spatiotemporal metadata, etc), but those can be done in follow-up PRs. |
Instead of saving embeddings of shape (1, 65, 768), save out embeddings of shape (1, 768) instead. Done by taking the mean along the sequence_length dim, except for the cls_token part (first index in the 65).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good.
Thanks for reviewing @srmsoumya, and everyone else for the comments. I'll merge this in now, noting that the code here generates (1, 768) shape averaged embeddings rather than the (1, 65, 768) shape raw embeddings. We can revise this later if we decide that we do want the raw embeddings. |
* 🍻 Generate embeddings via prediction loop Implement the embedding generator in the LightningModule's predict_step. The embeddings are tensor arrays that are saved to a .npy file in the data/embeddings/ folder. Input data is retrieved from the predict_dataloader, which is currently using the validation datapipe rather than a dedicated datapipe. Have documented how to generate the embedding output file using LightningCLI on the main README.md file. Also added a unit test to ensure that saving and loading from an embedding_0.npy file works. * 🐛 Disable masking of patches on predict_step Previously, 75% of the patches, or 48 out of a total of 64 were masked out, leaving 16 patches plus 1 cls_token = 17 sequences. Disabling the mask gives 64 + 1 cls_token = 65 sequences. Moved some assert statements with a fixed sequence_length dim from the forward function to the training_step. Also updated the unit test to ensure output embeddings have a shape like (batch_size, 65, 768). * ♻️ Refactor LightningDataModule to not do random split on predict Refactoring the setup method in the LightningDataModule to not do a random split on the predict stage. I.e. just do the GeoTIFF to torch.Tensor conversion directly, followed by batching and collating. * ✅ Test predict stage in geotiffdatamodule Need to explicitly pass an argument to stage in the test_geotiffdatapipemodule unit test. Testing both the fit and predict stages. * 👔 Ensure that embeddings have no NaN values Make sure that the generated embeddings do not have NaN values in them. * 🗃️ Take mean of the embeddings along sequence_length dim Instead of saving embeddings of shape (1, 65, 768), save out embeddings of shape (1, 768) instead. Done by taking the mean along the sequence_length dim, except for the cls_token part (first index in the 65).
What I am changing
How I did it
predict_step
, implement the logic to do the forward pass and save-to-npy steppredict_dataloader
Excalidraw link: https://excalidraw.com/#json=IDteKVYDAHd05wT-rCR6K,vmbiRIb5ucGiXP6R1idu4w
TODO in this PR:
predict_dataloader
andpredict_step
TODO in the future:
How you can test it
data/
folder, and then run:This should produce an
embedding_0.npy
file under thedata/embeddings/
folder. Sample files (need to unzip):Extra configuration options can be found using
python trainer.py predict --help
To load the embeddings from the npy file:
Related Issues
Towards #3