diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 4ca2791e26ab9..7ba935633ca3c 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -614,6 +614,7 @@ typedef struct OrtTensorRTProviderOptions { typedef struct OrtMIGraphXProviderOptions { int device_id; // hip device id. int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true + int migraphx_fp8_enable; // MIGraphX FP8 precision. Default 0 = false, nonzero = true int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index fd36b8ae5f678..5ef94cb1b9d1b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -104,32 +104,97 @@ std::shared_ptr MIGraphXExecutionProvider::GetKernelRegistry() c MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_(info) { InitProviderOrtApi(); + get_flags_from_session_info(info); + metadef_id_generator_ = ModelMetadefIdGenerator::Create(); + get_flags_from_env(); +} + +MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { +} + +void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info) { // Set GPU device to be used HIP_CALL_THROW(hipSetDevice(info_.device_id)); t_ = migraphx::target(info.target_device.c_str()); + // Quantization + fp16_enable_ = info.fp16_enable; + fp8_enable_ = info.fp8_enable; + int8_enable_ = info.int8_enable; + + if (int8_enable_ and fp8_enable_) { + LOGS_DEFAULT(FATAL) << "MIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; + } + + if (int8_enable_ xor fp8_enable_) { + int8_calibration_cache_name_ = info.int8_calibration_table_name; + int8_use_native_migraphx_calibration_table_ = info.int8_use_native_calibration_table; + } + + if (int8_enable_ or fp8_enable_) { + int8_calibration_cache_available_ = !info.int8_calibration_table_name.empty(); + } + + // Load INT8 calibration table + std::unordered_map dynamic_range_map; + if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { + const std::string calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); + if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { + throw std::runtime_error("Session Failed to read INT8 calibration table " + calibration_cache_path); + } + } + + // Save/load migraphx compiled models + save_compiled_model_ = info.save_compiled_model; + save_compiled_path_ = info.save_model_file; + load_compiled_model_ = info.load_compiled_model; + load_compiled_path_ = info.load_model_file; + + exhaustive_tune_ = info.exhaustive_tune; + + LOGS_DEFAULT(WARNING) << "[MIGraphX EP] MIGraphX provider Session Options:"; + print_migraphx_ep_flags(); +} + +void MIGraphXExecutionProvider::get_flags_from_env() { + LOGS_DEFAULT(WARNING) << "\n[MIGraphX EP] MIGraphX ENV Override Variables Set:"; // whether fp16 is enable const std::string fp16_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP16Enable); if (!fp16_enable_env.empty()) { fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP16_ENABLE: " << fp16_enable_; + } + + // whether fp8 quantization is enabled + const std::string fp8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP8Enable); + if (!fp8_enable_env.empty()) { + fp8_enable_ = (std::stoi(fp8_enable_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP8_ENABLE: " << fp8_enable_; } // whether int8 is enabled const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable); if (!int8_enable_env.empty()) { int8_enable_ = (std::stoi(int8_enable_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_ENABLE: " << int8_enable_; } - if (int8_enable_) { + if (int8_enable_ and fp8_enable_) { + LOGS_DEFAULT(FATAL) << "\nMIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; + } + + if (int8_enable_ || fp8_enable_) { const std::string int8_calibration_cache_name_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8CalibrationTableName); if (!int8_calibration_cache_name_env.empty()) { int8_calibration_cache_name_ = int8_calibration_cache_name_env; + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CALIBRATION_TABLE_NAME: " << int8_calibration_cache_name_; } const std::string cache_path = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kCachePath); if (!cache_path.empty()) { calibration_cache_path_ = cache_path; + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CACHE_PATH: " << calibration_cache_path_; } const std::string int8_use_native_migraphx_calibration_table_env = @@ -137,19 +202,21 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv if (!int8_use_native_migraphx_calibration_table_env.empty()) { int8_use_native_migraphx_calibration_table_ = (std::stoi(int8_use_native_migraphx_calibration_table_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE: " + << int8_use_native_migraphx_calibration_table_; } } - if (int8_enable_) { + if (int8_enable_ or fp8_enable_) { int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty(); } // Load INT8 calibration table std::unordered_map dynamic_range_map; - if (int8_enable_ && int8_calibration_cache_available_) { + if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { const std::string calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { - throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path); + throw std::runtime_error("ENV Failed to read calibration table " + calibration_cache_path); } } @@ -157,55 +224,56 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv const std::string save_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSaveCompiledModel); if (!save_comp_model_env.empty()) { save_compiled_model_ = (std::stoi(save_comp_model_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_SAVE_COMPILED_MODEL: " << save_compiled_model_; } const std::string save_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSavedModelPath); - if (save_compiled_model_ && !save_model_path_env.empty()) { save_compiled_path_ = save_model_path_env; + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_SAVE_COMPILED_PATH: " << save_compiled_path_; } const std::string load_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadCompiledModel); if (!load_comp_model_env.empty()) { load_compiled_model_ = (std::stoi(load_comp_model_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_LOAD_COMPILED_MODEL: " << load_compiled_model_; } const std::string load_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadModelPath); if (load_compiled_model_ && !load_model_path_env.empty()) { load_compiled_path_ = load_model_path_env; + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_LOAD_COMPILED_PATH: " << load_compiled_path_; } // dump unsupported ops const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::dumpModelOps); if (!dump_model_ops_env.empty()) { dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_DUMP_MODEL_OPS: " << dump_model_ops_; } // Allow for exhaustive tune during compile const std::string exhaustive_tune_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kExhaustiveTune); if (!exhaustive_tune_env.empty()) { exhaustive_tune_ = (std::stoi(exhaustive_tune_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_EXHAUSTIVE_TUNE_OPS: " << exhaustive_tune_; } - - metadef_id_generator_ = ModelMetadefIdGenerator::Create(); - - LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: " - << "device_id: " << info_.device_id - << ", migraphx_fp16_enable: " << fp16_enable_ - << ", migraphx_int8_enable: " << int8_enable_ - << ", migraphx_int8_enable: " << int8_enable_ - << ", dump_model_ops: " << dump_model_ops_ - << ", exhaustive_tune: " << exhaustive_tune_ - << ", migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ - << ", int8_calibration_cache_available: " << int8_calibration_cache_available_ - << ", use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_ - << ", migraphx_save_compiled_model: " << save_compiled_model_ - << ", migraphx_save_compiled_model_path: " << save_compiled_path_ - << ", migraphx_load_compiled_model: " << load_compiled_model_ - << ", migraphx_load_compiled_model_path: " << load_compiled_path_; } -MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { +void MIGraphXExecutionProvider::print_migraphx_ep_flags() { + LOGS_DEFAULT(WARNING) << "\n device_id: " << info_.device_id + << "\n migraphx_fp16_enable: " << fp16_enable_ + << "\n migraphx_fp8_enable: " << fp8_enable_ + << "\n migraphx_int8_enable: " << int8_enable_ + << "\n dump_model_ops: " << dump_model_ops_ + << "\n exhaustive_tune: " << exhaustive_tune_ + << "\n migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ + << "\n int8_calibration_cache_available: " << int8_calibration_cache_available_ + << "\n use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_ + << "\n migraphx_save_compiled_model: " << save_compiled_model_ + << "\n migraphx_save_compiled_model_path: " << save_compiled_path_ + << "\n migraphx_load_compiled_model: " << load_compiled_model_ + << "\n migraphx_load_compiled_model_path: " << load_compiled_path_; } std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() { @@ -232,11 +300,17 @@ static bool IsTypeSupported(const NodeArg* node_arg) { switch (type_proto->tensor_type().elem_type()) { case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FNUZ: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2FNUZ: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32: @@ -261,6 +335,21 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: mgx_type = migraphx_shape_double_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ: + mgx_type = migraphx_shape_fp8e4m3fnuz_type; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: + mgx_type = migraphx_shape_fp8e4m3fn_type; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2: + mgx_type = migraphx_shape_fp8e5m2_type; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ: + mgx_type = migraphx_shape_fp8e5m2fnuz_type; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: + mgx_type = migraphx_shape_int8_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: mgx_type = migraphx_shape_int8_type; break; @@ -273,6 +362,9 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: mgx_type = migraphx_shape_int64_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4: + mgx_type = migraphx_shape_uint8_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: mgx_type = migraphx_shape_uint8_type; break; @@ -1021,7 +1113,7 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v if (dump_model_ops_) { LOGS_DEFAULT(INFO) << "============= Unsupported nodes ===================="; for (auto idx : unsupported_nodes) { - LOGS_DEFAULT(INFO) << graph_viewer.GetNode(idx)->OpType() << std::endl; + LOGS_DEFAULT(INFO) << graph_viewer.GetNode(idx)->OpType(); } LOGS_DEFAULT(INFO) << "************* Unsupported nodes ********************"; } @@ -1094,9 +1186,9 @@ bool get_input_output_names(const GraphViewer& graph, bool load_precompiled_model(migraphx::program& prog, bool load_enable, std::string path) { try { if (load_enable) { - LOGS_DEFAULT(INFO) << "Attempting to load model at:" << path; + LOGS_DEFAULT(WARNING) << "Attempting to load model at:" << path; prog = migraphx::load(path.c_str()); - LOGS_DEFAULT(INFO) << "load model : Success"; + LOGS_DEFAULT(WARNING) << "load model : Success"; return true; } else { return false; @@ -1109,14 +1201,73 @@ bool load_precompiled_model(migraphx::program& prog, bool load_enable, std::stri void save_compiled_model(migraphx::program& prog, bool save_enable, std::string out_path) { if (save_enable) { - LOGS_DEFAULT(INFO) << "Model Save at " << out_path << ": Begin" << std::endl; + LOGS_DEFAULT(WARNING) << "Model Save at " << out_path << ": Begin"; migraphx::file_options fo; fo.set_file_format("msgpack"); migraphx::save(prog, out_path.c_str(), fo); - LOGS_DEFAULT(INFO) << "Model Save: Complete" << std::endl; + LOGS_DEFAULT(WARNING) << "Model Save: Complete"; + } +} + +// Order matters here especially if the program uses mixed quantization +// Calibrate on full precision for int8/fp8 and then quantize down to fp16 +void calibrate_and_quantize(migraphx::program& prog, + const migraphx::target& t, + const migraphx::program_parameters quant_params, + bool fp16_enable, + bool int8_enable, + bool fp8_enable, + bool int8_calibration_cache_available, + std::unordered_map& dynamic_range_map) { + // Read in the calibration data and map it to an migraphx paramater map for the calibration ops + if ((int8_enable xor fp8_enable) && int8_calibration_cache_available) { + LOGS_DEFAULT(WARNING) << "Quantizing input program"; + + auto param_shapes = prog.get_parameter_shapes(); + + // Add all calibration data read in from int8 table + for (auto& [cal_key, cal_val] : dynamic_range_map) { + auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); + quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); + } + + // perform static quantization on the programs + if (int8_enable) { + LOGS_DEFAULT(WARNING) << "Quantizing input program to int8"; + migraphx::quantize_int8_options quant_opts; + quant_opts.add_calibration_data(quant_params); + // specify thing we want to int8 quantize + quant_opts.add_op_name("convolution"); + quant_opts.add_op_name("dot"); + migraphx::quantize_int8(prog, t, quant_opts); + LOGS_DEFAULT(WARNING) << "Quantizing int8: Complete"; + } else if (fp8_enable) { + LOGS_DEFAULT(WARNING) << "Quantizing input program to fp8"; + migraphx::quantize_fp8_options quant_opts; + quant_opts.add_calibration_data(quant_params); + migraphx::quantize_fp8(prog, t, quant_opts); + LOGS_DEFAULT(WARNING) << "Quantizing fp8: Complete"; + } + } + + if (fp16_enable) { + LOGS_DEFAULT(WARNING) << "Quantizing input program to fp16"; + migraphx::quantize_fp16(prog); + LOGS_DEFAULT(WARNING) << "Quantizing fp16: Complete"; } } +void compile_program(migraphx::program& prog, + const migraphx::target& t, + bool exhaustive_tune) { + LOGS_DEFAULT(WARNING) << "Model Compile: Begin"; + migraphx::compile_options co; + co.set_fast_math(false); + co.set_exhaustive_tune_flag(exhaustive_tune); + prog.compile(t, co); + LOGS_DEFAULT(WARNING) << "Model Compile: Complete"; +} + Status MIGraphXExecutionProvider::Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) { migraphx::onnx_options options; @@ -1157,44 +1308,11 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { LOGS_DEFAULT(INFO) << "No Input shapes detected quantizing model"; prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); + migraphx::program_parameters quant_params; - // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if (int8_enable_ && int8_calibration_cache_available_) { - LOGS_DEFAULT(INFO) << "Quantizing input program to int8" << std::endl; - migraphx::quantize_int8_options quant_opts; - migraphx::program_parameters quant_params; - - auto param_shapes = prog.get_parameter_shapes(); - - // Add all calibration data read in from int8 table - for (auto& [cal_key, cal_val] : dynamic_range_map_) { - auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); - quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); - } - quant_opts.add_calibration_data(quant_params); - - // specify thing we want to int8 quantize - quant_opts.add_op_name("convolution"); - quant_opts.add_op_name("dot"); - - // perform static quantization on the programs - migraphx::quantize_int8(prog, t_, quant_opts); - LOGS_DEFAULT(INFO) << "Quantizing input program to int8: Complete" << std::endl; - } - - if (fp16_enable_) { - LOGS_DEFAULT(INFO) << "Quantizing input program to fp16" << std::endl; - migraphx::quantize_fp16(prog); - LOGS_DEFAULT(INFO) << "Quantizing input program to fp16: Complete" << std::endl; - } - - migraphx::compile_options co; - co.set_fast_math(false); - co.set_exhaustive_tune_flag(exhaustive_tune_); - LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl; - prog.compile(t_, co); - LOGS_DEFAULT(INFO) << "Model Compile: Complete" << std::endl; - + calibrate_and_quantize(prog, t_, quant_params, fp16_enable_, int8_enable_, + fp8_enable_, int8_calibration_cache_available_, dynamic_range_map_); + compile_program(prog, t_, exhaustive_tune_); save_compiled_model(prog, save_compiled_model_, save_compiled_path_); } @@ -1216,7 +1334,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::unique_ptr p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, - map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_, + map_no_input_shape_[context->node_name], fp16_enable_, fp8_enable_, int8_enable_, int8_calibration_cache_available_, dynamic_range_map_, save_compiled_model_, save_compiled_path_, load_compiled_model_, load_compiled_path_, dump_model_ops_}; @@ -1241,6 +1359,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::onnx_options& cmp_options = mgx_state->options; bool& no_input_shape = mgx_state->no_input_shape; bool fp16_enable = mgx_state->fp16_enable; + bool fp8_enable = mgx_state->fp8_enable; bool int8_enable = mgx_state->int8_enable; bool int8_calibration_cache_available = mgx_state->int8_calibration_cache_available; @@ -1249,7 +1368,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& bool input_shape_match = true; migraphx::program_parameter_shapes param_shapes; if (no_input_shape) { - LOGS_DEFAULT(VERBOSE) << "Missing input shape setting input parameters again" << std::endl; + LOGS_DEFAULT(INFO) << "Missing input shape setting input parameters again"; for (auto& it : map_input_name_index) { auto& name = it.first; auto& index = it.second; @@ -1261,7 +1380,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& input_shape_match = false; } } else { - LOGS_DEFAULT(VERBOSE) << "Assigning inputs, and parameters from compiled model" << std::endl; + LOGS_DEFAULT(INFO) << "Assigning inputs, and parameters from compiled model"; param_shapes = prog.get_parameter_shapes(); auto prog_output_shapes = prog.get_output_shapes(); @@ -1296,18 +1415,13 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // re-compile the program if (!input_shape_match) { if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { - LOGS_DEFAULT(VERBOSE) << "No Input shapes mismatch detected. Recompiling" << std::endl; + LOGS_DEFAULT(INFO) << "No Input shapes mismatch detected. Recompiling"; cmp_options.set_external_data_path(model_path_.parent_path().string()); prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); + migraphx::program_parameters quant_params; - // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if (int8_enable && int8_calibration_cache_available) { - LOGS_DEFAULT(INFO) << "Quantize Int8: Begin" << std::endl; - migraphx::quantize_int8_options quant_opts; - migraphx::program_parameters quant_params; - + if ((int8_enable xor fp8_enable) and int8_calibration_cache_available) { auto param_shapes = prog.get_parameter_shapes(); - // Add input parameter data and the values they're set to for (auto&& name : param_shapes.names()) { if (map_input_name_index.count(name) > 0) { @@ -1326,34 +1440,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& quant_params.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } } - - // Add all calibration data read in from int8 table - for (auto& [cal_key, cal_val] : map_dynamic_range) { - auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); - quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); - } - quant_opts.add_calibration_data(quant_params); - // specify thing we want to int8 quantize - quant_opts.add_op_name("convolution"); - quant_opts.add_op_name("dot"); - - // perform static quantization on the programs - migraphx::quantize_int8(prog, t, quant_opts); - LOGS_DEFAULT(INFO) << "Quantize Int8: Completed" << std::endl; } - - if (fp16_enable) { - LOGS_DEFAULT(INFO) << "Quantize fp16: Begin" << std::endl; - migraphx::quantize_fp16(prog); - LOGS_DEFAULT(INFO) << "Quantize fp16: Completed" << std::endl; - } - - LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl; - migraphx::compile_options co; - co.set_fast_math(false); - co.set_exhaustive_tune_flag(exhaustive_tune_); - prog.compile(t, co); - + calibrate_and_quantize(prog, t, quant_params, fp16_enable, int8_enable, + fp8_enable, int8_calibration_cache_available, map_dynamic_range); + compile_program(prog, t, exhaustive_tune_); save_compiled_model(prog, mgx_state->save_compiled_mode, mgx_state->save_compiled_path); } @@ -1368,7 +1458,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (param_shapes.size() > 0) { for (auto&& name : param_shapes.names()) { if (map_input_name_index.count(name) > 0) { - LOGS_DEFAULT(INFO) << "Setting parameters for:" << name << std::endl; + LOGS_DEFAULT(INFO) << "Setting parameters for:" << name; auto input_tensor = ctx.GetInput(map_input_name_index[name]); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shape = tensor_info.GetShape(); @@ -1382,7 +1472,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; } - LOGS_DEFAULT(INFO) << "Writing Raw tensor data " << std::endl; + LOGS_DEFAULT(INFO) << "Writing Raw tensor data "; m.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 2be6c09551a71..ca7a71b28e05d 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -17,15 +17,16 @@ namespace onnxruntime { namespace migraphx_env_vars { static const char kFP16Enable[] = "ORT_MIGRAPHX_FP16_ENABLE"; +static const char kFP8Enable[] = "ORT_MIGRAPHX_FP8_ENABLE"; static const char kINT8Enable[] = "ORT_MIGRAPHX_INT8_ENABLE"; static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"; static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH"; static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"; static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL"; -static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILE_PATH"; +static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILED_PATH"; static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL"; -static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILE_PATH"; +static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILED_PATH"; static const char kExhaustiveTune[] = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"; }; // namespace migraphx_env_vars @@ -43,6 +44,7 @@ struct MIGraphXFuncState { OrtMutex* mgx_mu_ptr = nullptr; bool no_input_shape = false; bool fp16_enable = false; + bool fp8_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; std::unordered_map dynamic_range_map; @@ -60,6 +62,10 @@ class MIGraphXExecutionProvider : public IExecutionProvider { explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info); ~MIGraphXExecutionProvider(); + void get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info); + void get_flags_from_env(); + void print_migraphx_ep_flags(); + Status Sync() const override; Status OnRunStart(const onnxruntime::RunOptions& run_options) override; @@ -89,6 +95,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { private: MIGraphXExecutionProviderInfo info_; bool fp16_enable_ = false; + bool fp8_enable_ = false; bool int8_enable_ = false; std::string int8_calibration_cache_name_; bool int8_calibration_cache_available_ = false; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index 1f9a47d3ad87d..6537d1c12bc9c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -14,6 +14,7 @@ namespace migraphx { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; constexpr const char* kFp16Enable = "trt_fp16_enable"; +constexpr const char* kFp8Enable = "migx_fp8_enable"; constexpr const char* kInt8Enable = "migx_int8_enable"; constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name"; constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table"; @@ -43,6 +44,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions return Status::OK(); }) .AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable) + .AddAssignmentToReference(migraphx::provider_option_names::kFp8Enable, info.fp8_enable) .AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable) .AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model) .AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model) @@ -56,6 +58,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, + {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)}, {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)}, @@ -68,6 +71,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, + {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)}, {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)}, diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index b8bf86580f03d..554f77b6f7f58 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -16,6 +16,7 @@ struct MIGraphXExecutionProviderInfo { std::string target_device; OrtDevice::DeviceId device_id{0}; bool fp16_enable{false}; + bool fp8_enable{false}; bool int8_enable{false}; std::string int8_calibration_table_name{""}; bool int8_use_native_calibration_table{false}; diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 7b192b657b7cc..098512b2ce69c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -60,6 +60,7 @@ struct MIGraphX_Provider : Provider { info.device_id = static_cast(options.device_id); info.target_device = "gpu"; info.fp16_enable = options.migraphx_fp16_enable; + info.fp8_enable = options.migraphx_fp8_enable; info.exhaustive_tune = options.migraphx_exhaustive_tune; info.int8_enable = options.migraphx_int8_enable; info.int8_calibration_table_name = ""; @@ -85,6 +86,7 @@ struct MIGraphX_Provider : Provider { auto& migx_options = *reinterpret_cast(provider_options); migx_options.device_id = internal_options.device_id; migx_options.migraphx_fp16_enable = internal_options.fp16_enable; + migx_options.migraphx_fp8_enable = internal_options.fp8_enable; migx_options.migraphx_int8_enable = internal_options.int8_enable; migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index f31734bdfb805..a87b1b5b457ad 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -831,6 +831,7 @@ std::unique_ptr CreateExecutionProviderInstance( 0, 0, 0, + 0, nullptr, 1, "./compiled_model.mxr", @@ -854,6 +855,16 @@ std::unique_ptr CreateExecutionProviderInstance( "[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be" " 'True' or 'False'. Default value is 'False'.\n"); } + } else if (option.first == "migraphx_fp8_enable") { + if (option.second == "True" || option.second == "true") { + params.migraphx_fp8_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.migraphx_fp8_enable = false; + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migx_fp8_enable' should be" + " 'True' or 'False'. Default value is 'False'.\n"); + } } else if (option.first == "migraphx_int8_enable") { if (option.second == "True" || option.second == "true") { params.migraphx_int8_enable = true; @@ -1430,7 +1441,7 @@ void addGlobalMethods(py::module& m) { ORT_UNUSED_PARAMETER(algo); ORT_THROW("set_cudnn_conv_algo_search is not supported in ROCM"); #else - cudnn_conv_algo_search = algo; + cudnn_conv_algo_search = algo; #endif }); // TODO remove deprecated global config @@ -1441,7 +1452,7 @@ void addGlobalMethods(py::module& m) { ORT_UNUSED_PARAMETER(use_single_stream); ORT_THROW("set_do_copy_in_default_stream is not supported in ROCM"); #else - do_copy_in_default_stream = use_single_stream; + do_copy_in_default_stream = use_single_stream; #endif }); // TODO remove deprecated global config @@ -1806,10 +1817,10 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc") } ORT_THROW_IF_ERROR(options->value.AddExternalInitializers(names_ptrs, values_ptrs)); #else - ORT_UNUSED_PARAMETER(options); - ORT_UNUSED_PARAMETER(names); - ORT_UNUSED_PARAMETER(ort_values); - ORT_THROW("External initializers are not supported in this build."); + ORT_UNUSED_PARAMETER(options); + ORT_UNUSED_PARAMETER(names); + ORT_UNUSED_PARAMETER(ort_values); + ORT_THROW("External initializers are not supported in this build."); #endif }); @@ -1871,7 +1882,8 @@ including arg name, arg type (contains both type and shape).)pbdoc") return *(na.Type()); }, "node type") - .def("__str__", [](const onnxruntime::NodeArg& na) -> std::string { + .def( + "__str__", [](const onnxruntime::NodeArg& na) -> std::string { std::ostringstream res; res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape="; auto shape = na.Shape(); @@ -1898,7 +1910,8 @@ including arg name, arg type (contains both type and shape).)pbdoc") res << ")"; return std::string(res.str()); }, "converts the node into a readable string") - .def_property_readonly("shape", [](const onnxruntime::NodeArg& na) -> std::vector { + .def_property_readonly( + "shape", [](const onnxruntime::NodeArg& na) -> std::vector { auto shape = na.Shape(); std::vector arr; if (shape == nullptr || shape->dim_size() == 0) { @@ -2106,25 +2119,32 @@ including arg name, arg type (contains both type and shape).)pbdoc") .def_property_readonly("get_profiling_start_time_ns", [](const PyInferenceSession* sess) -> uint64_t { return sess->GetSessionHandle()->GetProfiling().GetStartTimeNs(); }) - .def("get_providers", [](const PyInferenceSession* sess) -> const std::vector& { return sess->GetSessionHandle()->GetRegisteredProviderTypes(); }, py::return_value_policy::reference_internal) - .def("get_provider_options", [](const PyInferenceSession* sess) -> const ProviderOptionsMap& { return sess->GetSessionHandle()->GetAllProviderOptions(); }, py::return_value_policy::reference_internal) - .def_property_readonly("session_options", [](const PyInferenceSession* sess) -> PySessionOptions* { + .def( + "get_providers", [](const PyInferenceSession* sess) -> const std::vector& { return sess->GetSessionHandle()->GetRegisteredProviderTypes(); }, py::return_value_policy::reference_internal) + .def( + "get_provider_options", [](const PyInferenceSession* sess) -> const ProviderOptionsMap& { return sess->GetSessionHandle()->GetAllProviderOptions(); }, py::return_value_policy::reference_internal) + .def_property_readonly( + "session_options", [](const PyInferenceSession* sess) -> PySessionOptions* { auto session_options = std::make_unique(); session_options->value = sess->GetSessionHandle()->GetSessionOptions(); return session_options.release(); }, py::return_value_policy::take_ownership) - .def_property_readonly("inputs_meta", [](const PyInferenceSession* sess) -> const std::vector& { + .def_property_readonly( + "inputs_meta", [](const PyInferenceSession* sess) -> const std::vector& { auto res = sess->GetSessionHandle()->GetModelInputs(); OrtPybindThrowIfError(res.first); return *(res.second); }, py::return_value_policy::reference_internal) - .def_property_readonly("outputs_meta", [](const PyInferenceSession* sess) -> const std::vector& { + .def_property_readonly( + "outputs_meta", [](const PyInferenceSession* sess) -> const std::vector& { auto res = sess->GetSessionHandle()->GetModelOutputs(); OrtPybindThrowIfError(res.first); return *(res.second); }, py::return_value_policy::reference_internal) - .def_property_readonly("overridable_initializers", [](const PyInferenceSession* sess) -> const std::vector& { + .def_property_readonly( + "overridable_initializers", [](const PyInferenceSession* sess) -> const std::vector& { auto res = sess->GetSessionHandle()->GetOverridableInitializers(); OrtPybindThrowIfError(res.first); return *(res.second); }, py::return_value_policy::reference_internal) - .def_property_readonly("model_meta", [](const PyInferenceSession* sess) -> const onnxruntime::ModelMetadata& { + .def_property_readonly( + "model_meta", [](const PyInferenceSession* sess) -> const onnxruntime::ModelMetadata& { auto res = sess->GetSessionHandle()->GetModelMetadata(); OrtPybindThrowIfError(res.first); return *(res.second); }, py::return_value_policy::reference_internal) @@ -2152,8 +2172,8 @@ including arg name, arg type (contains both type and shape).)pbdoc") return ret; #else - ORT_UNUSED_PARAMETER(sess); - ORT_THROW("TunableOp and get_tuning_results are not supported in this build."); + ORT_UNUSED_PARAMETER(sess); + ORT_THROW("TunableOp and get_tuning_results are not supported in this build."); #endif }) .def("set_tuning_results", [](PyInferenceSession* sess, py::list results, bool error_on_invalid) -> void { @@ -2184,10 +2204,10 @@ including arg name, arg type (contains both type and shape).)pbdoc") throw std::runtime_error("Error in execution: " + status.ErrorMessage()); } #else - ORT_UNUSED_PARAMETER(sess); - ORT_UNUSED_PARAMETER(results); - ORT_UNUSED_PARAMETER(error_on_invalid); - ORT_THROW("TunableOp and set_tuning_results are not supported in this build."); + ORT_UNUSED_PARAMETER(sess); + ORT_UNUSED_PARAMETER(results); + ORT_UNUSED_PARAMETER(error_on_invalid); + ORT_THROW("TunableOp and set_tuning_results are not supported in this build."); #endif }); diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 1feba20e32bbb..eecf194e7c9a1 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -76,6 +76,7 @@ std::unique_ptr DefaultMIGraphXExecutionProvider() { 0, 0, 0, + 0, nullptr, 1, "./compiled_model.mxr",