From b6d6563aa3a3050474a4250bf18322f7ebf0b486 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 23 Dec 2024 21:57:34 -0500 Subject: [PATCH] More prints of lowered kernels (#201) * Even more prints * fix * stream as ptr * fix * fix * fix * more * redef * rename * fix * fix * fix --- src/enzyme_ad/jax/Passes/LowerKernel.cpp | 194 +++++++++++++++++++++-- 1 file changed, 180 insertions(+), 14 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/LowerKernel.cpp b/src/enzyme_ad/jax/Passes/LowerKernel.cpp index baac840c..c960ffe9 100644 --- a/src/enzyme_ad/jax/Passes/LowerKernel.cpp +++ b/src/enzyme_ad/jax/Passes/LowerKernel.cpp @@ -247,18 +247,21 @@ void *CompileHostModule(std::string &key, mlir::ModuleOp modOp, bool run_init) { } auto ptr = (void *)EntrySym->getValue(); + llvm::errs() << " entry ptr: " << ptr << "\n"; kernels[key] = ptr; - auto NVSym = JIT->lookup(LibA.get(), "nv_func_init"); - if (!NVSym) { - llvm::errs() << " lookupError " << NVSym.takeError() << "\n"; - return nullptr; - } + if (run_init) { + auto NVSym = JIT->lookup(LibA.get(), "nv_func_init"); + if (!NVSym) { + llvm::errs() << " lookupError " << NVSym.takeError() << "\n"; + return nullptr; + } - auto nvptr = (void *)NVSym->getValue(); + auto nvptr = (void *)NVSym->getValue(); - ((void (*)())(nvptr))(); + ((void (*)())(nvptr))(); + } return ptr; } @@ -272,6 +275,8 @@ extern "C" void EnzymeGPUCustomCall(void *__restrict__ stream, XlaCustomCallStatus *__restrict__ status) { auto ptr = (void (*)(void *, void **))(opaqueptr[0]); printf("ptr=%p\n", ptr); + printf("stream=%p\n", stream); + printf("bufferptr=%p\n", buffers); printf("buffer[0]=%p\n", buffers[0]); // auto ptr = (void(*)(void*, void**, size_t, size_t, size_t, size_t, size_t, // size_t)) (opaqueptr[0][0]); @@ -422,6 +427,19 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, builder.setInsertionPointToEnd(&submod.getBodyRegion().front()); + auto printfunc = builder.create(loc, "printf", calleeType); + printfunc.setVisibility(SymbolTable::Visibility::Private); + + LLVM::GlobalOp printStrStream; + { + std::string value = "found pointer [stream] = %p\n"; + auto type = LLVM::LLVMArrayType::get( + mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1); + printStrStream = builder.create( + loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strstream", + builder.getStringAttr(value + '\0')); + } + auto func = builder.create(loc, "entry", calleeType); auto &entryBlock = *func.addEntryBlock(); @@ -453,10 +471,19 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, arguments.push_back(ld); } auto dynshmem = builder.create(loc, shmem, i32); + + { + Value printargs1[] = { + builder.create(loc, printStrStream)->getResult(0), + stream}; + builder.create(loc, printfunc, printargs1); + } + stream = builder .create( loc, gpu::AsyncTokenType::get(stream.getContext()), stream) ->getResult(0); + builder.create(loc, gpufunc, gridSize, blockSize, dynshmem, arguments, stream.getType(), ValueRange(stream)); @@ -498,6 +525,10 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, pm.run(submod); OpBuilder builder(submod); + + SymbolTable st2(submod); + auto print2 = st2.lookup("printf"); + builder.setInsertionPointToStart(&submod.getBodyRegion().front()); auto ptrty = LLVM::LLVMPointerType::get(builder.getContext()); auto i64 = builder.getIntegerType(64); @@ -517,7 +548,7 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, mlir::Type cutys[] = {ptrty, idx, idx, idx, idx, idx, idx, i32, ptrty, ptrty, ptrty}; - auto launch_ty = LLVM::LLVMFunctionType::get(voidty, cutys); + auto launch_ty = LLVM::LLVMFunctionType::get(i32, cutys); LLVM::LLVMFuncOp launch = builder.create(loc, "cuLaunchKernel", launch_ty); @@ -536,6 +567,66 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, builder.getStringAttr(value + '\0')); } + LLVM::GlobalOp printStrSet; + { + std::string value = "found pointer [set] = %p\n"; + auto type = LLVM::LLVMArrayType::get( + mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1); + printStrSet = builder.create( + loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strset", + builder.getStringAttr(value + '\0')); + } + + LLVM::GlobalOp printStrGlob; + { + std::string value = "found pointer [glob] = %p\n"; + auto type = LLVM::LLVMArrayType::get( + mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1); + printStrGlob = builder.create( + loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strglob", + builder.getStringAttr(value + '\0')); + } + + LLVM::GlobalOp printStrCu; + { + std::string value = "found pointer [cu] = %p\n"; + auto type = LLVM::LLVMArrayType::get( + mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1); + printStrCu = builder.create( + loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strcu", + builder.getStringAttr(value + '\0')); + } + + LLVM::GlobalOp printStrMod; + { + std::string value = "found pointer mod = %p\n"; + auto type = LLVM::LLVMArrayType::get( + mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1); + printStrMod = builder.create( + loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strmod", + builder.getStringAttr(value + '\0')); + } + + LLVM::GlobalOp printStrLdFunc; + { + std::string value = "found pointer ld func = %p\n"; + auto type = LLVM::LLVMArrayType::get( + mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1); + printStrLdFunc = builder.create( + loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, + "strldfunc", builder.getStringAttr(value + '\0')); + } + + LLVM::GlobalOp printStrLaunch; + { + std::string value = "found pointer launch = %p\n"; + auto type = LLVM::LLVMArrayType::get( + mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1); + printStrLaunch = builder.create( + loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, + "strlaunch", builder.getStringAttr(value + '\0')); + } + builder.setInsertionPointToStart(&submod.getBodyRegion().front()); LLVM::LLVMFuncOp initfn = builder.create( @@ -547,6 +638,16 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, builder.create(loc, builder.getArrayAttr(funcs), builder.getArrayAttr(idxs)); + LLVM::GlobalOp printStrFunc; + { + std::string value = "found pointer func = %p\n"; + auto type = LLVM::LLVMArrayType::get( + mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1); + printStrFunc = builder.create( + loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strfunc", + builder.getStringAttr(value + '\0')); + } + LLVM::GlobalOp binary; submod.walk([&](gpu::BinaryOp op) { gpu::ObjectAttr object = getSelectedObject(op); @@ -583,16 +684,28 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, auto addr_modbin = builder.create(loc, binary); SmallVector modargs = {modptr->getResult(0), addr_modbin->getResult(0)}; + + mlir::Value loadRes; if (cuModuleLoadDataPtr) { auto addr_glob_int = builder.create( loc, i64, builder.getI64IntegerAttr(cuModuleLoadDataPtr)); auto addr_glob = builder.create(loc, ptrty, addr_glob_int); modargs.insert(modargs.begin(), addr_glob); - builder.create(loc, modload_ty, modargs); + loadRes = builder.create(loc, modload_ty, modargs) + ->getResult(0); } else { - builder.create(loc, modload, modargs); + loadRes = + builder.create(loc, modload, modargs)->getResult(0); + } + loadRes = builder.create(loc, ptrty, loadRes); + { + Value printargs1[] = { + builder.create(loc, printStrMod)->getResult(0), + loadRes}; + builder.create(loc, print2, printargs1); } + auto mod = builder.create(loc, ptrty, modptr); auto addr_kernstr = @@ -601,19 +714,45 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, SmallVector funcargs = {funcptr->getResult(0), mod->getResult(0), addr_kernstr->getResult(0)}; + mlir::Value getRes; if (cuModuleGetFunctionPtr) { auto addr_glob_int = builder.create( loc, i64, builder.getI64IntegerAttr(cuModuleGetFunctionPtr)); auto addr_glob = builder.create(loc, ptrty, addr_glob_int); funcargs.insert(funcargs.begin(), addr_glob); - builder.create(loc, funcload_ty, funcargs); + getRes = builder.create(loc, funcload_ty, funcargs) + ->getResult(0); } else { - builder.create(loc, funcload, funcargs); + getRes = builder.create(loc, funcload, funcargs) + ->getResult(0); + } + + getRes = builder.create(loc, ptrty, getRes); + { + Value printargs1[] = { + builder.create(loc, printStrFunc) + ->getResult(0), + getRes}; + builder.create(loc, print2, printargs1); } + auto func = builder.create(loc, ptrty, funcptr); + { + Value printargs1[] = { + builder.create(loc, printStrLdFunc) + ->getResult(0), + func}; + builder.create(loc, print2, printargs1); + } auto addr_glob = builder.create(loc, glob); + { + Value printargs1[] = { + builder.create(loc, printStrSet)->getResult(0), + addr_glob}; + builder.create(loc, print2, printargs1); + } builder.create(loc, func, addr_glob); builder.create(loc, ValueRange()); } @@ -639,15 +778,42 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, params, builder.create(loc, ptrty)}; + { + Value printargs1[] = { + builder.create(loc, printStrGlob) + ->getResult(0), + addr_glob}; + builder.create(loc, print2, printargs1); + } + + { + Value printargs1[] = { + builder.create(loc, printStrCu)->getResult(0), + cufunc}; + builder.create(loc, print2, printargs1); + } + + mlir::Value callRes; if (cuLaunchKernelPtr) { auto addr_glob_int = builder.create( loc, i64, builder.getI64IntegerAttr(cuLaunchKernelPtr)); auto addr_glob = builder.create(loc, ptrty, addr_glob_int); args.insert(args.begin(), addr_glob); - builder.create(loc, launch_ty, args); + callRes = + builder.create(loc, launch_ty, args)->getResult(0); } else { - builder.create(loc, launch, args); + callRes = + builder.create(loc, launch, args)->getResult(0); + } + + callRes = builder.create(loc, ptrty, callRes); + { + Value printargs1[] = { + builder.create(loc, printStrLaunch) + ->getResult(0), + callRes}; + builder.create(loc, print2, printargs1); } op.erase();