Skip to content

Commit

Permalink
More prints of lowered kernels (#201)
Browse files Browse the repository at this point in the history
* Even more prints

* fix

* stream as ptr

* fix

* fix

* fix

* more

* redef

* rename

* fix

* fix

* fix
  • Loading branch information
wsmoses authored Dec 24, 2024
1 parent fdcf401 commit b6d6563
Showing 1 changed file with 180 additions and 14 deletions.
194 changes: 180 additions & 14 deletions src/enzyme_ad/jax/Passes/LowerKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,18 +247,21 @@ void *CompileHostModule(std::string &key, mlir::ModuleOp modOp, bool run_init) {
}

auto ptr = (void *)EntrySym->getValue();
llvm::errs() << " entry ptr: " << ptr << "\n";

kernels[key] = ptr;

auto NVSym = JIT->lookup(LibA.get(), "nv_func_init");
if (!NVSym) {
llvm::errs() << " lookupError " << NVSym.takeError() << "\n";
return nullptr;
}
if (run_init) {
auto NVSym = JIT->lookup(LibA.get(), "nv_func_init");
if (!NVSym) {
llvm::errs() << " lookupError " << NVSym.takeError() << "\n";
return nullptr;
}

auto nvptr = (void *)NVSym->getValue();
auto nvptr = (void *)NVSym->getValue();

((void (*)())(nvptr))();
((void (*)())(nvptr))();
}

return ptr;
}
Expand All @@ -272,6 +275,8 @@ extern "C" void EnzymeGPUCustomCall(void *__restrict__ stream,
XlaCustomCallStatus *__restrict__ status) {
auto ptr = (void (*)(void *, void **))(opaqueptr[0]);
printf("ptr=%p\n", ptr);
printf("stream=%p\n", stream);
printf("bufferptr=%p\n", buffers);
printf("buffer[0]=%p\n", buffers[0]);
// auto ptr = (void(*)(void*, void**, size_t, size_t, size_t, size_t, size_t,
// size_t)) (opaqueptr[0][0]);
Expand Down Expand Up @@ -422,6 +427,19 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,

builder.setInsertionPointToEnd(&submod.getBodyRegion().front());

auto printfunc = builder.create<func::FuncOp>(loc, "printf", calleeType);
printfunc.setVisibility(SymbolTable::Visibility::Private);

LLVM::GlobalOp printStrStream;
{
std::string value = "found pointer [stream] = %p\n";
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
printStrStream = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strstream",
builder.getStringAttr(value + '\0'));
}

auto func = builder.create<func::FuncOp>(loc, "entry", calleeType);

auto &entryBlock = *func.addEntryBlock();
Expand Down Expand Up @@ -453,10 +471,19 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
arguments.push_back(ld);
}
auto dynshmem = builder.create<arith::ConstantIntOp>(loc, shmem, i32);

{
Value printargs1[] = {
builder.create<LLVM::AddressOfOp>(loc, printStrStream)->getResult(0),
stream};
builder.create<func::CallOp>(loc, printfunc, printargs1);
}

stream = builder
.create<UnrealizedConversionCastOp>(
loc, gpu::AsyncTokenType::get(stream.getContext()), stream)
->getResult(0);

builder.create<gpu::LaunchFuncOp>(loc, gpufunc, gridSize, blockSize, dynshmem,
arguments, stream.getType(),
ValueRange(stream));
Expand Down Expand Up @@ -498,6 +525,10 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
pm.run(submod);

OpBuilder builder(submod);

SymbolTable st2(submod);
auto print2 = st2.lookup<LLVM::LLVMFuncOp>("printf");

builder.setInsertionPointToStart(&submod.getBodyRegion().front());
auto ptrty = LLVM::LLVMPointerType::get(builder.getContext());
auto i64 = builder.getIntegerType(64);
Expand All @@ -517,7 +548,7 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
mlir::Type cutys[] = {ptrty, idx, idx, idx, idx, idx,
idx, i32, ptrty, ptrty, ptrty};

auto launch_ty = LLVM::LLVMFunctionType::get(voidty, cutys);
auto launch_ty = LLVM::LLVMFunctionType::get(i32, cutys);
LLVM::LLVMFuncOp launch =
builder.create<LLVM::LLVMFuncOp>(loc, "cuLaunchKernel", launch_ty);

Expand All @@ -536,6 +567,66 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
builder.getStringAttr(value + '\0'));
}

LLVM::GlobalOp printStrSet;
{
std::string value = "found pointer [set] = %p\n";
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
printStrSet = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strset",
builder.getStringAttr(value + '\0'));
}

LLVM::GlobalOp printStrGlob;
{
std::string value = "found pointer [glob] = %p\n";
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
printStrGlob = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strglob",
builder.getStringAttr(value + '\0'));
}

LLVM::GlobalOp printStrCu;
{
std::string value = "found pointer [cu] = %p\n";
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
printStrCu = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strcu",
builder.getStringAttr(value + '\0'));
}

LLVM::GlobalOp printStrMod;
{
std::string value = "found pointer mod = %p\n";
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
printStrMod = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strmod",
builder.getStringAttr(value + '\0'));
}

LLVM::GlobalOp printStrLdFunc;
{
std::string value = "found pointer ld func = %p\n";
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
printStrLdFunc = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal,
"strldfunc", builder.getStringAttr(value + '\0'));
}

LLVM::GlobalOp printStrLaunch;
{
std::string value = "found pointer launch = %p\n";
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
printStrLaunch = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal,
"strlaunch", builder.getStringAttr(value + '\0'));
}

builder.setInsertionPointToStart(&submod.getBodyRegion().front());

