Skip to content

Commit

Permalink
Feature/oro 0 flexible rtc error handling cherrypick (#48)
Browse files Browse the repository at this point in the history
* add a handler for RTC load failure case on cuda.

* [ORO-0] add a handler for RTC load failure case on hip.
  • Loading branch information
AtsushiYoshimura0302 authored Feb 21, 2023
1 parent 2842b92 commit 7d41ed0
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 100 deletions.
64 changes: 46 additions & 18 deletions Orochi/Orochi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,35 +81,63 @@ int oroInitialize( oroApi api, oroU32 flags )
s_api = api;
int e = 0;
s_loadedApis = 0;
if( (api & ORO_API_CUDA) == ORO_API_CUDA )

if( api & ORO_API_CUDA )
{
e = cuewInit( CUEW_INIT_CUDA | CUEW_INIT_NVRTC );
if( e == 0 )
s_loadedApis |= ORO_API_CUDA | ORO_API_CUDADRIVER | ORO_API_CUDARTC;
}
if ((s_loadedApis & ORO_API_CUDA) == 0) {
if (api & ORO_API_CUDADRIVER)
cuuint32_t flag = 0;
if( api & ORO_API_CUDADRIVER )
{
flag |= CUEW_INIT_CUDA;
}
if( api & ORO_API_CUDARTC )
{
flag |= CUEW_INIT_NVRTC;
}

int resultDriver, resultRtc;
cuewInit( &resultDriver, &resultRtc, flag );

if( resultDriver == CUEW_SUCCESS )
{
s_loadedApis |= ORO_API_CUDADRIVER;
}
if( resultRtc == CUEW_SUCCESS )
{
cuuint32_t cuewInitFlags = CUEW_INIT_CUDA;
if ( api & ORO_API_CUDARTC ) cuewInitFlags |= CUEW_INIT_NVRTC;
e = cuewInit( cuewInitFlags );
if( e == 0 )
{
s_loadedApis |= ORO_API_CUDADRIVER;
if ( api & ORO_API_CUDARTC ) s_loadedApis |= ORO_API_CUDARTC;
}
s_loadedApis |= ORO_API_CUDARTC;
}
}
if( api & ORO_API_HIP )
{
e = hipewInit( HIPEW_INIT_HIP );
if( e == 0 )
s_loadedApis |= ORO_API_HIP;
hipuint32_t flag = 0;
if( api & ORO_API_HIPDRIVER )
{
flag |= HIPEW_INIT_HIPDRIVER;
}
if( api & ORO_API_HIPRTC )
{
flag |= HIPEW_INIT_HIPRTC;
}

int resultDriver, resultRtc;
hipewInit( &resultDriver, &resultRtc, flag );

if( resultDriver == HIPEW_SUCCESS )
{
s_loadedApis |= ORO_API_HIPDRIVER;
}
if( resultRtc == HIPEW_SUCCESS )
{
s_loadedApis |= ORO_API_HIPRTC;
}
}
if( s_loadedApis == 0 )
return ORO_ERROR_OPEN_FAILED;
return ORO_SUCCESS;
}
oroApi oroLoadedAPI()
{
return (oroApi)s_loadedApis;
}
oroApi oroGetCurAPI(oroU32 flags)
{
return s_api;
Expand Down
9 changes: 6 additions & 3 deletions Orochi/Orochi.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
enum oroApi
{
ORO_API_AUTOMATIC = 1<<0,
ORO_API_HIP = 1<<1,
ORO_API_CUDADRIVER = 1<<2,
ORO_API_CUDARTC = 1<<3,
ORO_API_HIPDRIVER = 1 << 1,
ORO_API_HIPRTC = 1 << 2,
ORO_API_HIP = ORO_API_HIPDRIVER | ORO_API_HIPRTC,
ORO_API_CUDADRIVER = 1 << 3,
ORO_API_CUDARTC = 1 << 4,
ORO_API_CUDA = ORO_API_CUDADRIVER | ORO_API_CUDARTC,
};

Expand Down Expand Up @@ -762,6 +764,7 @@ enum {


int oroInitialize( oroApi api, oroU32 flags );
oroApi oroLoadedAPI();
oroApi oroGetCurAPI( oroU32 flags );
void* oroGetRawCtx( oroCtx ctx );
oroError oroCtxCreateFromRaw( oroCtx* ctxOut, oroApi api, void* ctxIn );
Expand Down
3 changes: 2 additions & 1 deletion contrib/cuew/include/cuew.h
Original file line number Diff line number Diff line change
Expand Up @@ -1826,11 +1826,12 @@ enum {
CUEW_SUCCESS = 0,
CUEW_ERROR_OPEN_FAILED = -1,
CUEW_ERROR_ATEXIT_FAILED = -2,
CUEW_NOT_INITIALIZED = -3,
};

enum { CUEW_INIT_CUDA = 1, CUEW_INIT_NVRTC = 2 };

int cuewInit(cuuint32_t flags);
void cuewInit( int* resultDriver, int* resultRtc, cuuint32_t flags );
const char *cuewErrorString(CUresult result);
const char *cuewCompilerPath(void);
int cuewCompilerVersion(void);
Expand Down
18 changes: 5 additions & 13 deletions contrib/cuew/src/cuew.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -696,25 +696,17 @@ static int cuewNvrtcInit(void)
return result;
}

int cuewInit(cuuint32_t flags)
void cuewInit( int* resultDriver, int* resultRtc, cuuint32_t flags )
{
int result = CUEW_SUCCESS;
*resultDriver = CUEW_NOT_INITIALIZED;
*resultRtc = CUEW_NOT_INITIALIZED;

if (flags & CUEW_INIT_CUDA) {
result = cuewCudaInit();
if (result != CUEW_SUCCESS) {
return result;
}
*resultDriver = cuewCudaInit();
}

if (flags & CUEW_INIT_NVRTC) {
result = cuewNvrtcInit();
if (result != CUEW_SUCCESS) {
return result;
}
*resultRtc = cuewNvrtcInit();
}

return result;
}

const char *cuewErrorString(CUresult result)
Expand Down
6 changes: 4 additions & 2 deletions contrib/hipew/include/hipew.h
Original file line number Diff line number Diff line change
Expand Up @@ -1457,13 +1457,15 @@ enum {
HIPEW_ERROR_OPEN_FAILED = -1,
HIPEW_ERROR_ATEXIT_FAILED = -2,
HIPEW_ERROR_OLD_DRIVER = -3,
HIPEW_NOT_INITIALIZED = -4,
};

enum {
HIPEW_INIT_HIP = 1,
HIPEW_INIT_HIPDRIVER = 1 << 0,
HIPEW_INIT_HIPRTC = 1 << 1,
};

int hipewInit(hipuint32_t flags);
void hipewInit( int* resultDriver, int* resultRtc, hipuint32_t flags );
const char *hipewErrorString(hipError_t result);
const char *hipewCompilerPath(void);
int hipewCompilerVersion(void);
Expand Down
125 changes: 62 additions & 63 deletions contrib/hipew/src/hipew.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ static int hipewHasOldDriver(const char *hip_path) {
}
#endif

static int hipewHipInit(void) {
void hipewInit( int* resultDriver, int* resultRtc, hipuint32_t flags )
{
/* Library paths. */
#ifdef _WIN32
/* Expected in C:/Windows/System32 or similar, no path needed. */
Expand All @@ -283,26 +284,34 @@ static int hipewHipInit(void) {
const char* hiprtc_paths[] = { "/opt/rocm/hip/lib/libhiprtc.so", NULL };
#endif
static int initialized = 0;
static int result = 0;
int error;
static int s_resultDriver = 0;
static int s_resultRtc = 0;

if (initialized) {
return result;
*resultDriver = s_resultDriver;
*resultRtc = s_resultRtc;
return;
}

initialized = 1;

error = atexit(hipewHipExit);
int error = atexit( hipewHipExit );
if (error) {
result = HIPEW_ERROR_ATEXIT_FAILED;
return result;
s_resultDriver = HIPEW_ERROR_ATEXIT_FAILED;
s_resultRtc = HIPEW_NOT_INITIALIZED;
*resultDriver = s_resultDriver;
*resultRtc = s_resultRtc;
return;
}

#ifdef _WIN32
/* Test for driver version. */
if(hipewHasOldDriver(hip_paths[0])) {
result = HIPEW_ERROR_OLD_DRIVER;
return result;
s_resultDriver = HIPEW_ERROR_OLD_DRIVER;
s_resultRtc = HIPEW_NOT_INITIALIZED;
*resultDriver = s_resultDriver;
*resultRtc = s_resultRtc;
return;
}
#endif

Expand All @@ -311,8 +320,11 @@ static int hipewHipInit(void) {
hiprtc_lib = dynamic_library_open_find(hiprtc_paths);

if (hip_lib == NULL) {
result = HIPEW_ERROR_OPEN_FAILED;
return result;
s_resultDriver = HIPEW_ERROR_ATEXIT_FAILED;
s_resultRtc = HIPEW_NOT_INITIALIZED;
*resultDriver = s_resultDriver;
*resultRtc = s_resultRtc;
return;
}

/* Fetch all function pointers. */
Expand Down Expand Up @@ -441,65 +453,52 @@ static int hipewHipInit(void) {
HIP_LIBRARY_FIND_CHECKED(hipImportExternalMemory);
HIP_LIBRARY_FIND_CHECKED(hipExternalMemoryGetMappedBuffer);
HIP_LIBRARY_FIND_CHECKED(hipDestroyExternalMemory);
if(hiprtc_lib)
// HIP_LIBRARY_FIND_CHECKED(hipImportExternalSemaphore);
// HIP_LIBRARY_FIND_CHECKED(hipDestroyExternalSemaphore);
// HIP_LIBRARY_FIND_CHECKED(hipSignalExternalSemaphoresAsync);
// HIP_LIBRARY_FIND_CHECKED(hipWaitExternalSemaphoresAsync);

s_resultDriver = HIPEW_SUCCESS;
*resultDriver = s_resultDriver;

if( ( flags & HIPEW_INIT_HIPRTC ) == 0 )
{
HIPRTC_LIBRARY_FIND_CHECKED(hiprtcGetErrorString);
HIPRTC_LIBRARY_FIND_CHECKED(hiprtcAddNameExpression);
HIPRTC_LIBRARY_FIND_CHECKED(hiprtcCompileProgram);
HIPRTC_LIBRARY_FIND_CHECKED(hiprtcCreateProgram);
HIPRTC_LIBRARY_FIND_CHECKED(hiprtcDestroyProgram);
HIPRTC_LIBRARY_FIND_CHECKED(hiprtcGetLoweredName);
HIPRTC_LIBRARY_FIND_CHECKED(hiprtcGetProgramLog);
HIPRTC_LIBRARY_FIND_CHECKED(hiprtcGetProgramLogSize);
HIPRTC_LIBRARY_FIND_CHECKED(hiprtcGetCode);
HIPRTC_LIBRARY_FIND_CHECKED(hiprtcGetBitcodeSize);
HIPRTC_LIBRARY_FIND_CHECKED(hiprtcGetBitcode);
HIPRTC_LIBRARY_FIND_CHECKED(hiprtcGetCodeSize);
HIPRTC_LIBRARY_FIND_CHECKED( hiprtcLinkCreate );
HIPRTC_LIBRARY_FIND_CHECKED( hiprtcLinkAddFile );
HIPRTC_LIBRARY_FIND_CHECKED( hiprtcLinkAddData );
HIPRTC_LIBRARY_FIND_CHECKED( hiprtcLinkComplete );
HIPRTC_LIBRARY_FIND_CHECKED( hiprtcLinkDestroy );
s_resultRtc = HIPEW_NOT_INITIALIZED;
*resultRtc = s_resultRtc;
return;
}
else

auto rtcLib = hiprtc_lib ? hiprtc_lib : hip_lib;
_LIBRARY_FIND( rtcLib, hiprtcGetErrorString );
if( hiprtcGetErrorString )
{
HIP_LIBRARY_FIND_CHECKED(hiprtcGetErrorString);
HIP_LIBRARY_FIND_CHECKED(hiprtcAddNameExpression);
HIP_LIBRARY_FIND_CHECKED(hiprtcCompileProgram);
HIP_LIBRARY_FIND_CHECKED(hiprtcCreateProgram);
HIP_LIBRARY_FIND_CHECKED(hiprtcDestroyProgram);
HIP_LIBRARY_FIND_CHECKED(hiprtcGetLoweredName);
HIP_LIBRARY_FIND_CHECKED(hiprtcGetProgramLog);
HIP_LIBRARY_FIND_CHECKED(hiprtcGetProgramLogSize);
HIP_LIBRARY_FIND_CHECKED(hiprtcGetCode);
HIP_LIBRARY_FIND_CHECKED(hiprtcGetCodeSize);
HIP_LIBRARY_FIND_CHECKED( hiprtcLinkCreate );
HIP_LIBRARY_FIND_CHECKED( hiprtcLinkAddFile );
HIP_LIBRARY_FIND_CHECKED( hiprtcLinkAddData );
HIP_LIBRARY_FIND_CHECKED( hiprtcLinkComplete );
HIP_LIBRARY_FIND_CHECKED( hiprtcLinkDestroy );
_LIBRARY_FIND_CHECKED( rtcLib, hiprtcAddNameExpression );
_LIBRARY_FIND_CHECKED( rtcLib, hiprtcCompileProgram );
_LIBRARY_FIND_CHECKED( rtcLib, hiprtcCreateProgram );
_LIBRARY_FIND_CHECKED( rtcLib, hiprtcDestroyProgram );
_LIBRARY_FIND_CHECKED( rtcLib, hiprtcGetLoweredName );
_LIBRARY_FIND_CHECKED( rtcLib, hiprtcGetProgramLog );
_LIBRARY_FIND_CHECKED( rtcLib, hiprtcGetProgramLogSize );
_LIBRARY_FIND_CHECKED( rtcLib, hiprtcGetCode );
// _LIBRARY_FIND_CHECKED( rtcLib, hiprtcGetBitcodeSize );
// _LIBRARY_FIND_CHECKED( rtcLib, hiprtcGetBitcode );
_LIBRARY_FIND_CHECKED( rtcLib, hiprtcGetCodeSize );
_LIBRARY_FIND_CHECKED( rtcLib, hiprtcLinkCreate );
_LIBRARY_FIND_CHECKED( rtcLib, hiprtcLinkAddFile );
_LIBRARY_FIND_CHECKED( rtcLib, hiprtcLinkAddData );
_LIBRARY_FIND_CHECKED( rtcLib, hiprtcLinkComplete );
_LIBRARY_FIND_CHECKED( rtcLib, hiprtcLinkDestroy );

s_resultRtc = HIPEW_SUCCESS;
*resultRtc = s_resultRtc;
}

result = HIPEW_SUCCESS;
return result;
}



int hipewInit(hipuint32_t flags) {
int result = HIPEW_SUCCESS;

if (flags & HIPEW_INIT_HIP) {
result = hipewHipInit();
if (result != HIPEW_SUCCESS) {
return result;
}
else
{
s_resultRtc = HIPEW_ERROR_OPEN_FAILED;
*resultRtc = s_resultRtc;
}

return result;
}


const char *hipewErrorString(hipError_t result) {
switch (result) {
case hipSuccess: return "No errors";
Expand Down

0 comments on commit 7d41ed0

Please sign in to comment.