Skip to content

Commit

Permalink
Fix more front end problems. Add a workaround for InsertAsyncV2 mask …
Browse files Browse the repository at this point in the history
…handling. (triton-lang#5)

Started a TODO.md file to keep track of the tasks to do before merging.
  • Loading branch information
ThomasRaoux committed Jul 15, 2023
1 parent 6aa244d commit 9d423e0
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 26 deletions.
1 change: 0 additions & 1 deletion lib/Analysis/Alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ void SharedMemoryAliasAnalysis::visitOperation(
// XXX(Keren): the following ops are always aliasing for now
if (isa<triton::gpu::ExtractSliceOp, triton::TransOp,
triton::nvidia_gpu::ExtractMBarrierOp>(op)) {
llvm::errs() << "ExtractSliceOp, TransOp, ExtractMBarrierOp\n";
// extract_slice %src
// trans %src
aliasInfo = AliasInfo(operands[0]->getValue());
Expand Down
6 changes: 5 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1330,8 +1330,12 @@ struct InsertSliceAsyncV2OpConversion
Value pred = icmp_eq(threadId, i32_val(0));

auto mask = adaptor.getMask();
if (mask)
if (mask) {
// TODO(thomas): What is the right implementation for this case?
assert(mask.getType().isInteger(1) &&
"need to implement cases with tensor mask");
pred = rewriter.create<arith::AndIOp>(loc, pred, mask);
}

Value mcastMask = getMCastMask(sharedLayout, rewriter, loc, clusterCTAId);

Expand Down
54 changes: 43 additions & 11 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Membar.h"
Expand Down Expand Up @@ -62,6 +63,28 @@ class TritonLLVMFunctionConversionTarget : public ConversionTarget {
}
};

class FoldSplatMaskInInsertAsync : public mlir::RewritePattern {

public:
FoldSplatMaskInInsertAsync(mlir::MLIRContext *context)
: mlir::RewritePattern(
triton::nvidia_gpu::InsertSliceAsyncV2Op::getOperationName(), 1,
context) {}

LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto insertOp = cast<triton::nvidia_gpu::InsertSliceAsyncV2Op>(op);
auto splatOp = insertOp.getMask().getDefiningOp<triton::SplatOp>();
if (!splatOp)
return failure();
rewriter.updateRootInPlace(insertOp, [&]() {
insertOp.getMaskMutable().assign(splatOp->getOperand(0));
});
return mlir::success();
}
};

struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
using ConvertOpToLLVMPattern<triton::ReturnOp>::ConvertOpToLLVMPattern;

Expand Down Expand Up @@ -152,6 +175,18 @@ struct FuncOpConversion : public FuncOpConversionBase {
if (!allocation.isRoot(funcOp))
amendedFuncOp = amendFuncOp(funcOp, rewriter);

// Collect TMA informations.
unsigned numTMALoad = 0;
funcOp.walk(
[&numTMALoad](triton::nvidia_gpu::InsertSliceAsyncV2Op insertSliceOp) {
numTMALoad++;
});
unsigned numTMAStore = 0;
funcOp.walk([&numTMAStore](triton::nvidia_gpu::StoreAsyncOp storeAsyncOp) {
numTMAStore++;
});
unsigned numTMA = numTMALoad + numTMAStore;

auto newFuncOp = convertFuncOpToLLVMFuncOp(amendedFuncOp, rewriter);
if (!newFuncOp) {
return failure();
Expand All @@ -177,17 +212,6 @@ struct FuncOpConversion : public FuncOpConversionBase {
// The call graph is updated by mapping the old function to the new one.
allocation.mapFuncOp(funcOp, newFuncOp);

unsigned numTMALoad = 0;
funcOp.walk(
[&numTMALoad](triton::nvidia_gpu::InsertSliceAsyncV2Op insertSliceOp) {
numTMALoad++;
});
unsigned numTMAStore = 0;
funcOp.walk([&numTMAStore](triton::nvidia_gpu::StoreAsyncOp storeAsyncOp) {
numTMAStore++;
});
unsigned numTMA = numTMALoad + numTMAStore;

// Append arguments to receive TMADesc in global memory in the runtime
auto i8PtrTy = LLVM::LLVMPointerType::get(
this->getTypeConverter()->convertType(rewriter.getI8Type()), 1);
Expand Down Expand Up @@ -384,6 +408,14 @@ class ConvertTritonGPUToLLVM
}
});

// Hack: cleanup
{
RewritePatternSet patterns(context);
patterns.add<FoldSplatMaskInInsertAsync>(context);
if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed())
signalPassFailure();
}

