Skip to content

Commit

Permalink
Merge pull request #1379 from Paitesanshi/1.1.x
Browse files Browse the repository at this point in the history
FIX: fix bugs of get_flops and core
  • Loading branch information
Sherry-XLL authored Aug 6, 2022
2 parents 411a45a + 33fc2cd commit 51cdaeb
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 73 deletions.
2 changes: 1 addition & 1 deletion recbole/model/sequential_recommender/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
test_item = interaction[self.ITEM_ID]
seq_output = self.forward(item_seq, item_seq_len)
seq_output = self.forward(item_seq)
test_item_emb = self.item_embedding(test_item)
scores = torch.mul(seq_output, test_item_emb).sum(dim=1) / self.temperature
return scores
Expand Down
2 changes: 1 addition & 1 deletion recbole/quick_start/quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def run_recbole(
init_seed(config["seed"] + config["local_rank"], config["reproducibility"])
model = get_model(config["model"])(config, train_data._dataset).to(config["device"])
logger.info(model)
flops = get_flops(model, dataset, config["device"])
flops = get_flops(model, dataset, config["device"], logger)
logger.info(set_color("FLOPs", "blue") + f": {flops}")

# trainer loading and initialization
Expand Down
176 changes: 107 additions & 69 deletions recbole/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@

import numpy as np
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from fvcore.nn import FlopCountAnalysis


from recbole.utils.enum_type import ModelType

Expand Down Expand Up @@ -244,73 +245,33 @@ def get_gpu_usage(device=None):
return "{:.2f} G/{:.2f} G".format(reserved, total)


def get_flops(
model, dataset, device, uncalled_warnings=False, unsupported_warnings=False
):
def get_flops(model, dataset, device, logger, verbose=False):
r"""Given a model and dataset to the model, compute the per-operator flops
of the given model.
Args:
model: the model to compute flop counts.
dataset: dataset that are passed to `model` to count flops.
device: cuda.device. It is the device that the model run on.
uncalled_warnings: whether to print warnings for uncalled modules.
unsupported_warnings: whether to print warnings for unsupported operations.
verbose: whether to print information of modules.
Returns:
total_flops: the number of flops for each operation.
total_ops: the number of flops for each operation.
"""

def get_shape(val):
r"""
Get the shapes from a jit value object.
Args:
val: jit value object.
Returns:
list: return a list of ints.
"""
if val.isCompleteTensor():
return val.type().sizes()
else:
return None

def binary_ops_jit(inputs, outputs):
r"""
Count flops for binary operator such as addition, subtraction, multiplication
and division.
"""
input_shapes = [get_shape(v) for v in inputs]
assert input_shapes[0] == input_shapes[1]
flop = np.prod(input_shapes[0])
return flop

def sum_ops_jit(inputs, outputs):
r"""
Count flops for sum.
"""
input_shapes = [get_shape(v) for v in inputs]
assert input_shapes[0]
flop = np.prod(input_shapes[0]) - 1
return flop

def sigmoid_ops_jit(inputs, outputs):
r"""
Count flops for sigmoid.
"""
input_shapes = [get_shape(v) for v in inputs]
assert input_shapes[0]
flop = np.prod(input_shapes) * 4
return flop

custom_ops = {
"aten::add": binary_ops_jit,
"aten::sub": binary_ops_jit,
"aten::mul": binary_ops_jit,
"aten::div": binary_ops_jit,
"aten::sum": sum_ops_jit,
"aten::sigmoid": sigmoid_ops_jit,
}
inter = dataset[torch.tensor([1])].to(device)
if model.type == ModelType.DECISIONTREE:
return 1
if model.__class__.__name__ == "Pop":
return 1

def count_normalization(m, x, y):
x = x[0]
flops = torch.DoubleTensor([2 * x.numel()])
m.total_ops += flops

def count_embedding(m, x, y):
x = x[0]
nelements = x.numel()
hiddensize = y.shape[-1]
m.total_ops += nelements * hiddensize

class TracingAdapter(torch.nn.Module):
def __init__(self, rec_model):
Expand All @@ -320,14 +281,91 @@ def __init__(self, rec_model):
def forward(self, interaction):
return self.model.predict(interaction)

custom_ops = {
torch.nn.Embedding: count_embedding,
torch.nn.LayerNorm: count_normalization,
}
wrapper = TracingAdapter(model)
flop_counter = FlopCountAnalysis(wrapper, (inter.interaction,)).set_op_handle(
**custom_ops
)
flop_counter.unsupported_ops_warnings(unsupported_warnings)
flop_counter.uncalled_modules_warnings(uncalled_warnings)
flop_counter.tracer_warnings("none")
total_flops = 0
for _, flop in flop_counter.by_operator().items():
total_flops += flop
return total_flops
inter = dataset[torch.tensor([1])].to(device)
inputs = (inter,)
from thop.profile import register_hooks
from thop.vision.basic_hooks import count_parameters

handler_collection = {}
fn_handles = []
params_handles = []
types_collection = set()
if custom_ops is None:
custom_ops = {}

def add_hooks(m: nn.Module):
m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64))
m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64))

m_type = type(m)

fn = None
if m_type in custom_ops:
fn = custom_ops[m_type]
if m_type not in types_collection and verbose:
logger.info("Customize rule %s() %s." % (fn.__qualname__, m_type))
elif m_type in register_hooks:
fn = register_hooks[m_type]
if m_type not in types_collection and verbose:
logger.info("Register %s() for %s." % (fn.__qualname__, m_type))
else:
if m_type not in types_collection and verbose:
logger.warning(
"[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params."
% m_type
)

if fn is not None:
handle_fn = m.register_forward_hook(fn)
handle_paras = m.register_forward_hook(count_parameters)
handler_collection[m] = (
handle_fn,
handle_paras,
)
fn_handles.append(handle_fn)
params_handles.append(handle_paras)
types_collection.add(m_type)

prev_training_status = wrapper.training

wrapper.eval()
wrapper.apply(add_hooks)

with torch.no_grad():
wrapper(*inputs)

def dfs_count(module: nn.Module, prefix="\t"):
total_ops, total_params = module.total_ops.item(), 0
ret_dict = {}
for n, m in module.named_children():

next_dict = {}
if m in handler_collection and not isinstance(
m, (nn.Sequential, nn.ModuleList)
):
m_ops, m_params = m.total_ops.item(), m.total_params.item()
else:
m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + "\t")
ret_dict[n] = (m_ops, m_params, next_dict)
total_ops += m_ops
total_params += m_params

return total_ops, total_params, ret_dict

total_ops, total_params, ret_dict = dfs_count(wrapper)

# reset wrapper to original status
wrapper.train(prev_training_status)
for m, (op_handler, params_handler) in handler_collection.items():
m._buffers.pop("total_ops")
m._buffers.pop("total_params")
for i in range(len(fn_handles)):
fn_handles[i].remove()
params_handles[i].remove()

return total_ops
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ pyyaml>=5.1.0
colorlog==4.7.2
colorama==0.4.4
tensorboard>=2.5.0
fvcore>=0.1.5.post20220512
ray>=1.13.0
thop>=0.1.1.post2207130030
ray>=1.13.0
tabulate>=0.8.10

0 comments on commit 51cdaeb

Please sign in to comment.