-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEEDBACK] Model Registration beta API #6365
Comments
It would be great if list_models could list only specific models matching a regex, or at least wildcard searches like list_models("resnet*") |
I agree with @dataplayer12. Whilst going through some tests in vision/test/test_prototype_models.py Lines 8 to 19 in 9c3e2bf
Furthermore this can get quite tricky when we're dealing with models that do not have the same out shapes or number of outputs even though they "solve" the same task. vision/test/test_prototype_models.py Lines 28 to 35 in 9c3e2bf
The BC-compatible fix I see is rather non intrusive and rather simple. We could change def find_model(name: str, pattern: str) -> Callable[..., M]:
name = name.lower()
try:
fn = BUILTIN_MODELS[name]
# check if the name matches the pattern
if not re.match(pattern, name):
return None
except KeyError:
raise ValueError(f"Unknown model {name}")
return fn Then we could change def list_model_fns(module, pattern: str = "*") -> List[Callable[..., M]]:
model_fns = [find_model(name, pattern) for name in list_models(module)]
model_fns = list(filter(lambda x: x is not None, model_fns))
return model_fns Other than giving the users the option of selecting only a specific family of models I believe that this might help with easing developer experience in the case of writing tests or various utilities whilst maintaining the same API. The alternative, in terms of developer experience would be to pass in individually each model class in the function arguments or decorator, when we cannot make the assertion that all model from a module behave in the exact same way. |
We're trying to adopt the new API in TorchGeo but it isn't clear how the registration API works for weights that are not built into torchvision. We list our own WeightsEnums but
So it's possible this is by design. Guess I'll just wait for them to become public and copy-n-paste all the code for now... |
Thanks for the feedback @adamjstewart . The registrators are private right now because they weren't intended to work for external packages. What kind of workflow would you like to enable? It seems like it would work like this for torchgeo users: from torchvision.models import list_models
list_models(module=torchgeo.models) which IMHO seems awkward; torchgeo users probably just want to use something like IIRC from the design stage, we introduced the We're still open to making those public if we can find a nice/easy/useful way to do so, but for now I think a good old copy-n-paste is your best strat :) |
I would love it if that syntax worked, but it doesn't: >>> import torchgeo.models
>>> from torchvision.models import list_models
>>> list_models(module=torchgeo.models)
[] |
🚀 Feedback Request
This issue is dedicated for collecting community feedback on the Model Registration API. Please review the dedicated RFC and blogpost where we describe the API in detail and provide an overview of its features.
We would love to get your thoughts, comments and input in order to finalize the API and include it on the new release of TorchVision.
The text was updated successfully, but these errors were encountered: