Skip to content

Commit

Permalink
Merged master:acb69f3b7c83 into amd-gfx:d162b3ea0f06
Browse files Browse the repository at this point in the history
Local branch amd-gfx d162b3e Merged master:1dea8ed8b7dd into amd-gfx:ee2ed3abef34
Remote branch master acb69f3 [mlir] Change ConvertOpToLLVMPattern::matchAndRewrite argument to concrete operand type.
  • Loading branch information
Sw authored and Sw committed Nov 28, 2020
2 parents d162b3e + acb69f3 commit ec1f53c
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 206 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -564,14 +564,47 @@ class ConvertToLLVMPattern : public ConversionPattern {

/// Utility class for operation conversions targeting the LLVM dialect that
/// match exactly one source operation.
template <typename OpTy>
template <typename SourceOp>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertToLLVMPattern(OpTy::getOperationName(),
: ConvertToLLVMPattern(SourceOp::getOperationName(),
&typeConverter.getContext(), typeConverter,
benefit) {}

/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}
LogicalResult match(Operation *op) const final {
return match(cast<SourceOp>(op));
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
}

/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override rewrite or matchAndRewrite");
}
virtual LogicalResult match(SourceOp op) const {
llvm_unreachable("must override match or matchAndRewrite");
}
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (succeeded(match(op))) {
rewrite(op, operands, rewriter);
return success();
}
return failure();
}
};

namespace LLVM {
Expand Down Expand Up @@ -604,7 +637,7 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
/// Converts the type of the result to an LLVM type, pass operands as is,
/// preserve attributes.
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
operands, this->typeConverter,
Expand All @@ -621,7 +654,7 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
static_assert(
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
Expand Down
67 changes: 34 additions & 33 deletions mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class ConvertHostRegisterOpToGpuRuntimeCallPattern

private:
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};

Expand Down Expand Up @@ -205,7 +205,7 @@ class ConvertWaitOpToGpuRuntimeCallPattern

private:
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(gpu::WaitOp waitOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};

Expand All @@ -219,7 +219,7 @@ class ConvertWaitAsyncOpToGpuRuntimeCallPattern

private:
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(gpu::WaitOp waitOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};

Expand Down Expand Up @@ -251,7 +251,7 @@ class ConvertLaunchFuncOpToGpuRuntimeCallPattern
Location loc, OpBuilder &builder) const;

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;

llvm::SmallString<32> gpuBinaryAnnotation;
Expand Down Expand Up @@ -321,14 +321,15 @@ isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
}

LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
gpu::HostRegisterOp hostRegisterOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto *op = hostRegisterOp.getOperation();
if (failed(areAllLLVMTypes(op, operands, rewriter)))
return failure();

Location loc = op->getLoc();

auto memRefType = cast<gpu::HostRegisterOp>(op).value().getType();
auto memRefType = hostRegisterOp.value().getType();
auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
auto elementSize = getSizeInBytes(loc, elementType, rewriter);

Expand Down Expand Up @@ -412,19 +413,19 @@ LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
// afterwards. In case this isn't correct, we will get a runtime error.
// Eventually, we will have a pass that guarantees this property.
LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
gpu::WaitOp waitOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (cast<gpu::WaitOp>(op).asyncToken())
return rewriter.notifyMatchFailure(op, "Cannot convert async op.");
if (waitOp.asyncToken())
return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");

Location loc = op->getLoc();
Location loc = waitOp.getLoc();

for (auto asyncDependency : operands)
streamSynchronizeCallBuilder.create(loc, rewriter, {asyncDependency});
for (auto asyncDependency : operands)
streamDestroyCallBuilder.create(loc, rewriter, {asyncDependency});

rewriter.eraseOp(op);
rewriter.eraseOp(waitOp);
return success();
}

Expand All @@ -435,23 +436,23 @@ LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
// assumes that there is no other use between the definition and this op, and
// the plan is to have a pass that guarantees this property.
LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
gpu::WaitOp waitOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (!cast<gpu::WaitOp>(op).asyncToken())
return rewriter.notifyMatchFailure(op, "Can only convert async op.");
if (!waitOp.asyncToken())
return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");

Location loc = op->getLoc();
Location loc = waitOp.getLoc();

