Skip to content

Commit

Permalink
Fix fvcore flops counting for torchvision models
Browse files Browse the repository at this point in the history
Summary: For now, we only support fvcore flops counting for torchvision models.

Reviewed By: FindHao

Differential Revision: D47672523

fbshipit-source-id: dce911484ec63823272d49dea0b49be5e2f62398
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Jul 21, 2023
1 parent 79bc754 commit 9d84f9e
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 21 deletions.
6 changes: 3 additions & 3 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ def printResultSummaryTime(result_summary, metrics_needed=[], model=None, flops_
if flops_model_analyzer.metrics_backend_mapping['flops'] == 'dcgm':
tflops_device_id, tflops = flops_model_analyzer.calculate_flops()
else:
flops, batch_size = model.get_flops()
tflops = flops * batch_size / (cpu_walltime / 1.0e3) / 1.0e12
print('{:<20} {:>20}'.format("GPU %d FLOPS:" % tflops_device_id, "%.4f TFLOPs per second" % tflops, sep=''))
flops = model.get_flops()
tflops = flops / (cpu_walltime / 1.0e3) / 1.0e12
print('{:<20} {:>20}'.format("GPU FLOPS:", "%.4f TFLOPs per second" % tflops, sep=''))
if gpu_peak_mem is not None:
print('{:<20} {:>20}'.format("GPU %d Peak Memory:" % mem_device_id, "%.4f GB" % gpu_peak_mem, sep=''))
if cpu_peak_mem is not None:
Expand Down
10 changes: 0 additions & 10 deletions torchbenchmark/util/backends/flops.py

This file was deleted.

7 changes: 1 addition & 6 deletions torchbenchmark/util/extra_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
import enum
from typing import List, Optional, Tuple
from torchbenchmark.util.backends import list_backends, BACKENDS

from torchbenchmark.util.backends.flops import enable_fvcore_flops
from torchbenchmark.util.env_check import is_torchvision_model, is_staged_train_test
from torchbenchmark.util.env_check import is_staged_train_test

TEST_STAGE = enum.Enum('TEST_STAGE', ['FORWARD', 'BACKWARD', 'OPTIMIZER', 'ALL'])
AVAILABLE_PRECISIONS = ["fp32", "tf32", "fp16", "amp", "fx_int8", "bf16","amp_fp16", "amp_bf16"]
Expand Down Expand Up @@ -127,7 +125,6 @@ def apply_decoration_args(model: 'torchbenchmark.util.model.BenchmarkModel', dar
def parse_opt_args(model: 'torchbenchmark.util.model.BenchmarkModel', opt_args: List[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--backend", choices=list_backends(), help="enable backends")
parser.add_argument("--flops", choices=["fvcore", "dcgm"], help="Return the flops result")
args, extra_args = parser.parse_known_args(opt_args)
if model.jit:
args.backend = "torchscript"
Expand All @@ -137,7 +134,5 @@ def parse_opt_args(model: 'torchbenchmark.util.model.BenchmarkModel', opt_args:
return args, extra_args

def apply_opt_args(model: 'torchbenchmark.util.model.BenchmarkModel', args: argparse.Namespace):
if args.flops == "fvcore":
enable_fvcore_flops(model)
if args.backend:
model._enable_backend()
11 changes: 9 additions & 2 deletions torchbenchmark/util/framework/vision/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,15 @@ def __init__(self, model_name, test, device, jit=False, batch_size=None, weights
self.real_output = ( torch.rand_like(self.example_outputs), )

def get_flops(self):
return self.flops, self.batch_size
# By default, FlopCountAnalysis count one fused-mult-add (FMA) as one flop.
# However, in our context, we count 1 FMA as 2 flops instead of 1.
# https://github.com/facebookresearch/fvcore/blob/7a0ef0c0839fa0f5e24d2ef7f5d48712f36e7cd7/fvcore/nn/flop_count.py
assert self.test == "eval", "fvcore flops is only available on inference tests, as it doesn't measure backward pass."
from fvcore.nn import FlopCountAnalysis
FLOPS_FMA = 2.0
self.flops = FlopCountAnalysis(self.model, tuple(self.example_inputs)).total()
self.flops = self.flops * FLOPS_FMA
return self.flops

def gen_inputs(self, num_batches:int=1) -> Tuple[Generator, Optional[int]]:
def _gen_inputs():
Expand Down Expand Up @@ -96,4 +104,3 @@ def cudagraph_eval(self):
self.g.replay()
break
return (self.example_outputs, )

0 comments on commit 9d84f9e

Please sign in to comment.