-
Notifications
You must be signed in to change notification settings - Fork 7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding _ops and _weight_size metadata checks to tests (#6996)
* Adding _ops and _weight_size metadata checks to tests * Fixing wrong ops value * Changing test_schema_meta_validation to instantiate the model only once * moving instantiating quantized models inside get_ops * Small refactor of test_schema_meta_validation logic * Reverting to previous ops value * Simplifying unquantized models logic in test_schema_meta_validation * Update test/test_extended_models.py Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com> Co-authored-by: Toni Blaslov <tblaslov@fb.com> Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
- Loading branch information
1 parent
264b83b
commit 790f1cd
Showing
2 changed files
with
352 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,303 @@ | ||
import os | ||
from collections import defaultdict | ||
from numbers import Number | ||
from typing import Any, List | ||
|
||
import torch | ||
from torch.utils._python_dispatch import TorchDispatchMode | ||
|
||
from torch.utils._pytree import tree_map | ||
|
||
from torchvision.models._api import Weights | ||
|
||
aten = torch.ops.aten | ||
quantized = torch.ops.quantized | ||
|
||
|
||
def get_shape(i): | ||
if isinstance(i, torch.Tensor): | ||
return i.shape | ||
elif hasattr(i, "weight"): | ||
return i.weight().shape | ||
else: | ||
raise ValueError(f"Unknown type {type(i)}") | ||
|
||
|
||
def prod(x): | ||
res = 1 | ||
for i in x: | ||
res *= i | ||
return res | ||
|
||
|
||
def matmul_flop(inputs: List[Any], outputs: List[Any]) -> Number: | ||
""" | ||
Count flops for matmul. | ||
""" | ||
# Inputs should be a list of length 2. | ||
# Inputs contains the shapes of two matrices. | ||
input_shapes = [get_shape(v) for v in inputs] | ||
assert len(input_shapes) == 2, input_shapes | ||
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes | ||
flop = prod(input_shapes[0]) * input_shapes[-1][-1] | ||
return flop | ||
|
||
|
||
def addmm_flop(inputs: List[Any], outputs: List[Any]) -> Number: | ||
""" | ||
Count flops for fully connected layers. | ||
""" | ||
# Count flop for nn.Linear | ||
# inputs is a list of length 3. | ||
input_shapes = [get_shape(v) for v in inputs[1:3]] | ||
# input_shapes[0]: [batch size, input feature dimension] | ||
# input_shapes[1]: [batch size, output feature dimension] | ||
assert len(input_shapes[0]) == 2, input_shapes[0] | ||
assert len(input_shapes[1]) == 2, input_shapes[1] | ||
batch_size, input_dim = input_shapes[0] | ||
output_dim = input_shapes[1][1] | ||
flops = batch_size * input_dim * output_dim | ||
return flops | ||
|
||
|
||
def bmm_flop(inputs: List[Any], outputs: List[Any]) -> Number: | ||
""" | ||
Count flops for the bmm operation. | ||
""" | ||
# Inputs should be a list of length 2. | ||
# Inputs contains the shapes of two tensor. | ||
assert len(inputs) == 2, len(inputs) | ||
input_shapes = [get_shape(v) for v in inputs] | ||
n, c, t = input_shapes[0] | ||
d = input_shapes[-1][-1] | ||
flop = n * c * t * d | ||
return flop | ||
|
||
|
||
def conv_flop_count( | ||
x_shape: List[int], | ||
w_shape: List[int], | ||
out_shape: List[int], | ||
transposed: bool = False, | ||
) -> Number: | ||
""" | ||
Count flops for convolution. Note only multiplication is | ||
counted. Computation for addition and bias is ignored. | ||
Flops for a transposed convolution are calculated as | ||
flops = (x_shape[2:] * prod(w_shape) * batch_size). | ||
Args: | ||
x_shape (list(int)): The input shape before convolution. | ||
w_shape (list(int)): The filter shape. | ||
out_shape (list(int)): The output shape after convolution. | ||
transposed (bool): is the convolution transposed | ||
Returns: | ||
int: the number of flops | ||
""" | ||
batch_size = x_shape[0] | ||
conv_shape = (x_shape if transposed else out_shape)[2:] | ||
flop = batch_size * prod(w_shape) * prod(conv_shape) | ||
return flop | ||
|
||
|
||
def conv_flop(inputs: List[Any], outputs: List[Any]): | ||
""" | ||
Count flops for convolution. | ||
""" | ||
x, w = inputs[:2] | ||
x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0])) | ||
transposed = inputs[6] | ||
|
||
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) | ||
|
||
|
||
def quant_conv_flop(inputs: List[Any], outputs: List[Any]): | ||
""" | ||
Count flops for quantized convolution. | ||
""" | ||
x, w = inputs[:2] | ||
x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0])) | ||
|
||
return conv_flop_count(x_shape, w_shape, out_shape, transposed=False) | ||
|
||
|
||
def transpose_shape(shape): | ||
return [shape[1], shape[0]] + list(shape[2:]) | ||
|
||
|
||
def conv_backward_flop(inputs: List[Any], outputs: List[Any]): | ||
grad_out_shape, x_shape, w_shape = [get_shape(i) for i in inputs[:3]] | ||
output_mask = inputs[-1] | ||
fwd_transposed = inputs[7] | ||
flop_count = 0 | ||
|
||
if output_mask[0]: | ||
grad_input_shape = get_shape(outputs[0]) | ||
flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed) | ||
if output_mask[1]: | ||
grad_weight_shape = get_shape(outputs[1]) | ||
flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed) | ||
|
||
return flop_count | ||
|
||
|
||
flop_mapping = { | ||
aten.mm: matmul_flop, | ||
aten.matmul: matmul_flop, | ||
aten.addmm: addmm_flop, | ||
aten.bmm: bmm_flop, | ||
aten.convolution: conv_flop, | ||
aten._convolution: conv_flop, | ||
aten.convolution_backward: conv_backward_flop, | ||
quantized.conv2d: quant_conv_flop, | ||
quantized.conv2d_relu: quant_conv_flop, | ||
} | ||
|
||
unmapped_ops = set() | ||
|
||
|
||
def normalize_tuple(x): | ||
if not isinstance(x, tuple): | ||
return (x,) | ||
return x | ||
|
||
|
||
class FlopCounterMode(TorchDispatchMode): | ||
def __init__(self, model=None): | ||
self.flop_counts = defaultdict(lambda: defaultdict(int)) | ||
self.parents = ["Global"] | ||
# global mod | ||
if model is not None: | ||
for name, module in dict(model.named_children()).items(): | ||
module.register_forward_pre_hook(self.enter_module(name)) | ||
module.register_forward_hook(self.exit_module(name)) | ||
|
||
def enter_module(self, name): | ||
def f(module, inputs): | ||
self.parents.append(name) | ||
inputs = normalize_tuple(inputs) | ||
out = self.create_backwards_pop(name)(*inputs) | ||
return out | ||
|
||
return f | ||
|
||
def exit_module(self, name): | ||
def f(module, inputs, outputs): | ||
assert self.parents[-1] == name | ||
self.parents.pop() | ||
outputs = normalize_tuple(outputs) | ||
return self.create_backwards_push(name)(*outputs) | ||
|
||
return f | ||
|
||
def create_backwards_push(self, name): | ||
class PushState(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, *args): | ||
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args) | ||
if len(args) == 1: | ||
return args[0] | ||
return args | ||
|
||
@staticmethod | ||
def backward(ctx, *grad_outs): | ||
self.parents.append(name) | ||
return grad_outs | ||
|
||
return PushState.apply | ||
|
||
def create_backwards_pop(self, name): | ||
class PopState(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, *args): | ||
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args) | ||
if len(args) == 1: | ||
return args[0] | ||
return args | ||
|
||
@staticmethod | ||
def backward(ctx, *grad_outs): | ||
assert self.parents[-1] == name | ||
self.parents.pop() | ||
return grad_outs | ||
|
||
return PopState.apply | ||
|
||
def __enter__(self): | ||
self.flop_counts.clear() | ||
super().__enter__() | ||
|
||
def __exit__(self, *args): | ||
# print(f"Total: {sum(self.flop_counts['Global'].values()) / 1e9} GFLOPS") | ||
# for mod in self.flop_counts.keys(): | ||
# print(f"Module: ", mod) | ||
# for k, v in self.flop_counts[mod].items(): | ||
# print(f"{k}: {v / 1e9} GFLOPS") | ||
# print() | ||
super().__exit__(*args) | ||
|
||
def __torch_dispatch__(self, func, types, args=(), kwargs=None): | ||
kwargs = kwargs if kwargs else {} | ||
|
||
out = func(*args, **kwargs) | ||
func_packet = func._overloadpacket | ||
if func_packet in flop_mapping: | ||
flop_count = flop_mapping[func_packet](args, normalize_tuple(out)) | ||
for par in self.parents: | ||
self.flop_counts[par][func_packet] += flop_count | ||
else: | ||
unmapped_ops.add(func_packet) | ||
|
||
return out | ||
|
||
def get_flops(self): | ||
return sum(self.flop_counts["Global"].values()) / 1e9 | ||
|
||
|
||
def get_dims(module_name, height, width): | ||
# detection models have curated input sizes | ||
if module_name == "detection": | ||
# we can feed a batch of 1 for detection model instead of a list of 1 image | ||
dims = (3, height, width) | ||
elif module_name == "video": | ||
# hard-coding the time dimension to size 16 | ||
dims = (1, 16, 3, height, width) | ||
else: | ||
dims = (1, 3, height, width) | ||
|
||
return dims | ||
|
||
|
||
def get_ops(model: torch.nn.Module, weight: Weights, height=512, width=512): | ||
module_name = model.__module__.split(".")[-2] | ||
dims = get_dims(module_name=module_name, height=height, width=width) | ||
|
||
input_tensor = torch.randn(dims) | ||
|
||
# try: | ||
preprocess = weight.transforms() | ||
if module_name == "optical_flow": | ||
inp = preprocess(input_tensor, input_tensor) | ||
else: | ||
# hack to enable mod(*inp) for optical_flow models | ||
inp = [preprocess(input_tensor)] | ||
|
||
model.eval() | ||
|
||
flop_counter = FlopCounterMode(model) | ||
with flop_counter: | ||
# detection models expect a list of 3d tensors as inputs | ||
if module_name == "detection": | ||
model(inp) | ||
else: | ||
model(*inp) | ||
|
||
flops = flop_counter.get_flops() | ||
|
||
return round(flops, 3) | ||
|
||
|
||
def get_weight_size_mb(weight): | ||
weights_path = os.path.join(os.getenv("HOME"), ".cache/torch/hub/checkpoints", weight.url.split("/")[-1]) | ||
weights_size_mb = os.path.getsize(weights_path) / 1024 / 1024 | ||
|
||
return round(weights_size_mb, 3) |
Oops, something went wrong.