Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Precompiled Kernels in Orochi #37

Merged
merged 7 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 49 additions & 42 deletions Orochi/OrochiUtils.cpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
#include <Orochi/OrochiUtils.h>
#include <string>
#include <string.h>
#include <iostream>
#include <fstream>
#include <codecvt>
#include <fstream>
#include <iostream>
#include <string.h>
#include <string>

#if defined( _WIN32 )
#define NOMINMAX
#include <Windows.h>
#else
#include <errno.h>
#include <sys/stat.h>
#include <locale>
#include <sys/stat.h>
#endif

inline std::wstring utf8_to_wstring( const std::string& str )
Expand Down Expand Up @@ -111,8 +111,7 @@ class FileStat

struct OrochiUtilsImpl
{
static
bool readSourceCode( const std::string& path, std::string& sourceCode, std::vector<std::string>* includes )
static bool readSourceCode( const std::string& path, std::string& sourceCode, std::vector<std::string>* includes )
{
std::fstream f( path );
if( f.is_open() )
Expand Down Expand Up @@ -175,7 +174,7 @@ struct OrochiUtilsImpl
return hash;
};

auto hashString =[&]( const char* ss, const size_t size, char buf[9] )
auto hashString = [&]( const char* ss, const size_t size, char buf[9] )
{
const unsigned int hash = hashBin( ss, size );

Expand Down Expand Up @@ -222,8 +221,7 @@ struct OrochiUtilsImpl
deviceName = deviceName.substr( 0, deviceName.find( ":" ) );
binFileName = cacheDirectory + "/"s + moduleHash + "-"s + optionHash + ".v."s + deviceName + "."s + driverVersion + "_"s + std::to_string( 8 * sizeof( void* ) ) + ".bin"s;
}
static
bool isFileUpToDate( const char* binaryFileName, const char* srcFileName )
static bool isFileUpToDate( const char* binaryFileName, const char* srcFileName )
{
FileStat b( binaryFileName );

Expand Down Expand Up @@ -373,49 +371,34 @@ struct OrochiUtilsImpl
return 0;
}

static std::string getCacheName( const char* path, const char* kernelname )
{
std::string a( path );
a += kernelname;
return a;
}
static std::string getCacheName( const std::string& path, const std::string& kernelname ) noexcept { return path + kernelname; }
};

OrochiUtils::OrochiUtils()
{
m_cacheDirectory = "./cache/";
}
OrochiUtils::OrochiUtils() { m_cacheDirectory = "./cache/"; }

OrochiUtils::~OrochiUtils()
{
}
OrochiUtils::~OrochiUtils() {}

bool OrochiUtils::readSourceCode( const std::string& path, std::string& sourceCode, std::vector<std::string>* includes )
{
return OrochiUtilsImpl::readSourceCode( path, sourceCode, includes );
}
bool OrochiUtils::readSourceCode( const std::string& path, std::string& sourceCode, std::vector<std::string>* includes ) { return OrochiUtilsImpl::readSourceCode( path, sourceCode, includes ); }

