Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add migx ep fp8 int4 #78

Open
wants to merge 10 commits into
base: rocm6.3_internal_testing
Choose a base branch
from
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,7 @@ typedef struct OrtTensorRTProviderOptions {
typedef struct OrtMIGraphXProviderOptions {
int device_id; // hip device id.
int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true
int migraphx_fp8_enable; // MIGraphX FP8 precision. Default 0 = false, nonzero = true
int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true
int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true
const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name
Expand Down
300 changes: 195 additions & 105 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@ namespace onnxruntime {

namespace migraphx_env_vars {
static const char kFP16Enable[] = "ORT_MIGRAPHX_FP16_ENABLE";
static const char kFP8Enable[] = "ORT_MIGRAPHX_FP8_ENABLE";
static const char kINT8Enable[] = "ORT_MIGRAPHX_INT8_ENABLE";
static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS";
static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME";
static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH";
static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE";
static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL";
static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILE_PATH";
static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILED_PATH";
static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL";
static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILE_PATH";
static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILED_PATH";
static const char kExhaustiveTune[] = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE";

}; // namespace migraphx_env_vars
Expand All @@ -43,6 +44,7 @@ struct MIGraphXFuncState {
OrtMutex* mgx_mu_ptr = nullptr;
bool no_input_shape = false;
bool fp16_enable = false;
bool fp8_enable = false;
bool int8_enable = false;
bool int8_calibration_cache_available = false;
std::unordered_map<std::string, float> dynamic_range_map;
Expand All @@ -60,6 +62,10 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info);
~MIGraphXExecutionProvider();

void get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info);
void get_flags_from_env();
void print_migraphx_ep_flags();

Status Sync() const override;

Status OnRunStart(const onnxruntime::RunOptions& run_options) override;
Expand Down Expand Up @@ -89,6 +95,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
private:
MIGraphXExecutionProviderInfo info_;
bool fp16_enable_ = false;
bool fp8_enable_ = false;
bool int8_enable_ = false;
std::string int8_calibration_cache_name_;
bool int8_calibration_cache_available_ = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace migraphx {
namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
constexpr const char* kFp16Enable = "trt_fp16_enable";
constexpr const char* kFp8Enable = "migx_fp8_enable";
constexpr const char* kInt8Enable = "migx_int8_enable";
constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name";
constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table";
Expand Down Expand Up @@ -43,6 +44,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions
return Status::OK();
})
.AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kFp8Enable, info.fp8_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model)
.AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model)
Expand All @@ -56,6 +58,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE
const ProviderOptions options{
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)},
{migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)},
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
Expand All @@ -68,6 +71,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap
const ProviderOptions options{
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)},
{migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)},
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ struct MIGraphXExecutionProviderInfo {
std::string target_device;
OrtDevice::DeviceId device_id{0};
bool fp16_enable{false};
bool fp8_enable{false};
bool int8_enable{false};
std::string int8_calibration_table_name{""};
bool int8_use_native_calibration_table{false};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ struct MIGraphX_Provider : Provider {
info.device_id = static_cast<OrtDevice::DeviceId>(options.device_id);
info.target_device = "gpu";
info.fp16_enable = options.migraphx_fp16_enable;
info.fp8_enable = options.migraphx_fp8_enable;
info.exhaustive_tune = options.migraphx_exhaustive_tune;
info.int8_enable = options.migraphx_int8_enable;
info.int8_calibration_table_name = "";
Expand All @@ -85,6 +86,7 @@ struct MIGraphX_Provider : Provider {
auto& migx_options = *reinterpret_cast<OrtMIGraphXProviderOptions*>(provider_options);
migx_options.device_id = internal_options.device_id;
migx_options.migraphx_fp16_enable = internal_options.fp16_enable;
migx_options.migraphx_fp8_enable = internal_options.fp8_enable;
migx_options.migraphx_int8_enable = internal_options.int8_enable;
migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune;

Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
0,
0,
0,
0,
nullptr,
1,
"./compiled_model.mxr",
Expand All @@ -854,6 +855,16 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
"[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_fp8_enable") {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp8_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_fp8_enable = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_fp8_enable' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_int8_enable") {
if (option.second == "True" || option.second == "true") {
params.migraphx_int8_enable = true;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/util/default_providers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ std::unique_ptr<IExecutionProvider> DefaultMIGraphXExecutionProvider() {
0,
0,
0,
0,
nullptr,
1,
"./compiled_model.mxr",
Expand Down
Loading