diff --git a/benchmarks/benchmark_aq.py b/benchmarks/benchmark_aq.py index fc19d1ab8..e00f32c4f 100644 --- a/benchmarks/benchmark_aq.py +++ b/benchmarks/benchmark_aq.py @@ -93,15 +93,14 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, kwargs=None): # warmup WARMUP = 5 RUNS = 100 - input_tensor = example_inputs[0] m = torch.compile(m, mode='max-autotune', fullgraph=True) - benchmark_model(m, WARMUP, input_tensor) - elapsed_time = benchmark_model(m, RUNS, input_tensor) + benchmark_model(m, WARMUP, example_inputs) + elapsed_time = benchmark_model(m, RUNS, example_inputs) m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True) - benchmark_model(m_ref, WARMUP, input_tensor) - ref_elapsed_time = benchmark_model(m_ref, RUNS, input_tensor) + benchmark_model(m_ref, WARMUP, example_inputs) + ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs) print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}") assert elapsed_time < 1.05 * ref_elapsed_time diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index d8b6d71a5..d7d63f4ec 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1532,7 +1532,7 @@ def run_benchmark_model(self, device): example_inputs = m.example_inputs(dtype=dtype, device=device) m_bf16 = torch.compile(m_bf16, mode='max-autotune') num_runs = 1 - return benchmark_model(m_bf16, num_runs, example_inputs[0]) + return benchmark_model(m_bf16, num_runs, example_inputs) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_benchmark_model_cuda(self): diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index bc4f6e4f3..5eebd8626 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -119,9 +119,9 @@ from torchao.utils import benchmark_model num_runs = 100 torch._dynamo.reset() -bf16_time = benchmark_model(m_bf16, num_runs, example_inputs[0]) +bf16_time = benchmark_model(m_bf16, num_runs, example_inputs) print(f"bf16 mean time: {bf16_time}") -int4_time = benchmark_model(m, num_runs, example_inputs[0]) +int4_time = benchmark_model(m, num_runs, example_inputs) print(f"int4 weight only quantized mean time: {int4_time}") print(f"speedup: {bf16_time / int4_time}") diff --git a/torchao/utils.py b/torchao/utils.py index 4f2ae85b1..801968b2a 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -42,8 +42,16 @@ def _assert_and_get_unique_device(module: torch.nn.Module) -> Any: return device -def benchmark_model(model, num_runs, input_tensor): - device_type = _assert_and_get_unique_device(model).type +def benchmark_model(model, num_runs, args=(), kwargs=None, device_type=None): + """Benchmark model runs with `args` and `kwargs` both are optional + """ + if kwargs is None: + kwargs = {} + + if device_type is None: + assert isinstance(model, torch.nn.Module), "Expecting `model` to be torch.nn.Module if device_type is not provided" + device_type = _assert_and_get_unique_device(model).type + if device_type == "cuda": torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) @@ -53,7 +61,7 @@ def benchmark_model(model, num_runs, input_tensor): # benchmark for _ in range(num_runs): with torch.autograd.profiler.record_function("timed region"): - model(input_tensor) + model(*args, **kwargs) end_event.record() torch.cuda.synchronize() @@ -68,7 +76,7 @@ def benchmark_model(model, num_runs, input_tensor): # benchmark for _ in range(num_runs): with torch.autograd.profiler.record_function("timed region"): - model(input_tensor) + model(*args, **kwargs) end_event.record() torch.mps.synchronize() @@ -81,7 +89,7 @@ def benchmark_model(model, num_runs, input_tensor): # benchmark for _ in range(num_runs): with torch.autograd.profiler.record_function("timed region"): - model(input_tensor) + model(*args, **kwargs) end_time = time.time() torch.cpu.synchronize() @@ -264,7 +272,7 @@ def unwrap_tensor_subclass(model, filter_fn=None): parametrize.register_parametrization(child, "weight", UnwrapTensorSubclass()) unwrap_tensor_subclass(child) return model - + def is_fbcode(): return not hasattr(torch.version, "git_version") diff --git a/tutorials/quantize_vit/run_vit_b.py b/tutorials/quantize_vit/run_vit_b.py index dae7dde70..7c60fcdba 100644 --- a/tutorials/quantize_vit/run_vit_b.py +++ b/tutorials/quantize_vit/run_vit_b.py @@ -11,15 +11,15 @@ model.eval().cuda().to(torch.bfloat16) # Input tensor (batch_size, channels, height, width) -input_tensor = torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda') +inputs = (torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda'),) model = torch.compile(model, mode='max-autotune') # Must run with no_grad when optimizing for inference with torch.no_grad(): # warmup - benchmark_model(model, 5, input_tensor) + benchmark_model(model, 5, inputs) # benchmark - print("elapsed_time: ", benchmark_model(model, 100, input_tensor), " milliseconds") + print("elapsed_time: ", benchmark_model(model, 100, inputs), " milliseconds") # Create a trace - profiler_runner("bfloat16.json.gz", benchmark_model, model, 5, input_tensor) + profiler_runner("bfloat16.json.gz", benchmark_model, model, 5, inputs) diff --git a/tutorials/quantize_vit/run_vit_b_quant.py b/tutorials/quantize_vit/run_vit_b_quant.py index a082cfe53..a826f43b9 100644 --- a/tutorials/quantize_vit/run_vit_b_quant.py +++ b/tutorials/quantize_vit/run_vit_b_quant.py @@ -12,7 +12,7 @@ model.eval().cuda().to(torch.bfloat16) # Input tensor (batch_size, channels, height, width) -input_tensor = torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda') +inputs = (torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda'),) ## Quantization code - start # int8 dynamic quantization act, int8 weight, see ao/torchao/quantization/README.md @@ -39,8 +39,8 @@ # Must run with no_grad when optimizing for inference with torch.no_grad(): # warmup - benchmark_model(model, 20, input_tensor) + benchmark_model(model, 20, inputs) # benchmark - print("elapsed_time: ", benchmark_model(model, 1000, input_tensor), " milliseconds") + print("elapsed_time: ", benchmark_model(model, 1000, inputs), " milliseconds") # Create a trace - profiler_runner("quant.json.gz", benchmark_model, model, 5, input_tensor) + profiler_runner("quant.json.gz", benchmark_model, model, 5, inputs)