-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding _ops and _weight_size metadata checks to tests #6996
Adding _ops and _weight_size metadata checks to tests #6996
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work @toni057. Just a few comments:
test/common_extended_utils.py
Outdated
detection_models_input_dims = { | ||
"fasterrcnn_mobilenet_v3_large_320_fpn": (320, 320), | ||
"fasterrcnn_mobilenet_v3_large_fpn": (800, 800), | ||
"fasterrcnn_resnet50_fpn": (800, 800), | ||
"fasterrcnn_resnet50_fpn_v2": (800, 800), | ||
"fcos_resnet50_fpn": (800, 800), | ||
"keypointrcnn_resnet50_fpn": (1333, 1333), | ||
"maskrcnn_resnet50_fpn": (800, 800), | ||
"maskrcnn_resnet50_fpn_v2": (800, 800), | ||
"retinanet_resnet50_fpn": (800, 800), | ||
"retinanet_resnet50_fpn_v2": (800, 800), | ||
"ssd300_vgg16": (300, 300), | ||
"ssdlite320_mobilenet_v3_large": (320, 320), | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I think this doesn't belong on the common_extended_utils.py
file but rather on the test/test_extended_models.py
file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, can move it there.
test/test_extended_models.py
Outdated
else: | ||
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()): | ||
incorrect_params.append(w) | ||
|
||
calculated_ops = get_ops(module_name, model_name, w) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to review this logic. This way we initialize the models multiple times. Once on the model_fn
call above and once within get_ops()
. What we can do is initialize the model once and then use it in both cases.
test/test_extended_models.py
Outdated
if module_name == "quantization": | ||
# parameters() count doesn't work well with quantization, so we check against the non-quantized | ||
unquantized_w = w.meta.get("unquantized") | ||
if unquantized_w is not None and w.meta.get("num_params") != unquantized_w.meta.get("num_params"): | ||
incorrect_params.append(w) | ||
|
||
# the methodology for quantized ops count doesn't work as well, so we take unquantized FLOPs instead | ||
calculated_ops = get_ops(model=None, module_name="models", model_name=model_name, weight=unquantized_w) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have to do this estimation. We can follow the same approach as with the num_params
. More precisely:
we fetch the unquantized_w.meta.get("_ops")
and confirm that the match what we have here. Basically we reproduce the logic on lines 219-220.
test/common_extended_utils.py
Outdated
return sum(self.flop_counts["Global"].values()) / 1e9 | ||
|
||
|
||
def get_ops(model: torch.nn.Module, module_name: str, model_name: str, weight: Weights, h=512, w=512): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's assume here that model
is not None. Then we don't need the model_name
parameter. The module_name
is also unnecessary as it can be fetched from the model. More specifically:
>>> m = resnet50()
>>> m.__module__
'torchvision.models.resnet'
test/common_extended_utils.py
Outdated
if model is None: | ||
kwargs = {"quantize": True} if module_name == "quantization" else {} | ||
model = models.get_model(model_name, weights=weight, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can go away:
if model is None: | |
kwargs = {"quantize": True} if module_name == "quantization" else {} | |
model = models.get_model(model_name, weights=weight, **kwargs) |
test/test_extended_models.py
Outdated
# loading the model and using it for parameter and ops verification | ||
kwargs = {"quantize": True} if module_name == "quantization" else {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessary. We already checked it's not quantization above.
test/test_extended_models.py
Outdated
) | ||
|
||
# assert that weight flops are correctly pasted to metadata | ||
assert calculated_ops == w.meta["_ops"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't assert like this because it will fail immediately the test without showing us other issues. Instead we should be collecting all issues in one list and showing them to the user. Previously we had incorrect_params
which was monitoring issues with the number of parameters. Now that we have more, it's worth switching this into something like incorrect_meta
and append to it not only the weight but also the meta name that failed. For example: incorrect_params.append((w, "num_params"))
.
test/test_extended_models.py
Outdated
assert not problematic_weights | ||
assert not incorrect_params | ||
assert not bad_names | ||
assert weight_size_mb == w.meta["_weight_size"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the above. This needs to be asserted properly for all weights. You can use the proposed incorrect_meta
to track it as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @toni057. Looks great. The comment below is optional.
Let's wait for the tests to see whether there is any randomness, otherwise we should be good.
incorrect_meta.append((w, "num_params")) | ||
|
||
# the methodology for quantized ops count doesn't work as well, so we take unquantized FLOPs instead | ||
if unquantized_w is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor Nit: Since this check is needed for both num_params
and _ops
we can perhaps do it once for both and simplify the code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, only one optional Nit below. Your call.
Otherwise we can merge on green CI.
Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
Summary: * Adding _ops and _weight_size metadata checks to tests * Fixing wrong ops value * Changing test_schema_meta_validation to instantiate the model only once * moving instantiating quantized models inside get_ops * Small refactor of test_schema_meta_validation logic * Reverting to previous ops value * Simplifying unquantized models logic in test_schema_meta_validation * Update test/test_extended_models.py Reviewed By: datumbox Differential Revision: D41836893 fbshipit-source-id: 9174c95ee1843d972898fcd89c3d4e1697e83bca Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com> Co-authored-by: Toni Blaslov <tblaslov@fb.com> Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
Continuing on PR6936 where number of operations and model sizes were added, in this PR we are adding the logic for calculating the mentioned metadata to test, and verifying that the values added to metadata correspond to the values hardcoded for weights.
Due to the relatively long run times, we are limiting the solution to default weights only.
cc: @datumbox
cc @datumbox @pmeier