// Lower functions
{
mlir::LowerToLLVMOptions option(context);
Expand Down
22 changes: 21 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,27 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(

Type TritonGPUToLLVMTypeConverter::convertTritonPointerType(
triton::PointerType type) {
// Recursively translate pointee type
auto ctx = type.getContext();
auto pointeeType = type.getPointeeType();
if (pointeeType.isa<RankedTensorType>()) {
auto rankedTensorType = pointeeType.cast<RankedTensorType>();
// struct { offset0, offset1, shape0, shape1, stride0,
// stride1, base_ptr};
auto eleType = rankedTensorType.getElementType();
auto shape = rankedTensorType.getShape();
SmallVector<Type, 4> types;
// offsets
for (size_t i = 0; i < shape.size(); ++i)
types.push_back(IntegerType::get(ctx, 32));
// shapes, strides
for (size_t i = 0; i < 2 * shape.size(); ++i)
types.push_back(IntegerType::get(ctx, 64));

types.push_back(
LLVM::LLVMPointerType::get(eleType, type.getAddressSpace()));

return LLVM::LLVMStructType::getLiteral(ctx, types);
}
return LLVM::LLVMPointerType::get(convertType(type.getPointeeType()),
type.getAddressSpace());
}
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
MLIRContext *context = patterns.getContext();
patterns
.insert< // TODO: view should have custom pattern that views the layout
TritonGenericPattern<triton::AdvanceOp>,
TritonGenericPattern<triton::MakeTensorPtrOp>,
TritonGenericPattern<triton::ViewOp>,
TritonGenericPattern<triton::BitcastOp>,
TritonGenericPattern<triton::FpToFpOp>,
Expand Down
13 changes: 10 additions & 3 deletions lib/Target/PTX/PTXTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,16 @@ static bool findAndReplace(std::string &str, const std::string &begin,

static void linkExternal(llvm::Module &module) {
namespace fs = std::filesystem;
static const std::filesystem::path path =
std::filesystem::path(__BUILD_DIR__) / "lib" / "Hopper" /
"libhopper_helpers.bc";

// TODO: enable generating bc file from clang.
static const auto this_file_path = std::filesystem::path(__FILE__);
static const auto path =
this_file_path.parent_path().parent_path().parent_path().parent_path() /
"python" / "triton" / "hopper_lib" / "libhopper_helpers.bc";

// static const std::filesystem::path path =
// std::filesystem::path(__BUILD_DIR__) / "lib" / "Hopper" /
// "libhopper_helpers.bc";
if (mlir::triton::linkExternLib(module, "libhopper_helpers", path.string(),
/*isROCM*/ false))
llvm::errs() << "Link failed for: libhopper_helpers.bc";
Expand Down
13 changes: 10 additions & 3 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from ..tools.disasm import extract
from .code_generator import ast_to_ttir
from .make_launcher import make_stub
from .utils import TensorMapManager, get_ids_of_tensormaps, parse_tma_info
from .utils import (InfoFromBackendForTensorMap, TensorMapManager,
get_ids_of_tensormaps, parse_tma_info)


def inline_triton_ir(mod):
Expand Down Expand Up @@ -507,6 +508,9 @@ def compile(fn, **kwargs):
if metadata_path is not None:
with open(metadata_path) as f:
metadata = json.load(f)
if 'tensormaps_info' in metadata:
metadata['tensormaps_info'] = [
InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']]
else:
metadata = {"num_warps": num_warps,
"num_ctas": num_ctas,
Expand All @@ -522,6 +526,7 @@ def compile(fn, **kwargs):

# Add device type to meta information
metadata["device_type"] = device_type
metadata["cache_key"] = fn_cache_manager.key

first_stage = list(stages.keys()).index(ext)
asm = dict()
Expand Down Expand Up @@ -597,7 +602,7 @@ def compile(fn, **kwargs):
so_path = _device_backend.make_launcher_stub(name, signature, constants, ids)
# write-back metadata, if it didn't come from the cache
if metadata_path is None:
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata), metadata_filename, binary=False)
metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False)
fn_cache_manager.put_group(metadata_filename, metadata_group)

# return handle to compiled kernel
Expand Down Expand Up @@ -631,6 +636,7 @@ def __init__(self, fn, so_path, metadata, asm):
if "tensormaps_info" in metadata:
self.tensormaps_info = metadata["tensormaps_info"]
self.constants = metadata["constants"]
self.cache_key = metadata["cache_key"]
self.device_type = metadata["device_type"]
self.device_backend = get_backend(self.device_type) if self.device_type not in ["cuda", "hip"] else None
# initialize asm dict
Expand Down Expand Up @@ -690,14 +696,15 @@ def __getitem__(self, grid):
self._init_handles()

def runner(*args, stream=None):
args_expand = self.assemble_tensormap_to_arg(args, self.constants)
if stream is None:
if self.device_type in ["cuda", "rocm"]:
stream = get_cuda_stream()
else:
stream = get_backend(self.device_type).get_stream(None)
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.clusterDims[0],
self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args)
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand)
return runner

def get_sass(self, fun=None):
Expand Down
4 changes: 3 additions & 1 deletion python/triton/runtime/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def run(self, *args, **kwargs):
if config.pre_hook is not None:
full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs}
config.pre_hook(full_nargs)
ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs)
self.nargs = None
return ret