auto insertionPoint = rewriter.saveInsertionPoint();
SmallVector<Value, 1> events;
for (auto pair : llvm::zip(op->getOperands(), operands)) {
for (auto pair : llvm::zip(waitOp.asyncDependencies(), operands)) {
auto token = std::get<0>(pair);
if (auto *defOp = token.getDefiningOp()) {
rewriter.setInsertionPointAfter(defOp);
} else {
// If we can't find the defining op, we record the event at block start,
// which is late and therefore misses parallelism, but still valid.
rewriter.setInsertionPointToStart(op->getBlock());
rewriter.setInsertionPointToStart(waitOp.getOperation()->getBlock());
}
auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
auto stream = std::get<1>(pair);
Expand All @@ -464,7 +465,7 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
for (auto event : events)
eventDestroyCallBuilder.create(loc, rewriter, {event});
rewriter.replaceOp(op, {stream});
rewriter.replaceOp(waitOp, {stream});

return success();
}
Expand Down Expand Up @@ -564,23 +565,21 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
// If the op is async, the stream corresponds to the (single) async dependency
// as well as the async token the op produces.
LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (failed(areAllLLVMTypes(op, operands, rewriter)))
if (failed(areAllLLVMTypes(launchOp, operands, rewriter)))
return failure();

auto launchOp = cast<gpu::LaunchFuncOp>(op);

if (launchOp.asyncDependencies().size() > 1)
return rewriter.notifyMatchFailure(
op, "Cannot convert with more than one async dependency.");
launchOp, "Cannot convert with more than one async dependency.");

// Fail when the synchronous version of the op has async dependencies. The
// lowering destroys the stream, and we do not want to check that there is no
// use of the stream after this op.
if (!launchOp.asyncToken() && !launchOp.asyncDependencies().empty())
return rewriter.notifyMatchFailure(
op, "Cannot convert non-async op with async dependencies.");
launchOp, "Cannot convert non-async op with async dependencies.");

Location loc = launchOp.getLoc();

Expand Down Expand Up @@ -612,31 +611,33 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
loc, rewriter, {module.getResult(0), kernelName});
auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
rewriter.getI32IntegerAttr(0));
auto adaptor = gpu::LaunchFuncOpAdaptor(operands, op->getAttrDictionary());
auto adaptor = gpu::LaunchFuncOpAdaptor(
operands, launchOp.getOperation()->getAttrDictionary());
Value stream =
adaptor.asyncDependencies().empty()
? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)
: adaptor.asyncDependencies().front();
// Create array of pointers to kernel arguments.
auto kernelParams = generateParamsArray(launchOp, operands, rewriter);
auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType);
launchKernelCallBuilder.create(
loc, rewriter,
{function.getResult(0), launchOp.gridSizeX(), launchOp.gridSizeY(),
launchOp.gridSizeZ(), launchOp.blockSizeX(), launchOp.blockSizeY(),
launchOp.blockSizeZ(), /*sharedMemBytes=*/zero, stream, kernelParams,
/*extra=*/nullpointer});
launchKernelCallBuilder.create(loc, rewriter,
{function.getResult(0), launchOp.gridSizeX(),
launchOp.gridSizeY(), launchOp.gridSizeZ(),
launchOp.blockSizeX(), launchOp.blockSizeY(),
launchOp.blockSizeZ(),
/*sharedMemBytes=*/zero, stream, kernelParams,
/*extra=*/nullpointer});

if (launchOp.asyncToken()) {
// Async launch: make dependent ops use the same stream.
rewriter.replaceOp(op, {stream});
rewriter.replaceOp(launchOp, {stream});
} else {
// Synchronize with host and destroy stream. This must be the stream created
// above (with no other uses) because we check that the synchronous version
// does not have any async dependencies.
streamSynchronizeCallBuilder.create(loc, rewriter, stream);
streamDestroyCallBuilder.create(loc, rewriter, stream);
rewriter.eraseOp(op);
rewriter.eraseOp(launchOp);
}
moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
gpu::LaunchFuncOp launchOp = cast<gpu::LaunchFuncOp>(op);
auto *op = launchOp.getOperation();
MLIRContext *context = rewriter.getContext();
auto module = launchOp.getParentOfType<ModuleOp>();

Expand Down
Loading

0 comments on commit ec1f53c

Please sign in to comment.