LLVM::LLVMFuncOp initfn = builder.create<LLVM::LLVMFuncOp>(
Expand All @@ -547,6 +638,16 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
builder.create<LLVM::GlobalCtorsOp>(loc, builder.getArrayAttr(funcs),
builder.getArrayAttr(idxs));

LLVM::GlobalOp printStrFunc;
{
std::string value = "found pointer func = %p\n";
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
printStrFunc = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strfunc",
builder.getStringAttr(value + '\0'));
}

LLVM::GlobalOp binary;
submod.walk([&](gpu::BinaryOp op) {
gpu::ObjectAttr object = getSelectedObject(op);
Expand Down Expand Up @@ -583,16 +684,28 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
auto addr_modbin = builder.create<LLVM::AddressOfOp>(loc, binary);
SmallVector<mlir::Value> modargs = {modptr->getResult(0),
addr_modbin->getResult(0)};

mlir::Value loadRes;
if (cuModuleLoadDataPtr) {
auto addr_glob_int = builder.create<LLVM::ConstantOp>(
loc, i64, builder.getI64IntegerAttr(cuModuleLoadDataPtr));
auto addr_glob =
builder.create<LLVM::IntToPtrOp>(loc, ptrty, addr_glob_int);
modargs.insert(modargs.begin(), addr_glob);
builder.create<LLVM::CallOp>(loc, modload_ty, modargs);
loadRes = builder.create<LLVM::CallOp>(loc, modload_ty, modargs)
->getResult(0);
} else {
builder.create<LLVM::CallOp>(loc, modload, modargs);
loadRes =
builder.create<LLVM::CallOp>(loc, modload, modargs)->getResult(0);
}
loadRes = builder.create<LLVM::IntToPtrOp>(loc, ptrty, loadRes);
{
Value printargs1[] = {
builder.create<LLVM::AddressOfOp>(loc, printStrMod)->getResult(0),
loadRes};
builder.create<LLVM::CallOp>(loc, print2, printargs1);
}

auto mod = builder.create<LLVM::LoadOp>(loc, ptrty, modptr);

auto addr_kernstr =
Expand All @@ -601,19 +714,45 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
SmallVector<mlir::Value> funcargs = {funcptr->getResult(0),
mod->getResult(0),
addr_kernstr->getResult(0)};
mlir::Value getRes;
if (cuModuleGetFunctionPtr) {
auto addr_glob_int = builder.create<LLVM::ConstantOp>(
loc, i64, builder.getI64IntegerAttr(cuModuleGetFunctionPtr));
auto addr_glob =
builder.create<LLVM::IntToPtrOp>(loc, ptrty, addr_glob_int);
funcargs.insert(funcargs.begin(), addr_glob);
builder.create<LLVM::CallOp>(loc, funcload_ty, funcargs);
getRes = builder.create<LLVM::CallOp>(loc, funcload_ty, funcargs)
->getResult(0);
} else {
builder.create<LLVM::CallOp>(loc, funcload, funcargs);
getRes = builder.create<LLVM::CallOp>(loc, funcload, funcargs)
->getResult(0);
}

getRes = builder.create<LLVM::IntToPtrOp>(loc, ptrty, getRes);
{
Value printargs1[] = {
builder.create<LLVM::AddressOfOp>(loc, printStrFunc)
->getResult(0),
getRes};
builder.create<LLVM::CallOp>(loc, print2, printargs1);
}

auto func = builder.create<LLVM::LoadOp>(loc, ptrty, funcptr);
{
Value printargs1[] = {
builder.create<LLVM::AddressOfOp>(loc, printStrLdFunc)
->getResult(0),
func};
builder.create<LLVM::CallOp>(loc, print2, printargs1);
}

auto addr_glob = builder.create<LLVM::AddressOfOp>(loc, glob);
{
Value printargs1[] = {
builder.create<LLVM::AddressOfOp>(loc, printStrSet)->getResult(0),
addr_glob};
builder.create<LLVM::CallOp>(loc, print2, printargs1);
}
builder.create<LLVM::StoreOp>(loc, func, addr_glob);
builder.create<LLVM::ReturnOp>(loc, ValueRange());
}
Expand All @@ -639,15 +778,42 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
params,
builder.create<LLVM::ZeroOp>(loc, ptrty)};

{
Value printargs1[] = {
builder.create<LLVM::AddressOfOp>(loc, printStrGlob)
->getResult(0),
addr_glob};
builder.create<LLVM::CallOp>(loc, print2, printargs1);
}

{
Value printargs1[] = {
builder.create<LLVM::AddressOfOp>(loc, printStrCu)->getResult(0),
cufunc};
builder.create<LLVM::CallOp>(loc, print2, printargs1);
}

mlir::Value callRes;
if (cuLaunchKernelPtr) {
auto addr_glob_int = builder.create<LLVM::ConstantOp>(
loc, i64, builder.getI64IntegerAttr(cuLaunchKernelPtr));
auto addr_glob =
builder.create<LLVM::IntToPtrOp>(loc, ptrty, addr_glob_int);
args.insert(args.begin(), addr_glob);
builder.create<LLVM::CallOp>(loc, launch_ty, args);
callRes =
builder.create<LLVM::CallOp>(loc, launch_ty, args)->getResult(0);
} else {
builder.create<LLVM::CallOp>(loc, launch, args);
callRes =
builder.create<LLVM::CallOp>(loc, launch, args)->getResult(0);
}

callRes = builder.create<LLVM::IntToPtrOp>(loc, ptrty, callRes);
{
Value printargs1[] = {
builder.create<LLVM::AddressOfOp>(loc, printStrLaunch)
->getResult(0),
callRes};
builder.create<LLVM::CallOp>(loc, print2, printargs1);
}

op.erase();
Expand Down

0 comments on commit b6d6563

Please sign in to comment.