Skip to content

Commit

Permalink
builtin::module -> spirv::module
Browse files Browse the repository at this point in the history
  • Loading branch information
Hugobros3 committed Nov 18, 2024
1 parent a4dcdc7 commit d156cac
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,52 @@ void populateConvertToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
}

/// Pattern to convert a builtin.module to a spirv.module.
class ModuleConversion final : public OpConversionPattern<mlir::ModuleOp> {
public:
using OpConversionPattern<mlir::ModuleOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(mlir::ModuleOp moduleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

LogicalResult ModuleConversion::matchAndRewrite(
mlir::ModuleOp moduleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
const spirv::TargetEnv &targetEnv = typeConverter->getTargetEnv();
spirv::AddressingModel addressingModel = spirv::getAddressingModel(
targetEnv, typeConverter->getOptions().use64bitIndex);
FailureOr<spirv::MemoryModel> memoryModel = spirv::getMemoryModel(targetEnv);
if (failed(memoryModel))
return moduleOp.emitRemark("cannot deduce memory model from 'spirv.target_env'");

// Add a keyword to the module name to avoid symbolic conflict.
std::string spvModuleName = moduleOp.getName()->str();
auto spvModule = rewriter.create<spirv::ModuleOp>(
moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
StringRef(spvModuleName));

// Move the region from the module op into the SPIR-V module.
Region &spvModuleRegion = spvModule.getRegion();
rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion,
spvModuleRegion.begin());
// The spirv.module build method adds a block. Remove that.
rewriter.eraseBlock(&spvModuleRegion.back());

// Some of the patterns call `lookupTargetEnv` during conversion and they
// will fail if called after GPUModuleConversion and we don't preserve
// `TargetEnv` attribute.
// Copy TargetEnvAttr only if it is attached directly to the GPUModuleOp.
if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
spirv::getTargetEnvAttrName()))
spvModule->setAttr(spirv::getTargetEnvAttrName(), attr);

rewriter.eraseOp(moduleOp);
return success();
}

/// A pass to perform the SPIR-V conversion.
struct ConvertToSPIRVPass final
: impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
Expand Down Expand Up @@ -98,6 +144,7 @@ struct ConvertToSPIRVPass final
mapToMemRef(op, targetAttr);
populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
patterns);
patterns.add<ModuleConversion>(typeConverter, context);
if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
return;
Expand Down

0 comments on commit d156cac

Please sign in to comment.