Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MHLO] Init MHLO integration. #1083

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/buildAndTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ jobs:
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$GITHUB_WORKSPACE" \
-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR="${GITHUB_WORKSPACE}/external/llvm-external-projects/torch-mlir-dialects" \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DTORCH_MLIR_ENABLE_MHLO=ON \
-DLLVM_TARGETS_TO_BUILD=host
ninja check-torch-mlir-all
- name: RefBackend - TorchScript end-to-end tests
Expand Down Expand Up @@ -81,6 +82,7 @@ jobs:
-DLLVM_ENABLE_PROJECTS=mlir \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DLLVM_TARGETS_TO_BUILD=host \
-DTORCH_MLIR_ENABLE_MHLO=ON \
externals/llvm-project/llvm
ninja -Cllvm-build

Expand All @@ -94,6 +96,7 @@ jobs:
-DMLIR_DIR="$(pwd)/llvm-build/lib/cmake/mlir/" \
-DLLVM_DIR="$(pwd)/llvm-build/lib/cmake/llvm/" \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DTORCH_MLIR_ENABLE_MHLO=ON \
-DPython3_EXECUTABLE=$(which python) \
.
ninja -Cbuild check-torch-mlir-all
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "external/llvm-project"]
path = externals/llvm-project
url = https://github.com/llvm/llvm-project.git
[submodule "externals/mlir-hlo"]
path = externals/mlir-hlo
url = https://github.com/tensorflow/mlir-hlo.git
19 changes: 19 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,18 @@ macro(torch_mlir_add_llvm_external_project name identifier location)
set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE)
endmacro()

option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON)
if(TORCH_MLIR_ENABLE_MHLO)
add_definitions(-DTORCH_MLIR_ENABLE_MHLO)
endif()

torch_mlir_add_llvm_external_project(
torch-mlir-dialects
TORCH_MLIR_DIALECTS
${CMAKE_CURRENT_SOURCE_DIR}/externals/llvm-external-projects/torch-mlir-dialects)

if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
message(STATUS "Torch-MLIR out-of-tree build.")
# Out-of-tree build

#-------------------------------------------------------------------------------
Expand Down Expand Up @@ -82,10 +88,14 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
set(BACKEND_PACKAGE_STRING "LLVM ${LLVM_PACKAGE_VERSION}")
add_subdirectory(externals/llvm-external-projects/torch-mlir-dialects)
else()
message(STATUS "Torch-MLIR in-tree build.")
# In-tree build with LLVM_EXTERNAL_PROJECTS=torch-mlir
# FIXME: This should really be inherited from the LLVM tree. In particular,
# it's going to change when cross-compiling.
set(MLIR_TABLEGEN_EXE mlir-tblgen)
if (TORCH_MLIR_ENABLE_MHLO)
set(MLIR_PDLL_TABLEGEN_EXE mlir-pdll)
endif()

option(MLIR_ENABLE_BINDINGS_PYTHON "Enables MLIR Python Bindings" OFF)
option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON)
Expand All @@ -97,6 +107,15 @@ else()
set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}")
endif()

if (TORCH_MLIR_ENABLE_MHLO)
set(MHLO_BUILD_EMBEDDED ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo
${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo
EXCLUDE_FROM_ALL)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo/include)
endif()

set(TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
set(TORCH_MLIR_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
message(STATUS "Building torch-mlir project at ${TORCH_MLIR_SOURCE_DIR} (into ${TORCH_MLIR_BINARY_DIR})")
Expand Down
1 change: 1 addition & 0 deletions externals/mlir-hlo
Submodule mlir-hlo added at eb1042
6 changes: 5 additions & 1 deletion include/torch-mlir/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
if(TORCH_MLIR_ENABLE_MHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
else()
mlir_tablegen(Passes.h.inc -gen-pass-decls)
endif()
add_public_tablegen_target(TorchMLIRConversionPassIncGen)

add_mlir_doc(Passes TorchMLIRConversionPasses ./ -gen-pass-doc)
10 changes: 10 additions & 0 deletions include/torch-mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,14 @@ def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToTMTensorPass()";
}

#ifdef TORCH_MLIR_ENABLE_MHLO
def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
let summary = "Convert Torch ops to MHLO ops";
let description = [{
Convert Torch ops to mhlo ops.
}];
let constructor = "mlir::torch::createConvertTorchToMhloPass()";
}
#endif

#endif // TORCHMLIR_CONVERSION_PASSES
23 changes: 23 additions & 0 deletions include/torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
#define TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>

namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToMhloPass();
} // namespace torch
} // namespace mlir

