Skip to content

Commit

Permalink
Adding _ops and _weight_size metadata checks to tests (#6996)
Browse files Browse the repository at this point in the history
* Adding _ops and _weight_size metadata checks to tests

* Fixing wrong ops value

* Changing test_schema_meta_validation to instantiate the model only once

* moving instantiating quantized models inside get_ops

* Small refactor of test_schema_meta_validation logic

* Reverting to previous ops value

* Simplifying unquantized models logic in test_schema_meta_validation

* Update test/test_extended_models.py

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>

Co-authored-by: Toni Blaslov <tblaslov@fb.com>
Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 1, 2022
1 parent 264b83b commit 790f1cd
Show file tree
Hide file tree
Showing 2 changed files with 352 additions and 12 deletions.
303 changes: 303 additions & 0 deletions test/common_extended_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
import os
from collections import defaultdict
from numbers import Number
from typing import Any, List

import torch
from torch.utils._python_dispatch import TorchDispatchMode

from torch.utils._pytree import tree_map

from torchvision.models._api import Weights

aten = torch.ops.aten
quantized = torch.ops.quantized


def get_shape(i):
if isinstance(i, torch.Tensor):
return i.shape
elif hasattr(i, "weight"):
return i.weight().shape
else:
raise ValueError(f"Unknown type {type(i)}")


def prod(x):
res = 1
for i in x:
res *= i
return res


def matmul_flop(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for matmul.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two matrices.
input_shapes = [get_shape(v) for v in inputs]
assert len(input_shapes) == 2, input_shapes
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
flop = prod(input_shapes[0]) * input_shapes[-1][-1]
return flop


def addmm_flop(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for fully connected layers.
"""
# Count flop for nn.Linear
# inputs is a list of length 3.
input_shapes = [get_shape(v) for v in inputs[1:3]]
# input_shapes[0]: [batch size, input feature dimension]
# input_shapes[1]: [batch size, output feature dimension]
assert len(input_shapes[0]) == 2, input_shapes[0]
assert len(input_shapes[1]) == 2, input_shapes[1]
batch_size, input_dim = input_shapes[0]
output_dim = input_shapes[1][1]
flops = batch_size * input_dim * output_dim
return flops


def bmm_flop(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for the bmm operation.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two tensor.
assert len(inputs) == 2, len(inputs)
input_shapes = [get_shape(v) for v in inputs]
n, c, t = input_shapes[0]
d = input_shapes[-1][-1]
flop = n * c * t * d
return flop


def conv_flop_count(
x_shape: List[int],
w_shape: List[int],
out_shape: List[int],
transposed: bool = False,
) -> Number:
"""
Count flops for convolution. Note only multiplication is
counted. Computation for addition and bias is ignored.
Flops for a transposed convolution are calculated as
flops = (x_shape[2:] * prod(w_shape) * batch_size).
Args:
x_shape (list(int)): The input shape before convolution.
w_shape (list(int)): The filter shape.
out_shape (list(int)): The output shape after convolution.
transposed (bool): is the convolution transposed
Returns:
int: the number of flops
"""
batch_size = x_shape[0]
conv_shape = (x_shape if transposed else out_shape)[2:]
flop = batch_size * prod(w_shape) * prod(conv_shape)
return flop


def conv_flop(inputs: List[Any], outputs: List[Any]):
"""
Count flops for convolution.
"""
x, w = inputs[:2]
x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
transposed = inputs[6]

return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)


def quant_conv_flop(inputs: List[Any], outputs: List[Any]):
"""
Count flops for quantized convolution.
"""
x, w = inputs[:2]
x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))

return conv_flop_count(x_shape, w_shape, out_shape, transposed=False)


def transpose_shape(shape):
return [shape[1], shape[0]] + list(shape[2:])


def conv_backward_flop(inputs: List[Any], outputs: List[Any]):
grad_out_shape, x_shape, w_shape = [get_shape(i) for i in inputs[:3]]
output_mask = inputs[-1]
fwd_transposed = inputs[7]
flop_count = 0

if output_mask[0]:
grad_input_shape = get_shape(outputs[0])
flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed)
if output_mask[1]:
grad_weight_shape = get_shape(outputs[1])
flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed)

return flop_count


flop_mapping = {
aten.mm: matmul_flop,
aten.matmul: matmul_flop,
aten.addmm: addmm_flop,
aten.bmm: bmm_flop,
aten.convolution: conv_flop,
aten._convolution: conv_flop,
aten.convolution_backward: conv_backward_flop,
quantized.conv2d: quant_conv_flop,
quantized.conv2d_relu: quant_conv_flop,
}

unmapped_ops = set()


def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x


class FlopCounterMode(TorchDispatchMode):
def __init__(self, model=None):
self.flop_counts = defaultdict(lambda: defaultdict(int))
self.parents = ["Global"]
# global mod
if model is not None:
for name, module in dict(model.named_children()).items():
module.register_forward_pre_hook(self.enter_module(name))
module.register_forward_hook(self.exit_module(name))

def enter_module(self, name):
def f(module, inputs):
self.parents.append(name)
inputs = normalize_tuple(inputs)
out = self.create_backwards_pop(name)(*inputs)
return out

return f

def exit_module(self, name):
def f(module, inputs, outputs):
assert self.parents[-1] == name
self.parents.pop()
outputs = normalize_tuple(outputs)
return self.create_backwards_push(name)(*outputs)

return f

def create_backwards_push(self, name):
class PushState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
if len(args) == 1:
return args[0]
return args

@staticmethod
def backward(ctx, *grad_outs):
self.parents.append(name)
return grad_outs

return PushState.apply

def create_backwards_pop(self, name):
class PopState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
if len(args) == 1:
return args[0]
return args

@staticmethod
def backward(ctx, *grad_outs):
assert self.parents[-1] == name
self.parents.pop()
return grad_outs

return PopState.apply

def __enter__(self):
self.flop_counts.clear()
super().__enter__()

def __exit__(self, *args):
# print(f"Total: {sum(self.flop_counts['Global'].values()) / 1e9} GFLOPS")
# for mod in self.flop_counts.keys():
# print(f"Module: ", mod)
# for k, v in self.flop_counts[mod].items():
# print(f"{k}: {v / 1e9} GFLOPS")
# print()
super().__exit__(*args)

def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs else {}

out = func(*args, **kwargs)
func_packet = func._overloadpacket
if func_packet in flop_mapping:
flop_count = flop_mapping[func_packet](args, normalize_tuple(out))
for par in self.parents:
self.flop_counts[par][func_packet] += flop_count
else:
unmapped_ops.add(func_packet)

return out

def get_flops(self):
return sum(self.flop_counts["Global"].values()) / 1e9


def get_dims(module_name, height, width):
# detection models have curated input sizes
if module_name == "detection":
# we can feed a batch of 1 for detection model instead of a list of 1 image
dims = (3, height, width)
elif module_name == "video":
# hard-coding the time dimension to size 16
dims = (1, 16, 3, height, width)
else:
dims = (1, 3, height, width)

return dims


def get_ops(model: torch.nn.Module, weight: Weights, height=512, width=512):
module_name = model.__module__.split(".")[-2]
dims = get_dims(module_name=module_name, height=height, width=width)

input_tensor = torch.randn(dims)

# try:
preprocess = weight.transforms()
if module_name == "optical_flow":
inp = preprocess(input_tensor, input_tensor)
else:
# hack to enable mod(*inp) for optical_flow models
inp = [preprocess(input_tensor)]

model.eval()

flop_counter = FlopCounterMode(model)
with flop_counter:
# detection models expect a list of 3d tensors as inputs
if module_name == "detection":
model(inp)
else:
model(*inp)

flops = flop_counter.get_flops()

return round(flops, 3)


def get_weight_size_mb(weight):
weights_path = os.path.join(os.getenv("HOME"), ".cache/torch/hub/checkpoints", weight.url.split("/")[-1])
weights_size_mb = os.path.getsize(weights_path) / 1024 / 1024

return round(weights_size_mb, 3)
Loading

0 comments on commit 790f1cd

Please sign in to comment.