From 531f35cbaa705cb9341de2e7ada824e0df8d3d37 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Fri, 29 Nov 2024 15:56:23 +0000 Subject: [PATCH 01/11] Add fp8 and int4 types in supported list for Onnxruntime EP --- .../migraphx/migraphx_execution_provider.cc | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index fd36b8ae5f678..f320bf61f0ddf 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -232,6 +232,10 @@ static bool IsTypeSupported(const NodeArg* node_arg) { switch (type_proto->tensor_type().elem_type()) { case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FNUZ: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2FNUZ: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16: @@ -261,6 +265,18 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: mgx_type = migraphx_shape_double_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ: + mgx_type = migraphx_shape_fp8e4m3fnuz_type; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: + mgx_type = migraphx_shape_fp8e4m3fn_type; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2: + mgx_type = migraphx_shape_fp8e5m2_type; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ: + mgx_type = migraphx_shape_fp8e5m2fnuz_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: mgx_type = migraphx_shape_int8_type; break; From 7076f7463507e72f1fa01f81ef3ebb33e21ab4b8 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 4 Dec 2024 23:05:49 +0000 Subject: [PATCH 02/11] Add support for int4 inputs Map things to int8 right now as we don't explicitly set an int4 input type and pack/unpack int4 operands --- .../providers/migraphx/migraphx_execution_provider.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index f320bf61f0ddf..c1cae43480ea2 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -237,10 +237,12 @@ static bool IsTypeSupported(const NodeArg* node_arg) { case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2FNUZ: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32: @@ -277,6 +279,9 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ: mgx_type = migraphx_shape_fp8e5m2fnuz_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: + mgx_type = migraphx_shape_int8_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: mgx_type = migraphx_shape_int8_type; break; @@ -289,6 +294,9 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: mgx_type = migraphx_shape_int64_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4: + mgx_type = migraphx_shape_uint8_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: mgx_type = migraphx_shape_uint8_type; break; From 11ff6449deb32bdabf7f08c670bc958e3bab416b Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 18 Dec 2024 20:20:45 +0000 Subject: [PATCH 03/11] Add flag to allow for fp8 quantization through Onnxruntime API --- include/onnxruntime/core/session/onnxruntime_c_api.h | 1 + .../migraphx/migraphx_execution_provider.cc | 12 ++++++++++-- .../providers/migraphx/migraphx_execution_provider.h | 3 +++ .../migraphx/migraphx_execution_provider_info.cc | 4 ++++ .../migraphx/migraphx_execution_provider_info.h | 1 + .../providers/migraphx/migraphx_provider_factory.cc | 2 ++ onnxruntime/python/onnxruntime_pybind_state.cc | 11 +++++++++++ onnxruntime/test/util/default_providers.cc | 1 + 8 files changed, 33 insertions(+), 2 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 4ca2791e26ab9..7ba935633ca3c 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -614,6 +614,7 @@ typedef struct OrtTensorRTProviderOptions { typedef struct OrtMIGraphXProviderOptions { int device_id; // hip device id. int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true + int migraphx_fp8_enable; // MIGraphX FP8 precision. Default 0 = false, nonzero = true int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index c1cae43480ea2..755da8611eedd 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -114,6 +114,12 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); } + // whether fp8 quantization is enabled + const std::string fp8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP8Enable); + if (!fp8_enable_env.empty()) { + fp8_enable_ = (std::stoi(fp8_enable_env) == 0 ? false : true); + } + // whether int8 is enabled const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable); if (!int8_enable_env.empty()) { @@ -192,6 +198,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: " << "device_id: " << info_.device_id << ", migraphx_fp16_enable: " << fp16_enable_ + << ", migraphx_fp8_enable: " << fp8_enable_ << ", migraphx_int8_enable: " << int8_enable_ << ", migraphx_int8_enable: " << int8_enable_ << ", dump_model_ops: " << dump_model_ops_ @@ -1183,7 +1190,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if (int8_enable_ && int8_calibration_cache_available_) { + if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { LOGS_DEFAULT(INFO) << "Quantizing input program to int8" << std::endl; migraphx::quantize_int8_options quant_opts; migraphx::program_parameters quant_params; @@ -1240,7 +1247,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::unique_ptr p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, - map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_, + map_no_input_shape_[context->node_name], fp16_enable_, fp8_enable_, int8_enable_, int8_calibration_cache_available_, dynamic_range_map_, save_compiled_model_, save_compiled_path_, load_compiled_model_, load_compiled_path_, dump_model_ops_}; @@ -1265,6 +1272,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::onnx_options& cmp_options = mgx_state->options; bool& no_input_shape = mgx_state->no_input_shape; bool fp16_enable = mgx_state->fp16_enable; + bool fp8_enable = mgx_state->fp8_enable; bool int8_enable = mgx_state->int8_enable; bool int8_calibration_cache_available = mgx_state->int8_calibration_cache_available; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 2be6c09551a71..dd7cfaedac361 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -17,6 +17,7 @@ namespace onnxruntime { namespace migraphx_env_vars { static const char kFP16Enable[] = "ORT_MIGRAPHX_FP16_ENABLE"; +static const char kFP8Enable[] = "ORT_MIGRAPHX_FP8_ENABLE"; static const char kINT8Enable[] = "ORT_MIGRAPHX_INT8_ENABLE"; static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"; @@ -43,6 +44,7 @@ struct MIGraphXFuncState { OrtMutex* mgx_mu_ptr = nullptr; bool no_input_shape = false; bool fp16_enable = false; + bool fp8_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; std::unordered_map dynamic_range_map; @@ -89,6 +91,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { private: MIGraphXExecutionProviderInfo info_; bool fp16_enable_ = false; + bool fp8_enable_ = false; bool int8_enable_ = false; std::string int8_calibration_cache_name_; bool int8_calibration_cache_available_ = false; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index 1f9a47d3ad87d..6537d1c12bc9c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -14,6 +14,7 @@ namespace migraphx { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; constexpr const char* kFp16Enable = "trt_fp16_enable"; +constexpr const char* kFp8Enable = "migx_fp8_enable"; constexpr const char* kInt8Enable = "migx_int8_enable"; constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name"; constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table"; @@ -43,6 +44,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions return Status::OK(); }) .AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable) + .AddAssignmentToReference(migraphx::provider_option_names::kFp8Enable, info.fp8_enable) .AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable) .AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model) .AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model) @@ -56,6 +58,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, + {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)}, {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)}, @@ -68,6 +71,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, + {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)}, {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)}, diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index b8bf86580f03d..554f77b6f7f58 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -16,6 +16,7 @@ struct MIGraphXExecutionProviderInfo { std::string target_device; OrtDevice::DeviceId device_id{0}; bool fp16_enable{false}; + bool fp8_enable{false}; bool int8_enable{false}; std::string int8_calibration_table_name{""}; bool int8_use_native_calibration_table{false}; diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 7b192b657b7cc..098512b2ce69c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -60,6 +60,7 @@ struct MIGraphX_Provider : Provider { info.device_id = static_cast(options.device_id); info.target_device = "gpu"; info.fp16_enable = options.migraphx_fp16_enable; + info.fp8_enable = options.migraphx_fp8_enable; info.exhaustive_tune = options.migraphx_exhaustive_tune; info.int8_enable = options.migraphx_int8_enable; info.int8_calibration_table_name = ""; @@ -85,6 +86,7 @@ struct MIGraphX_Provider : Provider { auto& migx_options = *reinterpret_cast(provider_options); migx_options.device_id = internal_options.device_id; migx_options.migraphx_fp16_enable = internal_options.fp16_enable; + migx_options.migraphx_fp8_enable = internal_options.fp8_enable; migx_options.migraphx_int8_enable = internal_options.int8_enable; migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index f31734bdfb805..1862072c36e1f 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -831,6 +831,7 @@ std::unique_ptr CreateExecutionProviderInstance( 0, 0, 0, + 0, nullptr, 1, "./compiled_model.mxr", @@ -854,6 +855,16 @@ std::unique_ptr CreateExecutionProviderInstance( "[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be" " 'True' or 'False'. Default value is 'False'.\n"); } + } else if (option.first == "migraphx_fp8_enable") { + if (option.second == "True" || option.second == "true") { + params.migraphx_fp8_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.migraphx_fp8_enable = false; + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migx_fp8_enable' should be" + " 'True' or 'False'. Default value is 'False'.\n"); + } } else if (option.first == "migraphx_int8_enable") { if (option.second == "True" || option.second == "true") { params.migraphx_int8_enable = true; diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 1feba20e32bbb..eecf194e7c9a1 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -76,6 +76,7 @@ std::unique_ptr DefaultMIGraphXExecutionProvider() { 0, 0, 0, + 0, nullptr, 1, "./compiled_model.mxr", From ac77aacb174a30b1e682645fb73846af78de6e0e Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 19 Dec 2024 01:21:47 +0000 Subject: [PATCH 04/11] Add fp8 quantization to the compile stage of the MIGraphX EP Mirror the same calibration code we use for int8 and just change which quantize we call through the MIGraphx API --- .../migraphx/migraphx_execution_provider.cc | 54 +++++++++++++------ 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 755da8611eedd..980ac7720b8d9 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1190,9 +1190,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { + if ((int8_enable_ xor fp8_enable_) && int8_calibration_cache_available_) { LOGS_DEFAULT(INFO) << "Quantizing input program to int8" << std::endl; - migraphx::quantize_int8_options quant_opts; migraphx::program_parameters quant_params; auto param_shapes = prog.get_parameter_shapes(); @@ -1202,15 +1201,26 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); } - quant_opts.add_calibration_data(quant_params); - - // specify thing we want to int8 quantize - quant_opts.add_op_name("convolution"); - quant_opts.add_op_name("dot"); // perform static quantization on the programs - migraphx::quantize_int8(prog, t_, quant_opts); - LOGS_DEFAULT(INFO) << "Quantizing input program to int8: Complete" << std::endl; + if(int8_enable_) + { + migraphx::quantize_int8_options quant_opts; + quant_opts.add_calibration_data(quant_params); + // specify thing we want to int8 quantize + quant_opts.add_op_name("convolution"); + quant_opts.add_op_name("dot"); + migraphx::quantize_int8(prog, t_, quant_opts); + LOGS_DEFAULT(INFO) << "Quantizing input program to int8: Complete" << std::endl; + } + else if(fp8_enable_) + { + migraphx::quantize_fp8_options quant_opts; + quant_opts.add_calibration_data(quant_params); + migraphx::quantize_fp8(prog, t_, quant_opts); + LOGS_DEFAULT(INFO) << "Quantizing input program to fp8: Complete" << std::endl; + } + } if (fp16_enable_) { @@ -1333,9 +1343,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if (int8_enable && int8_calibration_cache_available) { + if ((int8_enable xor fp8_enable) && int8_calibration_cache_available) { LOGS_DEFAULT(INFO) << "Quantize Int8: Begin" << std::endl; - migraphx::quantize_int8_options quant_opts; migraphx::program_parameters quant_params; auto param_shapes = prog.get_parameter_shapes(); @@ -1364,14 +1373,25 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); } - quant_opts.add_calibration_data(quant_params); - // specify thing we want to int8 quantize - quant_opts.add_op_name("convolution"); - quant_opts.add_op_name("dot"); // perform static quantization on the programs - migraphx::quantize_int8(prog, t, quant_opts); - LOGS_DEFAULT(INFO) << "Quantize Int8: Completed" << std::endl; + if(int8_enable) + { + migraphx::quantize_int8_options quant_opts; + quant_opts.add_calibration_data(quant_params); + // specify thing we want to int8 quantize + quant_opts.add_op_name("convolution"); + quant_opts.add_op_name("dot"); + migraphx::quantize_int8(prog, t, quant_opts); + LOGS_DEFAULT(INFO) << "Quantizing input program to fp8: Complete" << std::endl; + } + else if(fp8_enable) + { + migraphx::quantize_fp8_options quant_opts; + quant_opts.add_calibration_data(quant_params); + migraphx::quantize_fp8(prog, t, quant_opts); + LOGS_DEFAULT(INFO) << "Quantizing input program to fp8: Complete" << std::endl; + } } if (fp16_enable) { From 3d8c69ae6622c5e7c5231f19a92d268704ffd29b Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 19 Dec 2024 22:05:02 +0000 Subject: [PATCH 05/11] cleanup logging --- .../migraphx/migraphx_execution_provider.cc | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 980ac7720b8d9..bc1f9fe5d711d 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1052,7 +1052,7 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v if (dump_model_ops_) { LOGS_DEFAULT(INFO) << "============= Unsupported nodes ===================="; for (auto idx : unsupported_nodes) { - LOGS_DEFAULT(INFO) << graph_viewer.GetNode(idx)->OpType() << std::endl; + LOGS_DEFAULT(INFO) << graph_viewer.GetNode(idx)->OpType(); } LOGS_DEFAULT(INFO) << "************* Unsupported nodes ********************"; } @@ -1140,11 +1140,11 @@ bool load_precompiled_model(migraphx::program& prog, bool load_enable, std::stri void save_compiled_model(migraphx::program& prog, bool save_enable, std::string out_path) { if (save_enable) { - LOGS_DEFAULT(INFO) << "Model Save at " << out_path << ": Begin" << std::endl; + LOGS_DEFAULT(INFO) << "Model Save at " << out_path << ": Begin"; migraphx::file_options fo; fo.set_file_format("msgpack"); migraphx::save(prog, out_path.c_str(), fo); - LOGS_DEFAULT(INFO) << "Model Save: Complete" << std::endl; + LOGS_DEFAULT(INFO) << "Model Save: Complete"; } } @@ -1191,7 +1191,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // Read in the calibration data and map it to an migraphx paramater map for the calibration ops if ((int8_enable_ xor fp8_enable_) && int8_calibration_cache_available_) { - LOGS_DEFAULT(INFO) << "Quantizing input program to int8" << std::endl; + LOGS_DEFAULT(INFO) << "Quantizing input program to int8"; migraphx::program_parameters quant_params; auto param_shapes = prog.get_parameter_shapes(); @@ -1211,30 +1211,30 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& quant_opts.add_op_name("convolution"); quant_opts.add_op_name("dot"); migraphx::quantize_int8(prog, t_, quant_opts); - LOGS_DEFAULT(INFO) << "Quantizing input program to int8: Complete" << std::endl; + LOGS_DEFAULT(INFO) << "Quantizing input program to int8: Complete"; } else if(fp8_enable_) { migraphx::quantize_fp8_options quant_opts; quant_opts.add_calibration_data(quant_params); migraphx::quantize_fp8(prog, t_, quant_opts); - LOGS_DEFAULT(INFO) << "Quantizing input program to fp8: Complete" << std::endl; + LOGS_DEFAULT(INFO) << "Quantizing input program to fp8: Complete"; } } if (fp16_enable_) { - LOGS_DEFAULT(INFO) << "Quantizing input program to fp16" << std::endl; + LOGS_DEFAULT(INFO) << "Quantizing input program to fp16"; migraphx::quantize_fp16(prog); - LOGS_DEFAULT(INFO) << "Quantizing input program to fp16: Complete" << std::endl; + LOGS_DEFAULT(INFO) << "Quantizing input program to fp16: Complete"; } migraphx::compile_options co; co.set_fast_math(false); co.set_exhaustive_tune_flag(exhaustive_tune_); - LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl; + LOGS_DEFAULT(INFO) << "Model Compile: Begin"; prog.compile(t_, co); - LOGS_DEFAULT(INFO) << "Model Compile: Complete" << std::endl; + LOGS_DEFAULT(INFO) << "Model Compile: Complete"; save_compiled_model(prog, save_compiled_model_, save_compiled_path_); } @@ -1291,7 +1291,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& bool input_shape_match = true; migraphx::program_parameter_shapes param_shapes; if (no_input_shape) { - LOGS_DEFAULT(VERBOSE) << "Missing input shape setting input parameters again" << std::endl; + LOGS_DEFAULT(VERBOSE) << "Missing input shape setting input parameters again"; for (auto& it : map_input_name_index) { auto& name = it.first; auto& index = it.second; @@ -1303,7 +1303,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& input_shape_match = false; } } else { - LOGS_DEFAULT(VERBOSE) << "Assigning inputs, and parameters from compiled model" << std::endl; + LOGS_DEFAULT(VERBOSE) << "Assigning inputs, and parameters from compiled model"; param_shapes = prog.get_parameter_shapes(); auto prog_output_shapes = prog.get_output_shapes(); @@ -1338,13 +1338,13 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // re-compile the program if (!input_shape_match) { if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { - LOGS_DEFAULT(VERBOSE) << "No Input shapes mismatch detected. Recompiling" << std::endl; + LOGS_DEFAULT(VERBOSE) << "No Input shapes mismatch detected. Recompiling"; cmp_options.set_external_data_path(model_path_.parent_path().string()); prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); // Read in the calibration data and map it to an migraphx paramater map for the calibration ops if ((int8_enable xor fp8_enable) && int8_calibration_cache_available) { - LOGS_DEFAULT(INFO) << "Quantize Int8: Begin" << std::endl; + LOGS_DEFAULT(INFO) << "Quantize Int8: Begin"; migraphx::program_parameters quant_params; auto param_shapes = prog.get_parameter_shapes(); @@ -1383,24 +1383,24 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& quant_opts.add_op_name("convolution"); quant_opts.add_op_name("dot"); migraphx::quantize_int8(prog, t, quant_opts); - LOGS_DEFAULT(INFO) << "Quantizing input program to fp8: Complete" << std::endl; + LOGS_DEFAULT(INFO) << "Quantizing input program to fp8: Complete"; } else if(fp8_enable) { migraphx::quantize_fp8_options quant_opts; quant_opts.add_calibration_data(quant_params); migraphx::quantize_fp8(prog, t, quant_opts); - LOGS_DEFAULT(INFO) << "Quantizing input program to fp8: Complete" << std::endl; + LOGS_DEFAULT(INFO) << "Quantizing input program to fp8: Complete"; } } if (fp16_enable) { - LOGS_DEFAULT(INFO) << "Quantize fp16: Begin" << std::endl; + LOGS_DEFAULT(INFO) << "Quantize fp16: Begin"; migraphx::quantize_fp16(prog); - LOGS_DEFAULT(INFO) << "Quantize fp16: Completed" << std::endl; + LOGS_DEFAULT(INFO) << "Quantize fp16: Completed"; } - LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl; + LOGS_DEFAULT(INFO) << "Model Compile: Begin"; migraphx::compile_options co; co.set_fast_math(false); co.set_exhaustive_tune_flag(exhaustive_tune_); @@ -1420,7 +1420,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (param_shapes.size() > 0) { for (auto&& name : param_shapes.names()) { if (map_input_name_index.count(name) > 0) { - LOGS_DEFAULT(INFO) << "Setting parameters for:" << name << std::endl; + LOGS_DEFAULT(INFO) << "Setting parameters for:" << name; auto input_tensor = ctx.GetInput(map_input_name_index[name]); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shape = tensor_info.GetShape(); @@ -1434,7 +1434,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; } - LOGS_DEFAULT(INFO) << "Writing Raw tensor data " << std::endl; + LOGS_DEFAULT(INFO) << "Writing Raw tensor data "; m.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } From 14d8a5868d26f518acf1a21e238c31ef8f54e194 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Tue, 24 Dec 2024 17:32:11 +0000 Subject: [PATCH 06/11] Cleanup and encapsulate quantization / compile functions - Add additional flags for fp8 thats shared for int8 - Add lockout warning message when int8/fp8 used at the same time --- .../migraphx/migraphx_execution_provider.cc | 204 ++++++++---------- 1 file changed, 95 insertions(+), 109 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index bc1f9fe5d711d..3e063e35e7cc4 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -126,7 +126,12 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv int8_enable_ = (std::stoi(int8_enable_env) == 0 ? false : true); } - if (int8_enable_) { + if(int8_enable_ and fp8_enable_) + { + LOGS_DEFAULT(FATAL) << "MIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; + } + + if (int8_enable_ || fp8_enable_) { const std::string int8_calibration_cache_name_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8CalibrationTableName); if (!int8_calibration_cache_name_env.empty()) { @@ -146,13 +151,13 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv } } - if (int8_enable_) { + if (int8_enable_ or fp8_enable_) { int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty(); } // Load INT8 calibration table std::unordered_map dynamic_range_map; - if (int8_enable_ && int8_calibration_cache_available_) { + if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { const std::string calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path); @@ -195,21 +200,20 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv metadef_id_generator_ = ModelMetadefIdGenerator::Create(); - LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: " - << "device_id: " << info_.device_id - << ", migraphx_fp16_enable: " << fp16_enable_ - << ", migraphx_fp8_enable: " << fp8_enable_ - << ", migraphx_int8_enable: " << int8_enable_ - << ", migraphx_int8_enable: " << int8_enable_ - << ", dump_model_ops: " << dump_model_ops_ - << ", exhaustive_tune: " << exhaustive_tune_ - << ", migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ - << ", int8_calibration_cache_available: " << int8_calibration_cache_available_ - << ", use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_ - << ", migraphx_save_compiled_model: " << save_compiled_model_ - << ", migraphx_save_compiled_model_path: " << save_compiled_path_ - << ", migraphx_load_compiled_model: " << load_compiled_model_ - << ", migraphx_load_compiled_model_path: " << load_compiled_path_; + LOGS_DEFAULT(WARNING) << "[MIGraphX EP] MIGraphX provider options:" + << "\n device_id: " << info_.device_id + << "\n migraphx_fp16_enable: " << fp16_enable_ + << "\n migraphx_fp8_enable: " << fp8_enable_ + << "\n migraphx_int8_enable: " << int8_enable_ + << "\n dump_model_ops: " << dump_model_ops_ + << "\n exhaustive_tune: " << exhaustive_tune_ + << "\n migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ + << "\n int8_calibration_cache_available: " << int8_calibration_cache_available_ + << "\n use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_ + << "\n migraphx_save_compiled_model: " << save_compiled_model_ + << "\n migraphx_save_compiled_model_path: " << save_compiled_path_ + << "\n migraphx_load_compiled_model: " << load_compiled_model_ + << "\n migraphx_load_compiled_model_path: " << load_compiled_path_; } MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { @@ -1148,6 +1152,70 @@ void save_compiled_model(migraphx::program& prog, bool save_enable, std::string } } + +// Order matters here especially if the program uses mixed quantization +// Calibrate on full precision for int8/fp8 and then quantize down to fp16 +void calibrate_and_quantize(migraphx::program& prog, + const migraphx::target& t, + const migraphx::program_parameters quant_params, + bool fp16_enable, + bool int8_enable, + bool fp8_enable, + bool int8_calibration_cache_available, + std::unordered_map& dynamic_range_map) +{ + // Read in the calibration data and map it to an migraphx paramater map for the calibration ops + if ((int8_enable xor fp8_enable) && int8_calibration_cache_available) { + LOGS_DEFAULT(INFO) << "Quantizing input program"; + + auto param_shapes = prog.get_parameter_shapes(); + + // Add all calibration data read in from int8 table + for (auto& [cal_key, cal_val] : dynamic_range_map) { + auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); + quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); + } + + // perform static quantization on the programs + if(int8_enable) + { + migraphx::quantize_int8_options quant_opts; + quant_opts.add_calibration_data(quant_params); + // specify thing we want to int8 quantize + quant_opts.add_op_name("convolution"); + quant_opts.add_op_name("dot"); + migraphx::quantize_int8(prog, t, quant_opts); + LOGS_DEFAULT(WARNING) << "Quantizing input program to int8: Complete"; + } + else if(fp8_enable) + { + migraphx::quantize_fp8_options quant_opts; + quant_opts.add_calibration_data(quant_params); + migraphx::quantize_fp8(prog, t, quant_opts); + LOGS_DEFAULT(WARNING) << "Quantizing input program to fp8: Complete"; + } + + } + + if (fp16_enable) { + LOGS_DEFAULT(INFO) << "Quantizing input program to fp16"; + migraphx::quantize_fp16(prog); + LOGS_DEFAULT(INFO) << "Quantizing input program to fp16: Complete"; + } +} + +void compile_program(migraphx::program& prog, + const migraphx::target& t, + bool exhaustive_tune) +{ + LOGS_DEFAULT(INFO) << "Model Compile: Begin"; + migraphx::compile_options co; + co.set_fast_math(false); + co.set_exhaustive_tune_flag(exhaustive_tune); + prog.compile(t, co); + LOGS_DEFAULT(INFO) << "Model Compile: Complete"; +} + Status MIGraphXExecutionProvider::Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) { migraphx::onnx_options options; @@ -1188,54 +1256,11 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { LOGS_DEFAULT(INFO) << "No Input shapes detected quantizing model"; prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); + migraphx::program_parameters quant_params; - // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if ((int8_enable_ xor fp8_enable_) && int8_calibration_cache_available_) { - LOGS_DEFAULT(INFO) << "Quantizing input program to int8"; - migraphx::program_parameters quant_params; - - auto param_shapes = prog.get_parameter_shapes(); - - // Add all calibration data read in from int8 table - for (auto& [cal_key, cal_val] : dynamic_range_map_) { - auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); - quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); - } - - // perform static quantization on the programs - if(int8_enable_) - { - migraphx::quantize_int8_options quant_opts; - quant_opts.add_calibration_data(quant_params); - // specify thing we want to int8 quantize - quant_opts.add_op_name("convolution"); - quant_opts.add_op_name("dot"); - migraphx::quantize_int8(prog, t_, quant_opts); - LOGS_DEFAULT(INFO) << "Quantizing input program to int8: Complete"; - } - else if(fp8_enable_) - { - migraphx::quantize_fp8_options quant_opts; - quant_opts.add_calibration_data(quant_params); - migraphx::quantize_fp8(prog, t_, quant_opts); - LOGS_DEFAULT(INFO) << "Quantizing input program to fp8: Complete"; - } - - } - - if (fp16_enable_) { - LOGS_DEFAULT(INFO) << "Quantizing input program to fp16"; - migraphx::quantize_fp16(prog); - LOGS_DEFAULT(INFO) << "Quantizing input program to fp16: Complete"; - } - - migraphx::compile_options co; - co.set_fast_math(false); - co.set_exhaustive_tune_flag(exhaustive_tune_); - LOGS_DEFAULT(INFO) << "Model Compile: Begin"; - prog.compile(t_, co); - LOGS_DEFAULT(INFO) << "Model Compile: Complete"; - + calibrate_and_quantize(prog, t_, quant_params, fp16_enable_, int8_enable_, + fp8_enable_, int8_calibration_cache_available_, dynamic_range_map_); + compile_program(prog, t_, exhaustive_tune_); save_compiled_model(prog, save_compiled_model_, save_compiled_path_); } @@ -1341,14 +1366,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& LOGS_DEFAULT(VERBOSE) << "No Input shapes mismatch detected. Recompiling"; cmp_options.set_external_data_path(model_path_.parent_path().string()); prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); + migraphx::program_parameters quant_params; - // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if ((int8_enable xor fp8_enable) && int8_calibration_cache_available) { - LOGS_DEFAULT(INFO) << "Quantize Int8: Begin"; - migraphx::program_parameters quant_params; - + if((int8_enable xor fp8_enable) and int8_calibration_cache_available) { auto param_shapes = prog.get_parameter_shapes(); - // Add input parameter data and the values they're set to for (auto&& name : param_shapes.names()) { if (map_input_name_index.count(name) > 0) { @@ -1367,45 +1388,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& quant_params.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } } - - // Add all calibration data read in from int8 table - for (auto& [cal_key, cal_val] : map_dynamic_range) { - auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); - quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); - } - - // perform static quantization on the programs - if(int8_enable) - { - migraphx::quantize_int8_options quant_opts; - quant_opts.add_calibration_data(quant_params); - // specify thing we want to int8 quantize - quant_opts.add_op_name("convolution"); - quant_opts.add_op_name("dot"); - migraphx::quantize_int8(prog, t, quant_opts); - LOGS_DEFAULT(INFO) << "Quantizing input program to fp8: Complete"; - } - else if(fp8_enable) - { - migraphx::quantize_fp8_options quant_opts; - quant_opts.add_calibration_data(quant_params); - migraphx::quantize_fp8(prog, t, quant_opts); - LOGS_DEFAULT(INFO) << "Quantizing input program to fp8: Complete"; - } - } - - if (fp16_enable) { - LOGS_DEFAULT(INFO) << "Quantize fp16: Begin"; - migraphx::quantize_fp16(prog); - LOGS_DEFAULT(INFO) << "Quantize fp16: Completed"; } - - LOGS_DEFAULT(INFO) << "Model Compile: Begin"; - migraphx::compile_options co; - co.set_fast_math(false); - co.set_exhaustive_tune_flag(exhaustive_tune_); - prog.compile(t, co); - + calibrate_and_quantize(prog, t, quant_params, fp16_enable, int8_enable, + fp8_enable, int8_calibration_cache_available, map_dynamic_range); + compile_program(prog, t, exhaustive_tune_); save_compiled_model(prog, mgx_state->save_compiled_mode, mgx_state->save_compiled_path); } From e70be4256d0e876e1ffb0255a3e6c1bd73271312 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Tue, 24 Dec 2024 17:35:39 +0000 Subject: [PATCH 07/11] Run lintrunner pass --- .../migraphx/migraphx_execution_provider.cc | 30 +++++++------------ .../migraphx/migraphx_execution_provider.h | 4 +-- 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 3e063e35e7cc4..e2bda11eaf0dd 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -126,8 +126,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv int8_enable_ = (std::stoi(int8_enable_env) == 0 ? false : true); } - if(int8_enable_ and fp8_enable_) - { + if (int8_enable_ and fp8_enable_) { LOGS_DEFAULT(FATAL) << "MIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; } @@ -1152,7 +1151,6 @@ void save_compiled_model(migraphx::program& prog, bool save_enable, std::string } } - // Order matters here especially if the program uses mixed quantization // Calibrate on full precision for int8/fp8 and then quantize down to fp16 void calibrate_and_quantize(migraphx::program& prog, @@ -1162,8 +1160,7 @@ void calibrate_and_quantize(migraphx::program& prog, bool int8_enable, bool fp8_enable, bool int8_calibration_cache_available, - std::unordered_map& dynamic_range_map) -{ + std::unordered_map& dynamic_range_map) { // Read in the calibration data and map it to an migraphx paramater map for the calibration ops if ((int8_enable xor fp8_enable) && int8_calibration_cache_available) { LOGS_DEFAULT(INFO) << "Quantizing input program"; @@ -1177,8 +1174,7 @@ void calibrate_and_quantize(migraphx::program& prog, } // perform static quantization on the programs - if(int8_enable) - { + if (int8_enable) { migraphx::quantize_int8_options quant_opts; quant_opts.add_calibration_data(quant_params); // specify thing we want to int8 quantize @@ -1186,15 +1182,12 @@ void calibrate_and_quantize(migraphx::program& prog, quant_opts.add_op_name("dot"); migraphx::quantize_int8(prog, t, quant_opts); LOGS_DEFAULT(WARNING) << "Quantizing input program to int8: Complete"; - } - else if(fp8_enable) - { + } else if (fp8_enable) { migraphx::quantize_fp8_options quant_opts; quant_opts.add_calibration_data(quant_params); migraphx::quantize_fp8(prog, t, quant_opts); LOGS_DEFAULT(WARNING) << "Quantizing input program to fp8: Complete"; } - } if (fp16_enable) { @@ -1206,8 +1199,7 @@ void calibrate_and_quantize(migraphx::program& prog, void compile_program(migraphx::program& prog, const migraphx::target& t, - bool exhaustive_tune) -{ + bool exhaustive_tune) { LOGS_DEFAULT(INFO) << "Model Compile: Begin"; migraphx::compile_options co; co.set_fast_math(false); @@ -1259,7 +1251,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::program_parameters quant_params; calibrate_and_quantize(prog, t_, quant_params, fp16_enable_, int8_enable_, - fp8_enable_, int8_calibration_cache_available_, dynamic_range_map_); + fp8_enable_, int8_calibration_cache_available_, dynamic_range_map_); compile_program(prog, t_, exhaustive_tune_); save_compiled_model(prog, save_compiled_model_, save_compiled_path_); } @@ -1316,7 +1308,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& bool input_shape_match = true; migraphx::program_parameter_shapes param_shapes; if (no_input_shape) { - LOGS_DEFAULT(VERBOSE) << "Missing input shape setting input parameters again"; + LOGS_DEFAULT(INFO) << "Missing input shape setting input parameters again"; for (auto& it : map_input_name_index) { auto& name = it.first; auto& index = it.second; @@ -1328,7 +1320,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& input_shape_match = false; } } else { - LOGS_DEFAULT(VERBOSE) << "Assigning inputs, and parameters from compiled model"; + LOGS_DEFAULT(INFO) << "Assigning inputs, and parameters from compiled model"; param_shapes = prog.get_parameter_shapes(); auto prog_output_shapes = prog.get_output_shapes(); @@ -1363,12 +1355,12 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // re-compile the program if (!input_shape_match) { if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { - LOGS_DEFAULT(VERBOSE) << "No Input shapes mismatch detected. Recompiling"; + LOGS_DEFAULT(INFO) << "No Input shapes mismatch detected. Recompiling"; cmp_options.set_external_data_path(model_path_.parent_path().string()); prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); migraphx::program_parameters quant_params; - if((int8_enable xor fp8_enable) and int8_calibration_cache_available) { + if ((int8_enable xor fp8_enable) and int8_calibration_cache_available) { auto param_shapes = prog.get_parameter_shapes(); // Add input parameter data and the values they're set to for (auto&& name : param_shapes.names()) { @@ -1390,7 +1382,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& } } calibrate_and_quantize(prog, t, quant_params, fp16_enable, int8_enable, - fp8_enable, int8_calibration_cache_available, map_dynamic_range); + fp8_enable, int8_calibration_cache_available, map_dynamic_range); compile_program(prog, t, exhaustive_tune_); save_compiled_model(prog, mgx_state->save_compiled_mode, mgx_state->save_compiled_path); } diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index dd7cfaedac361..be372b3012b36 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -44,7 +44,7 @@ struct MIGraphXFuncState { OrtMutex* mgx_mu_ptr = nullptr; bool no_input_shape = false; bool fp16_enable = false; - bool fp8_enable = false; + bool fp8_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; std::unordered_map dynamic_range_map; @@ -91,7 +91,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { private: MIGraphXExecutionProviderInfo info_; bool fp16_enable_ = false; - bool fp8_enable_ = false; + bool fp8_enable_ = false; bool int8_enable_ = false; std::string int8_calibration_cache_name_; bool int8_calibration_cache_available_ = false; From 647d7d32b0c1365f9a65e1ae8d75658b3e272946 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 25 Dec 2024 00:06:36 +0000 Subject: [PATCH 08/11] Fix session options inputs + add better logging. Previous runs using session options failed as we were missing pulling in inputs from the python interface. This plus additional logging allowed me to track what options were invoked via env and what were added during the start of an inference session --- .../migraphx/migraphx_execution_provider.cc | 86 +++++++++++++++---- .../migraphx/migraphx_execution_provider.h | 4 + 2 files changed, 72 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index e2bda11eaf0dd..7f362859609fb 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -104,10 +104,59 @@ std::shared_ptr MIGraphXExecutionProvider::GetKernelRegistry() c MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_(info) { InitProviderOrtApi(); + get_flags_from_session_info(info); + metadef_id_generator_ = ModelMetadefIdGenerator::Create(); + get_flags_from_env(); +} + +MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { +} + +void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info) { // Set GPU device to be used HIP_CALL_THROW(hipSetDevice(info_.device_id)); t_ = migraphx::target(info.target_device.c_str()); + // Quantization + fp16_enable_ = info.fp16_enable; + fp8_enable_ = info.fp8_enable; + int8_enable_ = info.int8_enable; + + if (int8_enable_ and fp8_enable_) { + LOGS_DEFAULT(FATAL) << "MIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; + } + + if (int8_enable_ xor fp8_enable_) { + int8_calibration_cache_name_ = info.int8_calibration_table_name; + int8_use_native_migraphx_calibration_table_ = info.int8_use_native_calibration_table; + } + + if (int8_enable_ or fp8_enable_) { + int8_calibration_cache_available_ = !info.int8_calibration_table_name.empty(); + } + + // Load INT8 calibration table + std::unordered_map dynamic_range_map; + if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { + const std::string calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); + if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { + throw std::runtime_error("Session Failed to read INT8 calibration table " + calibration_cache_path); + } + } + + // Save/load migraphx compiled models + save_compiled_model_ = info.save_compiled_model; + save_compiled_path_ = info.save_model_file; + load_compiled_model_ = info.load_compiled_model; + load_compiled_path_ = info.load_model_file; + + exhaustive_tune_ = info.exhaustive_tune; + + LOGS_DEFAULT(WARNING) << "[MIGraphX EP] MIGraphX provider Session Options:"; + print_migraphx_ep_flags(); +} + +void MIGraphXExecutionProvider::get_flags_from_env() { // whether fp16 is enable const std::string fp16_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP16Enable); if (!fp16_enable_env.empty()) { @@ -159,7 +208,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { const std::string calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { - throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path); + throw std::runtime_error("ENV Failed to read INT8 calibration table " + calibration_cache_path); } } @@ -197,10 +246,12 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv exhaustive_tune_ = (std::stoi(exhaustive_tune_env) == 0 ? false : true); } - metadef_id_generator_ = ModelMetadefIdGenerator::Create(); + LOGS_DEFAULT(WARNING) << "[MIGraphX EP] MIGraphX provider ENV Variables:"; + print_migraphx_ep_flags(); +} - LOGS_DEFAULT(WARNING) << "[MIGraphX EP] MIGraphX provider options:" - << "\n device_id: " << info_.device_id +void MIGraphXExecutionProvider::print_migraphx_ep_flags() { + LOGS_DEFAULT(WARNING) << "\n device_id: " << info_.device_id << "\n migraphx_fp16_enable: " << fp16_enable_ << "\n migraphx_fp8_enable: " << fp8_enable_ << "\n migraphx_int8_enable: " << int8_enable_ @@ -215,9 +266,6 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv << "\n migraphx_load_compiled_model_path: " << load_compiled_path_; } -MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { -} - std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo default_memory_info( [](OrtDevice::DeviceId device_id) { return CreateMIGraphXAllocator(device_id, onnxruntime::CUDA); }, info_.device_id); @@ -1128,9 +1176,9 @@ bool get_input_output_names(const GraphViewer& graph, bool load_precompiled_model(migraphx::program& prog, bool load_enable, std::string path) { try { if (load_enable) { - LOGS_DEFAULT(INFO) << "Attempting to load model at:" << path; + LOGS_DEFAULT(WARNING) << "Attempting to load model at:" << path; prog = migraphx::load(path.c_str()); - LOGS_DEFAULT(INFO) << "load model : Success"; + LOGS_DEFAULT(WARNING) << "load model : Success"; return true; } else { return false; @@ -1143,11 +1191,11 @@ bool load_precompiled_model(migraphx::program& prog, bool load_enable, std::stri void save_compiled_model(migraphx::program& prog, bool save_enable, std::string out_path) { if (save_enable) { - LOGS_DEFAULT(INFO) << "Model Save at " << out_path << ": Begin"; + LOGS_DEFAULT(WARNING) << "Model Save at " << out_path << ": Begin"; migraphx::file_options fo; fo.set_file_format("msgpack"); migraphx::save(prog, out_path.c_str(), fo); - LOGS_DEFAULT(INFO) << "Model Save: Complete"; + LOGS_DEFAULT(WARNING) << "Model Save: Complete"; } } @@ -1163,7 +1211,7 @@ void calibrate_and_quantize(migraphx::program& prog, std::unordered_map& dynamic_range_map) { // Read in the calibration data and map it to an migraphx paramater map for the calibration ops if ((int8_enable xor fp8_enable) && int8_calibration_cache_available) { - LOGS_DEFAULT(INFO) << "Quantizing input program"; + LOGS_DEFAULT(WARNING) << "Quantizing input program"; auto param_shapes = prog.get_parameter_shapes(); @@ -1175,37 +1223,39 @@ void calibrate_and_quantize(migraphx::program& prog, // perform static quantization on the programs if (int8_enable) { + LOGS_DEFAULT(WARNING) << "Quantizing input program to int8"; migraphx::quantize_int8_options quant_opts; quant_opts.add_calibration_data(quant_params); // specify thing we want to int8 quantize quant_opts.add_op_name("convolution"); quant_opts.add_op_name("dot"); migraphx::quantize_int8(prog, t, quant_opts); - LOGS_DEFAULT(WARNING) << "Quantizing input program to int8: Complete"; + LOGS_DEFAULT(WARNING) << "Quantizing int8: Complete"; } else if (fp8_enable) { + LOGS_DEFAULT(WARNING) << "Quantizing input program to fp8"; migraphx::quantize_fp8_options quant_opts; quant_opts.add_calibration_data(quant_params); migraphx::quantize_fp8(prog, t, quant_opts); - LOGS_DEFAULT(WARNING) << "Quantizing input program to fp8: Complete"; + LOGS_DEFAULT(WARNING) << "Quantizing fp8: Complete"; } } if (fp16_enable) { - LOGS_DEFAULT(INFO) << "Quantizing input program to fp16"; + LOGS_DEFAULT(WARNING) << "Quantizing input program to fp16"; migraphx::quantize_fp16(prog); - LOGS_DEFAULT(INFO) << "Quantizing input program to fp16: Complete"; + LOGS_DEFAULT(WARNING) << "Quantizing fp16: Complete"; } } void compile_program(migraphx::program& prog, const migraphx::target& t, bool exhaustive_tune) { - LOGS_DEFAULT(INFO) << "Model Compile: Begin"; + LOGS_DEFAULT(WARNING) << "Model Compile: Begin"; migraphx::compile_options co; co.set_fast_math(false); co.set_exhaustive_tune_flag(exhaustive_tune); prog.compile(t, co); - LOGS_DEFAULT(INFO) << "Model Compile: Complete"; + LOGS_DEFAULT(WARNING) << "Model Compile: Complete"; } Status MIGraphXExecutionProvider::Compile(const std::vector& fused_nodes, diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index be372b3012b36..ad241add2bbb2 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -62,6 +62,10 @@ class MIGraphXExecutionProvider : public IExecutionProvider { explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info); ~MIGraphXExecutionProvider(); + void get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info); + void get_flags_from_env(); + void print_migraphx_ep_flags(); + Status Sync() const override; Status OnRunStart(const onnxruntime::RunOptions& run_options) override; From 63dee58cf158200a74790c8684792a20d79edf3c Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 25 Dec 2024 00:16:57 +0000 Subject: [PATCH 09/11] Fix naming for save/load path varibles to be consistent with enable. --- .../core/providers/migraphx/migraphx_execution_provider.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index ad241add2bbb2..ca7a71b28e05d 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -24,9 +24,9 @@ static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_T static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH"; static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"; static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL"; -static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILE_PATH"; +static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILED_PATH"; static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL"; -static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILE_PATH"; +static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILED_PATH"; static const char kExhaustiveTune[] = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"; }; // namespace migraphx_env_vars From d1a26097b67e9c48f5cb2ebfbc5a61d8199b8f5a Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 25 Dec 2024 03:51:50 +0000 Subject: [PATCH 10/11] Print only env variables that are set as warnings need this so the user knows there's any of the environment variables running in the background to ensure proper consistently between runs. --- .../migraphx/migraphx_execution_provider.cc | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 7f362859609fb..5ef94cb1b9d1b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -157,26 +157,30 @@ void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecut } void MIGraphXExecutionProvider::get_flags_from_env() { + LOGS_DEFAULT(WARNING) << "\n[MIGraphX EP] MIGraphX ENV Override Variables Set:"; // whether fp16 is enable const std::string fp16_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP16Enable); if (!fp16_enable_env.empty()) { fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP16_ENABLE: " << fp16_enable_; } // whether fp8 quantization is enabled const std::string fp8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP8Enable); if (!fp8_enable_env.empty()) { fp8_enable_ = (std::stoi(fp8_enable_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP8_ENABLE: " << fp8_enable_; } // whether int8 is enabled const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable); if (!int8_enable_env.empty()) { int8_enable_ = (std::stoi(int8_enable_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_ENABLE: " << int8_enable_; } if (int8_enable_ and fp8_enable_) { - LOGS_DEFAULT(FATAL) << "MIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; + LOGS_DEFAULT(FATAL) << "\nMIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; } if (int8_enable_ || fp8_enable_) { @@ -184,11 +188,13 @@ void MIGraphXExecutionProvider::get_flags_from_env() { onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8CalibrationTableName); if (!int8_calibration_cache_name_env.empty()) { int8_calibration_cache_name_ = int8_calibration_cache_name_env; + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CALIBRATION_TABLE_NAME: " << int8_calibration_cache_name_; } const std::string cache_path = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kCachePath); if (!cache_path.empty()) { calibration_cache_path_ = cache_path; + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CACHE_PATH: " << calibration_cache_path_; } const std::string int8_use_native_migraphx_calibration_table_env = @@ -196,6 +202,8 @@ void MIGraphXExecutionProvider::get_flags_from_env() { if (!int8_use_native_migraphx_calibration_table_env.empty()) { int8_use_native_migraphx_calibration_table_ = (std::stoi(int8_use_native_migraphx_calibration_table_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE: " + << int8_use_native_migraphx_calibration_table_; } } @@ -208,7 +216,7 @@ void MIGraphXExecutionProvider::get_flags_from_env() { if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { const std::string calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { - throw std::runtime_error("ENV Failed to read INT8 calibration table " + calibration_cache_path); + throw std::runtime_error("ENV Failed to read calibration table " + calibration_cache_path); } } @@ -216,38 +224,40 @@ void MIGraphXExecutionProvider::get_flags_from_env() { const std::string save_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSaveCompiledModel); if (!save_comp_model_env.empty()) { save_compiled_model_ = (std::stoi(save_comp_model_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_SAVE_COMPILED_MODEL: " << save_compiled_model_; } const std::string save_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSavedModelPath); - if (save_compiled_model_ && !save_model_path_env.empty()) { save_compiled_path_ = save_model_path_env; + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_SAVE_COMPILED_PATH: " << save_compiled_path_; } const std::string load_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadCompiledModel); if (!load_comp_model_env.empty()) { load_compiled_model_ = (std::stoi(load_comp_model_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_LOAD_COMPILED_MODEL: " << load_compiled_model_; } const std::string load_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadModelPath); if (load_compiled_model_ && !load_model_path_env.empty()) { load_compiled_path_ = load_model_path_env; + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_LOAD_COMPILED_PATH: " << load_compiled_path_; } // dump unsupported ops const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::dumpModelOps); if (!dump_model_ops_env.empty()) { dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_DUMP_MODEL_OPS: " << dump_model_ops_; } // Allow for exhaustive tune during compile const std::string exhaustive_tune_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kExhaustiveTune); if (!exhaustive_tune_env.empty()) { exhaustive_tune_ = (std::stoi(exhaustive_tune_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_EXHAUSTIVE_TUNE_OPS: " << exhaustive_tune_; } - - LOGS_DEFAULT(WARNING) << "[MIGraphX EP] MIGraphX provider ENV Variables:"; - print_migraphx_ep_flags(); } void MIGraphXExecutionProvider::print_migraphx_ep_flags() { From 87f1f91020ab6d7fd6cc43b163d18667e3d4a925 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Fri, 3 Jan 2025 16:59:35 +0000 Subject: [PATCH 11/11] lintrunner pass --- .../python/onnxruntime_pybind_state.cc | 51 +++++++++++-------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 1862072c36e1f..a87b1b5b457ad 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1441,7 +1441,7 @@ void addGlobalMethods(py::module& m) { ORT_UNUSED_PARAMETER(algo); ORT_THROW("set_cudnn_conv_algo_search is not supported in ROCM"); #else - cudnn_conv_algo_search = algo; + cudnn_conv_algo_search = algo; #endif }); // TODO remove deprecated global config @@ -1452,7 +1452,7 @@ void addGlobalMethods(py::module& m) { ORT_UNUSED_PARAMETER(use_single_stream); ORT_THROW("set_do_copy_in_default_stream is not supported in ROCM"); #else - do_copy_in_default_stream = use_single_stream; + do_copy_in_default_stream = use_single_stream; #endif }); // TODO remove deprecated global config @@ -1817,10 +1817,10 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc") } ORT_THROW_IF_ERROR(options->value.AddExternalInitializers(names_ptrs, values_ptrs)); #else - ORT_UNUSED_PARAMETER(options); - ORT_UNUSED_PARAMETER(names); - ORT_UNUSED_PARAMETER(ort_values); - ORT_THROW("External initializers are not supported in this build."); + ORT_UNUSED_PARAMETER(options); + ORT_UNUSED_PARAMETER(names); + ORT_UNUSED_PARAMETER(ort_values); + ORT_THROW("External initializers are not supported in this build."); #endif }); @@ -1882,7 +1882,8 @@ including arg name, arg type (contains both type and shape).)pbdoc") return *(na.Type()); }, "node type") - .def("__str__", [](const onnxruntime::NodeArg& na) -> std::string { + .def( + "__str__", [](const onnxruntime::NodeArg& na) -> std::string { std::ostringstream res; res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape="; auto shape = na.Shape(); @@ -1909,7 +1910,8 @@ including arg name, arg type (contains both type and shape).)pbdoc") res << ")"; return std::string(res.str()); }, "converts the node into a readable string") - .def_property_readonly("shape", [](const onnxruntime::NodeArg& na) -> std::vector { + .def_property_readonly( + "shape", [](const onnxruntime::NodeArg& na) -> std::vector { auto shape = na.Shape(); std::vector arr; if (shape == nullptr || shape->dim_size() == 0) { @@ -2117,25 +2119,32 @@ including arg name, arg type (contains both type and shape).)pbdoc") .def_property_readonly("get_profiling_start_time_ns", [](const PyInferenceSession* sess) -> uint64_t { return sess->GetSessionHandle()->GetProfiling().GetStartTimeNs(); }) - .def("get_providers", [](const PyInferenceSession* sess) -> const std::vector& { return sess->GetSessionHandle()->GetRegisteredProviderTypes(); }, py::return_value_policy::reference_internal) - .def("get_provider_options", [](const PyInferenceSession* sess) -> const ProviderOptionsMap& { return sess->GetSessionHandle()->GetAllProviderOptions(); }, py::return_value_policy::reference_internal) - .def_property_readonly("session_options", [](const PyInferenceSession* sess) -> PySessionOptions* { + .def( + "get_providers", [](const PyInferenceSession* sess) -> const std::vector& { return sess->GetSessionHandle()->GetRegisteredProviderTypes(); }, py::return_value_policy::reference_internal) + .def( + "get_provider_options", [](const PyInferenceSession* sess) -> const ProviderOptionsMap& { return sess->GetSessionHandle()->GetAllProviderOptions(); }, py::return_value_policy::reference_internal) + .def_property_readonly( + "session_options", [](const PyInferenceSession* sess) -> PySessionOptions* { auto session_options = std::make_unique(); session_options->value = sess->GetSessionHandle()->GetSessionOptions(); return session_options.release(); }, py::return_value_policy::take_ownership) - .def_property_readonly("inputs_meta", [](const PyInferenceSession* sess) -> const std::vector& { + .def_property_readonly( + "inputs_meta", [](const PyInferenceSession* sess) -> const std::vector& { auto res = sess->GetSessionHandle()->GetModelInputs(); OrtPybindThrowIfError(res.first); return *(res.second); }, py::return_value_policy::reference_internal) - .def_property_readonly("outputs_meta", [](const PyInferenceSession* sess) -> const std::vector& { + .def_property_readonly( + "outputs_meta", [](const PyInferenceSession* sess) -> const std::vector& { auto res = sess->GetSessionHandle()->GetModelOutputs(); OrtPybindThrowIfError(res.first); return *(res.second); }, py::return_value_policy::reference_internal) - .def_property_readonly("overridable_initializers", [](const PyInferenceSession* sess) -> const std::vector& { + .def_property_readonly( + "overridable_initializers", [](const PyInferenceSession* sess) -> const std::vector& { auto res = sess->GetSessionHandle()->GetOverridableInitializers(); OrtPybindThrowIfError(res.first); return *(res.second); }, py::return_value_policy::reference_internal) - .def_property_readonly("model_meta", [](const PyInferenceSession* sess) -> const onnxruntime::ModelMetadata& { + .def_property_readonly( + "model_meta", [](const PyInferenceSession* sess) -> const onnxruntime::ModelMetadata& { auto res = sess->GetSessionHandle()->GetModelMetadata(); OrtPybindThrowIfError(res.first); return *(res.second); }, py::return_value_policy::reference_internal) @@ -2163,8 +2172,8 @@ including arg name, arg type (contains both type and shape).)pbdoc") return ret; #else - ORT_UNUSED_PARAMETER(sess); - ORT_THROW("TunableOp and get_tuning_results are not supported in this build."); + ORT_UNUSED_PARAMETER(sess); + ORT_THROW("TunableOp and get_tuning_results are not supported in this build."); #endif }) .def("set_tuning_results", [](PyInferenceSession* sess, py::list results, bool error_on_invalid) -> void { @@ -2195,10 +2204,10 @@ including arg name, arg type (contains both type and shape).)pbdoc") throw std::runtime_error("Error in execution: " + status.ErrorMessage()); } #else - ORT_UNUSED_PARAMETER(sess); - ORT_UNUSED_PARAMETER(results); - ORT_UNUSED_PARAMETER(error_on_invalid); - ORT_THROW("TunableOp and set_tuning_results are not supported in this build."); + ORT_UNUSED_PARAMETER(sess); + ORT_UNUSED_PARAMETER(results); + ORT_UNUSED_PARAMETER(error_on_invalid); + ORT_THROW("TunableOp and set_tuning_results are not supported in this build."); #endif });