Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added initial hubconf.py #4

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Conversation

Multihuntr
Copy link

@Multihuntr Multihuntr commented Feb 8, 2024

To make the models more accessible, I have added a hubconf.py. This requires instantiating a model from the current code base.

With this change, on my fork, I can - from anywhere - get a pretrained SNUNet or Flood_ViT model:

import torch
import numpy as np
snunet = torch.hub.load("Multihuntr/KuroSiwo", "snunet", pretrained=True)
inps = [torch.randn(8, 2, 224, 224) for i in range(2)]
dem = np.random.randn(8, 224, 224)
out = snunet(inps, dem=dem)

flood_vit = torch.hub.load("Multihuntr/KuroSiwo", "vit_decoder", pretrained=True)
inps = [torch.randn(8, 2, 224, 224) for i in range(3)]
out = flood_vit(inps)

Before accepting this pull request, there are some steps you need to take (at the bottom), if you want it to work using torch.hub.load("Orion-AI-Lab/KuroSiwo", "snunet").

Decisions

  1. To ensure minimal dependencies, I have moved the FineTunerSegmentation model definition from models/model_utilites.py into its own file. Otherwise it would import model_utilities, and thus require several other libraries.
  2. I have maintained the model structure found in the original weights provided (as best I know how), and added a wrapper. The wrapper normalises the input values and standardises the interface. The idea is that the input is simple S1 images as torch.tensors, and the DEM as a np.ndarray.
  3. The wrapper takes a list of images. Rather than choosing whether to concat or not, to split up concat-ed images. I decided that you should always provide a list of images. You do have to be careful to give the correct number of images for the different models.
  4. I have torch.saved just the state_dict for maximum portability.

Uncertainties

I am not certain about a few aspects. Can you check:

  1. Did I correctly instantiate SNUNet_ECAM?
  2. (EDIT: This point resolved)
  3. (EDIT: This point should be addressed elsewhere)

Actions needed before accepting

To make this work on the main branch, you will need to:

  1. Export just the state_dict for each.
  2. Upload those state_dicts to a release (recommended by torch.hub)

Then we can update the URL to point to your original codebase and keep it contained.

@Multihuntr
Copy link
Author

Based on the newly pushed "dem fix" code, I see that the third channel for SNUNet is likely the slope, not vh/vv. I have updated my code accordingly.

@Multihuntr
Copy link
Author

I have now run these models on a subset of the KuroSiwo dataset and got reasonable IoU numbers. Still not 100%, but, like 95% sure I've got this right now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant