diff --git a/lib/Analysis/Alias.cpp b/lib/Analysis/Alias.cpp index 8a1a66ba4fb4..db00c16d1600 100644 --- a/lib/Analysis/Alias.cpp +++ b/lib/Analysis/Alias.cpp @@ -30,7 +30,6 @@ void SharedMemoryAliasAnalysis::visitOperation( // XXX(Keren): the following ops are always aliasing for now if (isa(op)) { - llvm::errs() << "ExtractSliceOp, TransOp, ExtractMBarrierOp\n"; // extract_slice %src // trans %src aliasInfo = AliasInfo(operands[0]->getValue()); diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index d487fc9e767d..53f23ea3a434 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1330,8 +1330,12 @@ struct InsertSliceAsyncV2OpConversion Value pred = icmp_eq(threadId, i32_val(0)); auto mask = adaptor.getMask(); - if (mask) + if (mask) { + // TODO(thomas): What is the right implementation for this case? + assert(mask.getType().isInteger(1) && + "need to implement cases with tensor mask"); pred = rewriter.create(loc, pred, mask); + } Value mcastMask = getMCastMask(sharedLayout, rewriter, loc, clusterCTAId); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index 932a5b025ac6..33d80b92644f 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Analysis/Allocation.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Membar.h" @@ -62,6 +63,28 @@ class TritonLLVMFunctionConversionTarget : public ConversionTarget { } }; +class FoldSplatMaskInInsertAsync : public mlir::RewritePattern { + +public: + FoldSplatMaskInInsertAsync(mlir::MLIRContext *context) + : mlir::RewritePattern( + triton::nvidia_gpu::InsertSliceAsyncV2Op::getOperationName(), 1, + context) {} + + LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto insertOp = cast(op); + auto splatOp = insertOp.getMask().getDefiningOp(); + if (!splatOp) + return failure(); + rewriter.updateRootInPlace(insertOp, [&]() { + insertOp.getMaskMutable().assign(splatOp->getOperand(0)); + }); + return mlir::success(); + } +}; + struct ReturnOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -152,6 +175,18 @@ struct FuncOpConversion : public FuncOpConversionBase { if (!allocation.isRoot(funcOp)) amendedFuncOp = amendFuncOp(funcOp, rewriter); + // Collect TMA informations. + unsigned numTMALoad = 0; + funcOp.walk( + [&numTMALoad](triton::nvidia_gpu::InsertSliceAsyncV2Op insertSliceOp) { + numTMALoad++; + }); + unsigned numTMAStore = 0; + funcOp.walk([&numTMAStore](triton::nvidia_gpu::StoreAsyncOp storeAsyncOp) { + numTMAStore++; + }); + unsigned numTMA = numTMALoad + numTMAStore; + auto newFuncOp = convertFuncOpToLLVMFuncOp(amendedFuncOp, rewriter); if (!newFuncOp) { return failure(); @@ -177,17 +212,6 @@ struct FuncOpConversion : public FuncOpConversionBase { // The call graph is updated by mapping the old function to the new one. allocation.mapFuncOp(funcOp, newFuncOp); - unsigned numTMALoad = 0; - funcOp.walk( - [&numTMALoad](triton::nvidia_gpu::InsertSliceAsyncV2Op insertSliceOp) { - numTMALoad++; - }); - unsigned numTMAStore = 0; - funcOp.walk([&numTMAStore](triton::nvidia_gpu::StoreAsyncOp storeAsyncOp) { - numTMAStore++; - }); - unsigned numTMA = numTMALoad + numTMAStore; - // Append arguments to receive TMADesc in global memory in the runtime auto i8PtrTy = LLVM::LLVMPointerType::get( this->getTypeConverter()->convertType(rewriter.getI8Type()), 1); @@ -384,6 +408,14 @@ class ConvertTritonGPUToLLVM } }); + // Hack: cleanup + { + RewritePatternSet patterns(context); + patterns.add(context); + if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed()) + signalPassFailure(); + } + // Lower functions { mlir::LowerToLLVMOptions option(context); diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index 9c5f99a0a60c..1a1daca03acb 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -41,7 +41,27 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( Type TritonGPUToLLVMTypeConverter::convertTritonPointerType( triton::PointerType type) { - // Recursively translate pointee type + auto ctx = type.getContext(); + auto pointeeType = type.getPointeeType(); + if (pointeeType.isa()) { + auto rankedTensorType = pointeeType.cast(); + // struct { offset0, offset1, shape0, shape1, stride0, + // stride1, base_ptr}; + auto eleType = rankedTensorType.getElementType(); + auto shape = rankedTensorType.getShape(); + SmallVector types; + // offsets + for (size_t i = 0; i < shape.size(); ++i) + types.push_back(IntegerType::get(ctx, 32)); + // shapes, strides + for (size_t i = 0; i < 2 * shape.size(); ++i) + types.push_back(IntegerType::get(ctx, 64)); + + types.push_back( + LLVM::LLVMPointerType::get(eleType, type.getAddressSpace())); + + return LLVM::LLVMStructType::getLiteral(ctx, types); + } return LLVM::LLVMPointerType::get(convertType(type.getPointeeType()), type.getAddressSpace()); } diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 604a6f857839..deff7bd71d94 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -697,6 +697,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, MLIRContext *context = patterns.getContext(); patterns .insert< // TODO: view should have custom pattern that views the layout + TritonGenericPattern, + TritonGenericPattern, TritonGenericPattern, TritonGenericPattern, TritonGenericPattern, diff --git a/lib/Target/PTX/PTXTranslation.cpp b/lib/Target/PTX/PTXTranslation.cpp index 809eb15d3cc4..379bd25696a5 100644 --- a/lib/Target/PTX/PTXTranslation.cpp +++ b/lib/Target/PTX/PTXTranslation.cpp @@ -45,9 +45,16 @@ static bool findAndReplace(std::string &str, const std::string &begin, static void linkExternal(llvm::Module &module) { namespace fs = std::filesystem; - static const std::filesystem::path path = - std::filesystem::path(__BUILD_DIR__) / "lib" / "Hopper" / - "libhopper_helpers.bc"; + + // TODO: enable generating bc file from clang. + static const auto this_file_path = std::filesystem::path(__FILE__); + static const auto path = + this_file_path.parent_path().parent_path().parent_path().parent_path() / + "python" / "triton" / "hopper_lib" / "libhopper_helpers.bc"; + + // static const std::filesystem::path path = + // std::filesystem::path(__BUILD_DIR__) / "lib" / "Hopper" / + // "libhopper_helpers.bc"; if (mlir::triton::linkExternLib(module, "libhopper_helpers", path.string(), /*isROCM*/ false)) llvm::errs() << "Link failed for: libhopper_helpers.bc"; diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index c052425f4d17..3f3438fdbfcc 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -28,7 +28,8 @@ from ..tools.disasm import extract from .code_generator import ast_to_ttir from .make_launcher import make_stub -from .utils import TensorMapManager, get_ids_of_tensormaps, parse_tma_info +from .utils import (InfoFromBackendForTensorMap, TensorMapManager, + get_ids_of_tensormaps, parse_tma_info) def inline_triton_ir(mod): @@ -507,6 +508,9 @@ def compile(fn, **kwargs): if metadata_path is not None: with open(metadata_path) as f: metadata = json.load(f) + if 'tensormaps_info' in metadata: + metadata['tensormaps_info'] = [ + InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']] else: metadata = {"num_warps": num_warps, "num_ctas": num_ctas, @@ -522,6 +526,7 @@ def compile(fn, **kwargs): # Add device type to meta information metadata["device_type"] = device_type + metadata["cache_key"] = fn_cache_manager.key first_stage = list(stages.keys()).index(ext) asm = dict() @@ -597,7 +602,7 @@ def compile(fn, **kwargs): so_path = _device_backend.make_launcher_stub(name, signature, constants, ids) # write-back metadata, if it didn't come from the cache if metadata_path is None: - metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata), metadata_filename, binary=False) + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False) fn_cache_manager.put_group(metadata_filename, metadata_group) # return handle to compiled kernel @@ -631,6 +636,7 @@ def __init__(self, fn, so_path, metadata, asm): if "tensormaps_info" in metadata: self.tensormaps_info = metadata["tensormaps_info"] self.constants = metadata["constants"] + self.cache_key = metadata["cache_key"] self.device_type = metadata["device_type"] self.device_backend = get_backend(self.device_type) if self.device_type not in ["cuda", "hip"] else None # initialize asm dict @@ -690,6 +696,7 @@ def __getitem__(self, grid): self._init_handles() def runner(*args, stream=None): + args_expand = self.assemble_tensormap_to_arg(args, self.constants) if stream is None: if self.device_type in ["cuda", "rocm"]: stream = get_cuda_stream() @@ -697,7 +704,7 @@ def runner(*args, stream=None): stream = get_backend(self.device_type).get_stream(None) self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.clusterDims[0], self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function, - CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args) + CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand) return runner def get_sass(self, fun=None): diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 8d70e545b2f9..1bf29300c90e 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -115,7 +115,9 @@ def run(self, *args, **kwargs): if config.pre_hook is not None: full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs} config.pre_hook(full_nargs) - ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) + ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, + num_ctas=config.num_ctas, + enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs) self.nargs = None return ret diff --git a/python/triton/testing.py b/python/triton/testing.py index b6d217ec4d88..640d7e22388a 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -207,9 +207,15 @@ def _run(self, bench, save_path, show_plots, print_data): y_mean = bench.line_names y_min = [f'{x}-min' for x in bench.line_names] y_max = [f'{x}-max' for x in bench.line_names] - df = pd.DataFrame(columns=[bench.x_names[0]] + y_mean + y_min + y_max) + x_names_str = str(bench.x_names) + df = pd.DataFrame(columns=[x_names_str] + y_mean + y_min + y_max) for x in bench.x_vals: - x_args = {x_name: x for x_name in bench.x_names} + if not isinstance(x, list): + x = [x] + if len(x) == 1: + x = x * len(bench.x_names) + x_str = str(x) + x_args = {x_name: x_in for x_name, x_in in zip(bench.x_names, x)} row_mean, row_min, row_max = [], [], [] for y in bench.line_vals: ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args) @@ -220,11 +226,11 @@ def _run(self, bench, save_path, show_plots, print_data): row_mean += [y_mean] row_min += [y_min] row_max += [y_max] - df.loc[len(df)] = [x] + row_mean + row_min + row_max + df.loc[len(df)] = [x_str] + row_mean + row_min + row_max if bench.plot_name: plt.figure() ax = plt.subplot() - x = bench.x_names[0] + x = x_names_str for i, y in enumerate(bench.line_names): y_min, y_max = df[y + '-min'], df[y + '-max'] col = bench.styles[i][0] if bench.styles else None @@ -243,7 +249,7 @@ def _run(self, bench, save_path, show_plots, print_data): plt.show() if save_path: plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png")) - df = df[[bench.x_names[0]] + bench.line_names] + df = df[[x_names_str] + bench.line_names] if print_data: print(bench.plot_name + ':') print(df)