Skip to content

Commit

Permalink
[BUILD] MacOS can now build compiler and run MLIR tests (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet authored Jul 27, 2022
1 parent 0cb567f commit 548d9cd
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 11 deletions.
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ if(WIN32)
add_subdirectory(deps/dlfcn-win32/src ${CMAKE_BINARY_DIR}/dlfcn-win32)
endif()

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17")
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden")
if(APPLE)
set(CMAKE_OSX_DEPLOYMENT_TARGET 11.6)
endif()



##########
Expand Down
22 changes: 13 additions & 9 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using namespace mlir::triton::gpu;
// parse an array of integers
static LogicalResult parseIntArrayAttr(AsmParser &parser,
const NamedAttribute &attr,
/*SmallVector<unsigned, 2>*/ auto &res,
SmallVector<unsigned, 2> &res,
StringRef desc) {
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
if (!arrayAttr) {
Expand Down Expand Up @@ -84,7 +84,8 @@ static Attribute parseBlocked(AsmParser &parser, Type type) {
broadcastAxis);
}

static void printBlocked(AsmPrinter &printer, auto *attr) {
template <class T>
static void printBlocked(AsmPrinter &printer, const T *attr) {
printer << "<{"
<< "threadTileSize = [" << attr->getThreadTileSize() << "]"
<< ", warpTileSize = [" << attr->getWarpTileSize() << "]"
Expand All @@ -95,7 +96,7 @@ static void printBlocked(AsmPrinter &printer, auto *attr) {
}

Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
parseBlocked(parser, type);
return parseBlocked(parser, type);
}

void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
Expand All @@ -104,7 +105,7 @@ void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {

Attribute TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser,
Type type) {
parseBlocked(parser, type);
return parseBlocked(parser, type);
}

void TritonGPUBlockedMulticastEncodingAttr::print(AsmPrinter &printer) const {
Expand Down Expand Up @@ -163,7 +164,7 @@ static Attribute parseMma(AsmParser &parser, Type type) {
shapePerTile, repetitions, contigPerThread, broadcastAxis);
}

static void printMma(AsmPrinter &printer, auto *attr) {
template <class T> static void printMma(AsmPrinter &printer, T *attr) {
printer << "<{"
<< "fragmentPerWarp = [" << attr->getFragmentPerWarp() << "]"
<< ", shapePerWarp = [" << attr->getShapePerWarp() << "]"
Expand Down Expand Up @@ -276,12 +277,14 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
attr.dyn_cast<TritonGPUBlockedMulticastEncodingAttr>()) {
os << "blocked_multicast";
TritonGPUOpAsmInterface::printBlocked(blockedMulticastAttr, os);
return AliasResult::FinalAlias;
}
OpAsmDialectInterface::getAlias(attr, os);
return AliasResult::FinalAlias;
}

private:
static void printMma(const auto &attr, raw_ostream &os) {
static void printMma(const TritonGPUMmaEncodingAttr &attr, raw_ostream &os) {
TritonGPUOpAsmInterface::printArray(attr.getFragmentPerWarp(), os);
TritonGPUOpAsmInterface::printArray(attr.getShapePerWarp(), os);
TritonGPUOpAsmInterface::printArray(attr.getWarpPerTile(), os);
Expand All @@ -290,22 +293,23 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
TritonGPUOpAsmInterface::printArray(attr.getContigPerThread(), os);
}

static void printShared(const auto &attr, raw_ostream &os) {
static void printShared(const TritonGPUSharedEncodingAttr &attr,
raw_ostream &os) {
os << "_" << attr.getVec();
os << "_" << attr.getPerPhase();
os << "_" << attr.getMaxPhase();
TritonGPUOpAsmInterface::printArray(attr.getOrder(), os);
}

static void printBlocked(const auto &attr, raw_ostream &os) {
template <class T> static void printBlocked(const T &attr, raw_ostream &os) {
TritonGPUOpAsmInterface::printArray(attr.getThreadTileSize(), os);
TritonGPUOpAsmInterface::printArray(attr.getWarpTileSize(), os);
TritonGPUOpAsmInterface::printArray(attr.getBlockTileSize(), os);
TritonGPUOpAsmInterface::printArray(attr.getOrder(), os);
TritonGPUOpAsmInterface::printArray(attr.getBroadcastAxis(), os);
}

static void printArray(const auto &array, raw_ostream &os,
static void printArray(const ArrayRef<unsigned> &array, raw_ostream &os,
const std::string &delimiter = "x") {
os << "_";
if (array.empty()) {
Expand Down
4 changes: 3 additions & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

def get_llvm():
# download if nothing is installed
name = 'clang+llvm-14.0.0-x86_64-linux-gnu-ubuntu-18.04'
system = platform.system()
suffix = {"Linux": "linux-gnu-ubuntu-18.04", "Darwin": "apple-darwin"}[system]
name = f'clang+llvm-14.0.0-x86_64-{suffix}'
dir = '/tmp'
llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name)
llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name)
Expand Down

0 comments on commit 548d9cd

Please sign in to comment.