Skip to content

Commit

Permalink
[benchmarks] Set matrix multiplication precision. (#7748)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored Jul 25, 2024
1 parent 6280701 commit 2870e93
Show file tree
Hide file tree
Showing 10 changed files with 26 additions and 25 deletions.
14 changes: 13 additions & 1 deletion benchmarks/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from benchmark_experiment import ExperimentLoader, BenchmarkExperiment
from util import cleanup, move_to_device, randomize_input, reset_rng_state, us_to_s, ns_to_s, StrOrBool

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp

Expand Down Expand Up @@ -939,6 +940,11 @@ def __str__(self):
action="store_true",
help="Whether to enable fast F32 multiplication in PyTorch.",
)
parser.add_argument(
"--matmul-precision",
choices=["default", "high", "highest"],
help="Set matrix multiplication for both PyTorch and PyTorch/XLA.",
)
parser.add_argument(
"--experiment-config",
type=str,
Expand Down Expand Up @@ -1009,9 +1015,15 @@ def main():
logging.basicConfig(level=args.log_level.value, force=True)
logger.debug(f"Parsed args: {args}")

precision = 'highest'
if args.matmul_precision is not None:
precision = args.matmul_precision
# --disable-tf32 flag may overwrite precision settings for BC reasons.
if not args.disable_tf32:
logger.warning('Enabling fast F32 multiplication for PyTorch')
torch.set_float32_matmul_precision('high')
precision = 'high'
torch.set_float32_matmul_precision(precision)
torch_xla._XLAC._xla_set_mat_mul_precision(precision)

if args.profile_xla:
logger.info(
Expand Down
3 changes: 1 addition & 2 deletions test/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,5 @@ def run_benchmarks(args):
args.benchs = benchs

torch.set_default_dtype(torch.float32)
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
use_full_mat_mul_precision=True)
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
run_benchmarks(args)
3 changes: 1 addition & 2 deletions test/pytorch_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,7 @@ def get_primary_device(cls):
def setUpClass(cls):
# Sets the primary test device to the xla_device (CPU or TPU)
cls.primary_device = str(xm.xla_device())
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
use_full_mat_mul_precision=True)
torch_xla._XLAC._xla_set_mat_mul_precision('highest')

def setUp(self):
super().setUp()
Expand Down
3 changes: 1 addition & 2 deletions test/test_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,6 @@ def test_gmm_backward_3(self):
logging.getLogger().setLevel(logging.INFO)
torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
use_full_mat_mul_precision=True)
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
3 changes: 1 addition & 2 deletions test/test_mp_distributed_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ def _mp_fn(index):

if xm.xla_device_hw(device) in ('TPU', 'CUDA'):
world_size = xr.world_size()
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
use_full_mat_mul_precision=True)
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
torch.manual_seed(11)
xm.set_rng_state(11)

Expand Down
3 changes: 1 addition & 2 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3104,8 +3104,7 @@ def test_repeat_special(self):
if __name__ == '__main__':
torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
use_full_mat_mul_precision=True)
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
test = unittest.main(verbosity=FLAGS.verbosity, exit=False)
if xu.getenv_as('METRICS_DEBUG', bool, defval=False):
print(met.metrics_report())
Expand Down
3 changes: 1 addition & 2 deletions test/test_operations_hlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ def test_dropout_by_u8_mask(self):
if __name__ == '__main__':
torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
use_full_mat_mul_precision=True)
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
test = unittest.main(verbosity=FLAGS.verbosity, exit=False)
if xu.getenv_as('METRICS_DEBUG', bool, defval=False):
print(met.metrics_report())
Expand Down
3 changes: 1 addition & 2 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,6 @@ def test_flash_attention_sm_scale_backward(self):
logging.getLogger().setLevel(logging.INFO)
torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
use_full_mat_mul_precision=True)
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
3 changes: 1 addition & 2 deletions test/test_pallas_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ def test_flash_attention_backward_spmd_data_parallel(self):
logging.getLogger().setLevel(logging.INFO)
torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
use_full_mat_mul_precision=True)
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
xr.use_spmd()
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
13 changes: 5 additions & 8 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1939,14 +1939,11 @@ void InitXlaModuleBindings(py::module m) {
py::arg("nodes_threshold") = 100, py::arg("device") = "");
m.def("_xla_memory_info",
[](const std::string& device) { return GetMemoryInfo(device); });
m.def(
"_xla_set_use_full_mat_mul_precision",
[](bool use_full_mat_mul_precision) {
XlaHelpers::set_mat_mul_precision(use_full_mat_mul_precision
? xla::PrecisionConfig::HIGHEST
: xla::PrecisionConfig::DEFAULT);
},
py::arg("use_full_mat_mul_precision") = true);
m.def("_xla_set_mat_mul_precision", [](const std::string& mat_mul_precision) {
xla::PrecisionConfig::Precision precision =
ConsumeValue(xla::StringToPrecision(mat_mul_precision));
XlaHelpers::set_mat_mul_precision(precision);
});

py::class_<xla::XlaBuilder, op_builder::BuilderPtr>(m, "XlaBuilder");
py::class_<op_builder::Op, op_builder::OpPtr>(m, "XlaOp");
Expand Down

0 comments on commit 2870e93

Please sign in to comment.