#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ void createTorchBackendToTosaBackendPipeline(
OpPassManager &pm,
const torch::Torch::TorchLoweringPipelineOptions &options);

// Do not register the torch-to-mhlo pipeline if mhlo target is disabled
#ifdef TORCH_MLIR_ENABLE_MHLO
void createTorchBackendToMhloBackendPipeline(
OpPassManager &pm,
const torch::Torch::TorchLoweringPipelineOptions &options);
#endif

std::unique_ptr<OperationPass<ModuleOp>>
createVerifyInvariantsBeforeBackendLoweringPass();

Expand Down
20 changes: 13 additions & 7 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,22 @@ add_subdirectory(TorchToLinalg)
add_subdirectory(TorchToSCF)
add_subdirectory(TorchToStd)
add_subdirectory(TorchToTosa)
if(TORCH_MLIR_ENABLE_MHLO)
add_subdirectory(TorchToMhlo)
endif()
add_subdirectory(TorchToTMTensor)
add_subdirectory(Utils)

# TODO: Automate this with add_torch_mlir_conversion_library.
#get_property(torch_mlir_conversion_libs GLOBAL PROPERTY TORCH_MLIR_CONVERSION_LIBS)
set(linked_libs TorchMLIRTorchToLinalg
TorchMLIRTorchToSCF
TorchMLIRTorchToStd
TorchMLIRTorchToTosa
TorchMLIRTorchToTMTensor
TorchMLIRConversionUtils)
if(TORCH_MLIR_ENABLE_MHLO)
list(APPEND linked_libs TorchMLIRTorchToMhlo)
endif()

add_mlir_library(TorchMLIRConversionPasses
Passes.cpp
Expand All @@ -18,11 +29,6 @@ add_mlir_library(TorchMLIRConversionPasses
Core

LINK_LIBS PUBLIC
TorchMLIRTorchToLinalg
TorchMLIRTorchToSCF
TorchMLIRTorchToStd
TorchMLIRTorchToTosa
TorchMLIRTorchToTMTensor
TorchMLIRConversionUtils
${linked_libs}
#${torch_mlir_conversion_libs}
)
1 change: 1 addition & 0 deletions lib/Conversion/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"

//===----------------------------------------------------------------------===//
Expand Down
71 changes: 71 additions & 0 deletions lib/Conversion/TorchToMhlo/BasicOp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"

#include "../PassDetail.h"
#include "./PopulatePatterns.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include <iostream>
#include <numeric>

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;


namespace {
template <typename AtenOpT>
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
ZihengJiang marked this conversation as resolved.
Show resolved Hide resolved
} // namespace

// AtenTanhOp
namespace {
template <>
LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
AtenTanhOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.self();
auto selfTy = self.getType().cast<TensorType>();
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<mhlo::TanhOp>(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
} else {
return op.emitError(
"Only floating-point datatype legalization currently supported");
}
}
} // namespace

void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
MLIRContext *context = patterns.getContext();

#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
INSERT_ATENOP_PATTERN(AtenTanhOp);
#undef INSERT_ATENOP_PATTERN

}
22 changes: 22 additions & 0 deletions lib/Conversion/TorchToMhlo/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
add_mlir_conversion_library(TorchMLIRTorchToMhlo
TorchToMhlo.cpp
BasicOp.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo

DEPENDS
MhloDialect
TorchMLIRConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MhloDialect
TorchMLIRTorchDialect
)

torch_mlir_target_includes(TorchMLIRTorchToMhlo)
27 changes: 27 additions & 0 deletions lib/Conversion/TorchToMhlo/PopulatePatterns.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
#define TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H

#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
namespace torch {
namespace torch_to_mhlo {

void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);

} // namespace torch_to_mhlo
} // namespace torch
} // namespace mlir

#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
66 changes: 66 additions & 0 deletions lib/Conversion/TorchToMhlo/TorchToMhlo.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"

#include "../PassDetail.h"
#include "./PopulatePatterns.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {

class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mhlo::MhloDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithmeticDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<mhlo::MhloDialect, tensor::TensorDialect,
arith::ArithmeticDialect, Torch::TorchDialect>();

TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);

RewritePatternSet patterns(context);

torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns,
target);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
return signalPassFailure();
}
}
};

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToMhloPass() {
return std::make_unique<ConvertTorchToMhlo>();
}
Loading