oroFunction OrochiUtils::getFunctionFromFile( oroDevice device, const char* path, const char* funcName, std::vector<const char*>* optsIn )
{
{
std::lock_guard<std::recursive_mutex> lock( m_mutex );

const std::string cacheName = OrochiUtilsImpl::getCacheName( path, funcName );
if( m_kernelMap.find( cacheName.c_str() ) != m_kernelMap.end() )
{
return m_kernelMap[ cacheName ];
return m_kernelMap[cacheName];
}

std::string source;
if( !OrochiUtilsImpl::readSourceCode( path, source, 0 ) )
return 0;
if( !OrochiUtilsImpl::readSourceCode( path, source, 0 ) ) return 0;

oroFunction f = getFunction( device, source.c_str(), path, funcName, optsIn );
m_kernelMap[cacheName] = f;
return f;
}

oroFunction OrochiUtils::getFunctionFromString( oroDevice device, const char* source, const char* path, const char* funcName, std::vector<const char*>* optsIn,
int numHeaders, const char** headers, const char** includeNames )
oroFunction OrochiUtils::getFunctionFromString( oroDevice device, const char* source, const char* path, const char* funcName, std::vector<const char*>* optsIn, int numHeaders, const char** headers, const char** includeNames )
{
std::lock_guard<std::recursive_mutex> lock( m_mutex );

Expand All @@ -428,9 +411,33 @@ oroFunction OrochiUtils::getFunctionFromString( oroDevice device, const char* so
m_kernelMap[cacheName] = f;
return f;
}

oroFunction OrochiUtils::getFunction( oroDevice device, const char* code, const char* path, const char* funcName, std::vector<const char*>* optsIn,
int numHeaders, const char** headers, const char** includeNames )

oroFunction OrochiUtils::getFunctionFromPrecompiledBinary( const std::string& path, const std::string& funcName )
{
std::lock_guard<std::recursive_mutex> lock( m_mutex );

const std::string cacheName = OrochiUtilsImpl::getCacheName( path, funcName );
if( m_kernelMap.find( cacheName.c_str() ) != m_kernelMap.end() )
{
return m_kernelMap[cacheName];
}

std::ifstream instream( path, std::ios::in | std::ios::binary );
std::vector<char> binary( ( std::istreambuf_iterator<char>( instream ) ), std::istreambuf_iterator<char>() );

oroModule module;
oroFunction functionOut{};
oroError e = oroModuleLoadData( &module, binary.data() );
OROASSERT( e == oroSuccess, 0 );

e = oroModuleGetFunction( &functionOut, module, funcName.c_str() );
OROASSERT( e == oroSuccess, 0 );

m_kernelMap[cacheName] = functionOut;
return functionOut;
}

oroFunction OrochiUtils::getFunction( oroDevice device, const char* code, const char* path, const char* funcName, std::vector<const char*>* optsIn, int numHeaders, const char** headers, const char** includeNames )
{
std::lock_guard<std::recursive_mutex> lock( m_mutex );

Expand All @@ -451,13 +458,13 @@ oroFunction OrochiUtils::getFunction( oroDevice device, const char* code, const
std::string cacheFile;
{
std::string o;
for(int i=0; i<opts.size(); i++)
for( int i = 0; i < opts.size(); i++ )
o.append( opts[i] );
OrochiUtilsImpl::getCacheFileName( device, path, funcName, o.c_str(), cacheFile, m_cacheDirectory );
}
if( OrochiUtilsImpl::isFileUpToDate( cacheFile.c_str(), path ) )
{
//load cache
// load cache
OrochiUtilsImpl::loadCacheFileToBinary( cacheFile, codec );
}
else
Expand Down Expand Up @@ -489,7 +496,7 @@ oroFunction OrochiUtils::getFunction( oroDevice device, const char* code, const
e = orortcDestroyProgram( &prog );
OROASSERT( e == ORORTC_SUCCESS, 0 );

//store cache
// store cache
OrochiUtilsImpl::createDirectory( m_cacheDirectory.c_str() );
OrochiUtilsImpl::cacheBinaryToFile( codec, cacheFile );
}
Expand All @@ -509,7 +516,7 @@ void OrochiUtils::getData( oroDevice device, const char* code, const char* path,

std::string tmp = "--gpu-architecture=";

if( oroGetCurAPI(0) == ORO_API_HIP )
if( oroGetCurAPI( 0 ) == ORO_API_HIP )
{
oroDeviceProp props;
oroGetDeviceProperties( &props, device );
Expand Down Expand Up @@ -554,7 +561,7 @@ void OrochiUtils::getData( oroDevice device, const char* code, const char* path,
return;
}

void OrochiUtils::getProgram( oroDevice device, const char* code, const char* path, std::vector<const char*>* optsIn, const char* funcName, orortcProgram *prog )
void OrochiUtils::getProgram( oroDevice device, const char* code, const char* path, std::vector<const char*>* optsIn, const char* funcName, orortcProgram* prog )
{
std::vector<const char*> opts;
opts.push_back( "-std=c++17" );
Expand Down Expand Up @@ -598,7 +605,7 @@ void OrochiUtils::getProgram( oroDevice device, const char* code, const char* pa
return;
}

void OrochiUtils::launch1D( oroFunction func, int nx, const void** args, int wgSize, unsigned int sharedMemBytes, oroStream stream )
void OrochiUtils::launch1D( oroFunction func, int nx, const void** args, int wgSize, unsigned int sharedMemBytes, oroStream stream )
{
int4 tpb = { wgSize, 1, 0 };
int4 nb = { ( nx + tpb.x - 1 ) / tpb.x, 1, 0 };
Expand Down
52 changes: 34 additions & 18 deletions Orochi/OrochiUtils.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
#pragma once
#include <Orochi/Orochi.h>
#include <vector>
#include <unordered_map>
#include <string>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>

#if defined( GNUC )
#include <signal.h>
#endif

#if defined(_WIN32)
#define OROASSERT(x, y) if(!(x)) {__debugbreak();}
template<typename T, typename U>
constexpr void OROASSERT( T&& exp, [[maybe_unused]] U&& placeholder ) noexcept
{
if( static_cast<bool>( std::forward<T>( exp ) ) != true )
{

#if defined( _WIN32 )
__debugbreak();
#elif defined( GNUC )
raise( SIGTRAP );
#else
#define OROASSERT(x, y) if(!(x)) {;}
;
#endif
}
}

class OrochiUtils
{
Expand All @@ -22,11 +36,11 @@ class OrochiUtils
OrochiUtils();
~OrochiUtils();

oroFunction getFunctionFromPrecompiledBinary( const std::string& path, const std::string& funcName );

oroFunction getFunctionFromFile( oroDevice device, const char* path, const char* funcName, std::vector<const char*>* opts );
oroFunction getFunctionFromString( oroDevice device, const char* source, const char* path, const char* funcName, std::vector<const char*>* opts,
int numHeaders, const char** headers, const char** includeNames );
oroFunction getFunction( oroDevice device, const char* code, const char* path, const char* funcName, std::vector<const char*>* opts,
int numHeaders = 0, const char** headers = 0, const char** includeNames = 0 );
oroFunction getFunctionFromString( oroDevice device, const char* source, const char* path, const char* funcName, std::vector<const char*>* opts, int numHeaders, const char** headers, const char** includeNames );
oroFunction getFunction( oroDevice device, const char* code, const char* path, const char* funcName, std::vector<const char*>* opts, int numHeaders = 0, const char** headers = 0, const char** includeNames = 0 );

static bool readSourceCode( const std::string& path, std::string& sourceCode, std::vector<std::string>* includes = 0 );
static void getData( oroDevice device, const char* code, const char* path, std::vector<const char*>* opts, std::vector<char>& dst );
Expand All @@ -41,15 +55,18 @@ class OrochiUtils
}

template<typename T>
static void free( T* ptr ) { oroFree( (oroDeviceptr)ptr ); }
static void free( T* ptr )
{
oroFree( (oroDeviceptr)ptr );
}

static void memset( void* ptr, int val, size_t n )
{
oroError e = oroMemset( (oroDeviceptr)ptr, val, n );
static void memset( void* ptr, int val, size_t n )
{
oroError e = oroMemset( (oroDeviceptr)ptr, val, n );
OROASSERT( e == oroSuccess, 0 );
}

static void memsetAsync( void* ptr, int val, size_t n, oroStream stream )
static void memsetAsync( void* ptr, int val, size_t n, oroStream stream )
{
oroError e = oroMemsetD8Async( (oroDeviceptr)ptr, val, n, stream );
OROASSERT( e == oroSuccess, 0 );
Expand Down Expand Up @@ -97,14 +114,13 @@ class OrochiUtils
OROASSERT( e == oroSuccess, 0 );
}

static
void waitForCompletion( oroStream stream = 0 )
static void waitForCompletion( oroStream stream = 0 )
{
auto e = oroStreamSynchronize( stream );
OROASSERT( e == oroSuccess, 0 );
}

public:
public:
std::string m_cacheDirectory;
std::recursive_mutex m_mutex;
std::unordered_map<std::string, oroFunction> m_kernelMap;
Expand Down
42 changes: 36 additions & 6 deletions ParallelPrimitives/RadixSort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,21 @@
#include <numeric>

#if defined( ORO_PP_LOAD_FROM_STRING )

// Note: the include order must be in this particular form.
// clang-format off
#include <ParallelPrimitives/cache/Kernels.h>

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you use the automatic include sorting? If I remember correctly, I had some issue with baking when I used the automatic include sorting. I think it must be in this order:
#include <ParallelPrimitives/cache/Kernels.h>
#include <ParallelPrimitives/cache/KernelArgs.h>

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add

// clang-format off
...
// clang-format on

to turn off formatting here

#include <ParallelPrimitives/cache/KernelArgs.h>
// clang-format on
#endif

namespace
{

#if defined( ORO_PRECOMPILED )
constexpr auto useBitCode = true;
#else
constexpr auto useBitCode = false;
#endif

void printKernelInfo( oroFunction func )
{
Expand Down Expand Up @@ -76,9 +84,22 @@ void RadixSort::compileKernels( oroDevice device, OrochiUtils& oroutils, const s
const auto currentKernelPath{ ( kernelPath == "" ) ? defaultKernelPath : kernelPath };
const auto currentIncludeDir{ ( includeDir == "" ) ? defaultIncludeDir : includeDir };

if( m_flags == Flag::LOG )
std::string binaryPath{};
if constexpr( useBitCode )
{
std::cout << "compiling kernels at path : " << currentKernelPath << " in : " << currentIncludeDir << '\n';
const bool isAmd = oroGetCurAPI( 0 ) == ORO_API_HIP;
binaryPath = isAmd ? "../bitcodes/oro_compiled_kernels.hipfb" : "../bitcodes/oro_compiled_kernels.fatbin";
if( m_flags == Flag::LOG )
{
std::cout << "loading pre-compiled kernels at path : " << binaryPath << '\n';
}
}
else
{
if( m_flags == Flag::LOG )
{
std::cout << "compiling kernels at path : " << currentKernelPath << " in : " << currentIncludeDir << '\n';
}
}

const auto includeArg{ "-I" + currentIncludeDir };
Expand All @@ -100,10 +121,19 @@ void RadixSort::compileKernels( oroDevice device, OrochiUtils& oroutils, const s
for( const auto& record : records )
{
#if defined( ORO_PP_LOAD_FROM_STRING )
oroFunctions[record.kernelType] = oroutils.getFunctionFromString( device, hip_RadixSortKernels, currentKernelPath.c_str(), record.kernelName.c_str(), &opts,
1, hip::RadixSortKernelsArgs, hip::RadixSortKernelsIncludes );
oroFunctions[record.kernelType] = oroutils.getFunctionFromString( device, hip_RadixSortKernels, currentKernelPath.c_str(), record.kernelName.c_str(), &opts, 1, hip::RadixSortKernelsArgs, hip::RadixSortKernelsIncludes );
#else
oroFunctions[record.kernelType] = oroutils.getFunctionFromFile( device, currentKernelPath.c_str(), record.kernelName.c_str(), &opts );

if constexpr( useBitCode )
{
oroFunctions[record.kernelType] = oroutils.getFunctionFromPrecompiledBinary( binaryPath.c_str(), record.kernelName.c_str() );
}
else
{

oroFunctions[record.kernelType] = oroutils.getFunctionFromFile( device, currentKernelPath.c_str(), record.kernelName.c_str(), &opts );
}

#endif
if( m_flags == Flag::LOG )
{
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ See more in the [sample application](./Test/main.cpp).
Run premake.

```
./tools/premake5/win/premake5.exe vs2019
./tools/premake5/win/premake5.exe vs2022
```
Note: add the option `--precompiled` to enable precompiled bitcode

Test is a minimum application.

Expand Down
1 change: 1 addition & 0 deletions bitcodes/generate_bitcode.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
hipcc -O3 -std=c++17 --cuda-device-only --offload-arch=gfx1030 --offload-arch=gfx1031 --offload-arch=gfx1032 --offload-arch=gfx1033 --offload-arch=gfx1034 --offload-arch=gfx1035 --offload-arch=gfx1036 --offload-arch=gfx1010 --offload-arch=gfx1011 --offload-arch=gfx1012 --offload-arch=gfx1013 --offload-arch=gfx900 --offload-arch=gfx906 --genco -I../ oro_compiled_kernels.cpp -o oro_compiled_kernels.hipfb
1 change: 1 addition & 0 deletions bitcodes/generate_bitcode_nvidia.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
nvcc -O3 -std=c++17 -ccbin="C:\Program Files (x86)\Microsoft Visual Studio\2019\Professional\VC\Tools\MSVC\14.29.30133\bin\Hostx86\x64" -fatbin -arch=all -I../ -x cu oro_compiled_kernels.cpp -o oro_compiled_kernels.fatbin
7 changes: 7 additions & 0 deletions bitcodes/oro_compiled_kernels.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#if defined( __CUDACC__ )
#include <cuda_runtime.h>
#include <cmath>
#else
#include <hip/hip_runtime.h>
#endif
#include "../ParallelPrimitives/RadixSortKernels.h"
Binary file added bitcodes/oro_compiled_kernels.hipfb
Binary file not shown.
Loading