Expand Down
16 changes: 11 additions & 5 deletions python/triton/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,15 @@ def _run(self, bench, save_path, show_plots, print_data):
y_mean = bench.line_names
y_min = [f'{x}-min' for x in bench.line_names]
y_max = [f'{x}-max' for x in bench.line_names]
df = pd.DataFrame(columns=[bench.x_names[0]] + y_mean + y_min + y_max)
x_names_str = str(bench.x_names)
df = pd.DataFrame(columns=[x_names_str] + y_mean + y_min + y_max)
for x in bench.x_vals:
x_args = {x_name: x for x_name in bench.x_names}
if not isinstance(x, list):
x = [x]
if len(x) == 1:
x = x * len(bench.x_names)
x_str = str(x)
x_args = {x_name: x_in for x_name, x_in in zip(bench.x_names, x)}
row_mean, row_min, row_max = [], [], []
for y in bench.line_vals:
ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args)
Expand All @@ -220,11 +226,11 @@ def _run(self, bench, save_path, show_plots, print_data):
row_mean += [y_mean]
row_min += [y_min]
row_max += [y_max]
df.loc[len(df)] = [x] + row_mean + row_min + row_max
df.loc[len(df)] = [x_str] + row_mean + row_min + row_max
if bench.plot_name:
plt.figure()
ax = plt.subplot()
x = bench.x_names[0]
x = x_names_str
for i, y in enumerate(bench.line_names):
y_min, y_max = df[y + '-min'], df[y + '-max']
col = bench.styles[i][0] if bench.styles else None
Expand All @@ -243,7 +249,7 @@ def _run(self, bench, save_path, show_plots, print_data):
plt.show()
if save_path:
plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
df = df[[bench.x_names[0]] + bench.line_names]
df = df[[x_names_str] + bench.line_names]
if print_data:
print(bench.plot_name + ':')
print(df)
Expand Down

0 comments on commit 9d423e0

Please sign in to comment.