-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save embeddings with spatiotemporal metadata to GeoParquet #73
Conversation
Storing the vector embeddings alongside some spatial bounding box and datetime information in a tabular GeoParquet format, instead of an npy file! Using geopandas to create a GeoDataFrame with three columns - date, embeddings, geometry. The date is stored in Arrow's date32 format, embeddings are in FixedShapedTensorArray, and geometry is in WKB. Have updated the unit test's sample fixture data with the extra spatiotemporal data, and tested that the saved GeoParquet file can be loaded back.
Improve the docstring of predict_step in the LightningModule on how the embeddings are generated, and then saved to a GeoParquet file with the spatiotemporal metadata. Included some ASCII art and a markdown table of how the tabular data looks like.
"embeddings": pa.FixedShapeTensorArray.from_numpy_ndarray( | ||
embeddings_mean.cpu().detach().__array__() | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although we've converted the embedding into a FixedShapeTensorArray here, pandas
/geopandas
still interprets this column as an object
dtype, and this is saved as an object dtype to the parquet file too (see the unit test). Need to see if there's a way to preserve the dtype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Found a way to save this embeddings
column as a FixedShapeTensorArray dtype instead of an object
dtype like so:
"embeddings": pa.FixedShapeTensorArray.from_numpy_ndarray( | |
embeddings_mean.cpu().detach().__array__() | |
), | |
"embeddings": gpd.pd.arrays.ArrowExtensionArray( | |
values=pa.FixedShapeTensorArray.from_numpy_ndarray(embeddings) | |
), |
However, while we can save this FixedShapeTensorArray
to GeoParquet, loading this embeddings column as a FixedShapeTensorArray
is challenging, and might involve code that looks like this:
geodataframe: gpd.GeoDataFrame = gpd.read_parquet(
path="data/embeddings/embeddings_0.gpq",
schema=pa.schema(
fields=[
pa.field(
name="embeddings",
type=pa.fixed_shape_tensor(
value_type=pa.float32(), shape=[768]
),
),
pa.field(name="geometry", type=pa.binary()),
]
),
)
But this technically still results in an embeddings
column with object
dtype... Also, QGIS can load this geoparquet file with FixedShapeTensorArray, but would crash when you try to open the attribute table, because it can't handle FixedShapeTensorArray yet. So probably best to keep it in object
dtype for now.
outpath = f"{outfolder}/embeddings_{batch_idx}.gpq" | ||
gdf.to_parquet(path=outpath, schema_version="1.0.0") | ||
print(f"Saved embeddings of shape {tuple(embeddings_mean.shape)} to {outpath}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is possible to save several rows worth of embeddings to a single geoparquet file now. So, we can decide on how to lump embeddings together. E.g. save all the embeddings for one MGRS tile in one year together.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New 512x512 image chips are being processed now-ish, see #76 (comment). Will use a new filename convention in a follow up PR (with the MGRS code in it) once we've got a new model trained on that new dataset.
Document that the embeddings are stored with spatiotemporal metadata as a GeoParquet file. Increased batch size from 1 to 1024.
Should have updated the type hints in #66, but might as well do it here. Also adding some more inline comments and fixed a typo.
There are a couple of things that can be improved as mentioned above, such as the filenaming scheme, and streamlining how the embeddings are saved to the GeoParquet file, but will merge this in first, and handle those nice-to-haves in follow-up PRs. |
Output embeddings to a geopandas.GeoDataFrame with columns 'source_url', 'date', 'embeddings', and 'geometry'. Essentially copying and adapting the code from a767164 in #73, but modifying how the encoder's masking is disabled, and how the mean/average of the embeddings is computed over a slice of the raw embeddings.
* ✨ Save embeddings with spatiotemporal metadata to GeoParquet Storing the vector embeddings alongside some spatial bounding box and datetime information in a tabular GeoParquet format, instead of an npy file! Using geopandas to create a GeoDataFrame with three columns - date, embeddings, geometry. The date is stored in Arrow's date32 format, embeddings are in FixedShapedTensorArray, and geometry is in WKB. Have updated the unit test's sample fixture data with the extra spatiotemporal data, and tested that the saved GeoParquet file can be loaded back. * 📝 Document how embeddings are generated and saved to geoparquet Improve the docstring of predict_step in the LightningModule on how the embeddings are generated, and then saved to a GeoParquet file with the spatiotemporal metadata. Included some ASCII art and a markdown table of how the tabular data looks like. * 📝 Mention in main README.md that embeddings are saved to geoparquet Document that the embeddings are stored with spatiotemporal metadata as a GeoParquet file. Increased batch size from 1 to 1024. * 🎨 Update type hint of batch inputs, and add some inline comments Should have updated the type hints in #66, but might as well do it here. Also adding some more inline comments and fixed a typo.
#96) * 🍻 Implement CLAYModule's predict_step to generate embeddings table Output embeddings to a geopandas.GeoDataFrame with columns 'source_url', 'date', 'embeddings', and 'geometry'. Essentially copying and adapting the code from a767164 in #73, but modifying how the encoder's masking is disabled, and how the mean/average of the embeddings is computed over a slice of the raw embeddings. * 🚚 Rename output file to {MGRS}_{MINDATE}_{MAXDATE}_v{VERSION}.gpq The output GeoParquet file now has a filename with a format like "{MGRS:5}_{MINDATE:8}_{MAXDATE:8}_v{VERSION:3}.gpq", e.g. "12ABC_20210101_20231231_v001.gpq". Have implemented this in model_vit.py, and copied over the same `on_predict_epoch_end` method to model_clay.py. Also, we are no longer saving out the index column to the GeoParquet file. * ✅ Fix failing test by updating to new output filename Forgot to update the filename in the unit test to conform to the new `{MGRS}_{MINDATE}_{MAXDATE}_v{VERSION}.gpq` format. Patches f19cf8f. * ✅ Parametrized test to check CLAYModule's predict loop Splitting the previous integration test on the neural network model into separate fit and predict unit tests. Only testing the prediction loop of CLAYModule, because training/validating the model might be too much for CPU-based Continuous Integration. Also for testing CLAYModule, we are using 32-true precision instead of bf16-mixed, because `torch.cat` doesn't work with float16 tensors on the CPU, see pytorch/pytorch#100932 (should be fixed with Pytorch 2.2). * ⏪ Save index column to GeoParquet file Decided that the index column might be good to keep for now, since it might help to speed up row counts? But we are resetting the index first before saving it. Partially reverts f19cf8f. * ✅ Fix unit test to include index column After f1439e3, need to ensure that the index column is checked in the output geodataframe.
What I am changing
How I did it
In the LightningModule's
predict_step
, usegeopandas
to create a GeoDataFrame with three columns - date, embeddings, geometry. A sample table would look like this:The date is stored in Arrow's
date32
format, embeddings are inFixedShapedTensorArray
(TODO), and geometry is inWKB
.Each row would store the embedding for a single 256x256 chip, and the entire table could realistically store N rows for an entire MGRS tile (10000x1000) across different dates.
TODO in this PR:
TODO in the future:
How you can test it
data/
folder, and then run:embedding_0.gpq
file under thedata/embeddings/
folderpython trainer.py predict --help
To load the embeddings from the geoparquet file:
If you have a newer version of QGIS, it's also possible to load the GeoParquet file directly. The below screenshot shows the bounding box locations of the 755 embeddings (1 embedding for each 256x256 chip):
Related Issues
Extends #56, continuation of #66.