Skip to content

Commit

Permalink
[MLIR][mesh] moving shardinginterfaceimpl for tensor to tensor extens…
Browse files Browse the repository at this point in the history
…ion lib (llvm#104913)

Follow-up to llvm#102598 : as discussed, move tensor sharding implementation
into separate tensor extension lib.

@sogartar @yaochengji, could you take a look at this PR?
  • Loading branch information
fschlimb authored and cjdb committed Aug 23, 2024
1 parent 6663d2f commit 1d65d53
Show file tree
Hide file tree
Showing 12 changed files with 101 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- ShardingInterfaceImpl.h - ------------------------------------------===//
//===- MeshShardingExtensions.h - -----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down
30 changes: 30 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/Extensions/AllExtensions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//===- AllExtensions.h - All Tensor Extensions ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a common entry point for registering all extensions to the
// Tensor dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H
#define MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H

namespace mlir {
class DialectRegistry;

namespace tensor {
/// Register all extensions of the Tensor dialect. This should generally only be
/// used by tools, or other use cases that really do want *all* extensions of
/// the dialect. All other cases should prefer to instead register the specific
/// extensions they intend to take advantage of.
void registerAllExtensions(DialectRegistry &registry);
} // namespace tensor

} // namespace mlir

#endif // MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//===- MeshShardingExtensions.h - -------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_
#define MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_

namespace mlir {

class DialectRegistry;

namespace tensor {

void registerShardingInterfaceExternalModels(DialectRegistry &registry);

} // namespace tensor
} // namespace mlir

#endif // MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_
2 changes: 0 additions & 2 deletions mlir/include/mlir/InitAllDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
Expand Down Expand Up @@ -182,7 +181,6 @@ inline void registerAllDialects(DialectRegistry &registry) {
tensor::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
tensor::registerInferTypeOpInterfaceExternalModels(registry);
tensor::registerShardingInterfaceExternalModels(registry);
tensor::registerSubsetOpInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
tensor::registerValueBoundsOpInterfaceExternalModels(registry);
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllExtensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
Expand All @@ -60,6 +61,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
registerConvertComplexToLLVMInterface(registry);
cf::registerConvertControlFlowToLLVMInterface(registry);
func::registerAllExtensions(registry);
tensor::registerAllExtensions(registry);
registerConvertFuncToLLVMInterface(registry);
index::registerConvertIndexToLLVMInterface(registry);
registerConvertMathToLLVMInterface(registry);
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
add_mlir_library(MLIRShardingInterface
ShardingInterface.cpp
TensorShardingInterfaceImpl.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(Extensions)
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(TransformOps)
Expand Down
16 changes: 16 additions & 0 deletions mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//===- AllExtensions.cpp - All Tensor Dialect Extensions ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
#include "mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h"

using namespace mlir;

void mlir::tensor::registerAllExtensions(DialectRegistry &registry) {
registerShardingInterfaceExternalModels(registry);
}
26 changes: 26 additions & 0 deletions mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
set(LLVM_OPTIONAL_SOURCES
AllExtensions.cpp
MeshShardingExtensions.cpp
)

add_mlir_extension_library(MLIRTensorMeshShardingExtensions
MeshShardingExtensions.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions

LINK_LIBS PUBLIC
MLIRTensorDialect
MLIRIR
MLIRShardingInterface
)

add_mlir_extension_library(MLIRTensorAllExtensions
AllExtensions.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions

LINK_LIBS PUBLIC
MLIRTensorMeshShardingExtensions
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/DialectRegistry.h"
#include "llvm/Support/Debug.h"
Expand Down
1 change: 1 addition & 0 deletions mlir/tools/mlir-lsp-server/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ set(LIBS
MLIRLspServerLib
MLIRParser
MLIRPass
MLIRTensorAllExtensions
MLIRTransforms
MLIRTransformUtils
MLIRSupport
Expand Down

0 comments on commit 1d65d53

Please sign in to comment.