Skip to content

Commit

Permalink
[mlir][sparse] Populate lvlToDim (#68937)
Browse files Browse the repository at this point in the history
Updates:
1. Infer lvlToDim from dimToLvl
2. Add more tests for block sparsity
3. Finish TODOs related to lvlToDim, including adding lvlToDim to python
binding

Verification of lvlToDim that user provides will be implemented in the
next PR.
  • Loading branch information
yinying-lisa-li committed Oct 17, 2023
1 parent 9922aad commit d4088e7
Show file tree
Hide file tree
Showing 13 changed files with 177 additions and 22 deletions.
3 changes: 1 addition & 2 deletions mlir/include/mlir-c/Dialect/SparseTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,10 @@ MLIR_CAPI_EXPORTED bool
mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr);

/// Creates a `sparse_tensor.encoding` attribute with the given parameters.
/// TODO: add a version that supplied lvlToDim when it cannot be inferred
MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet(
MlirContext ctx, intptr_t lvlRank,
enum MlirSparseTensorDimLevelType const *lvlTypes, MlirAffineMap dimToLvl,
int posWidth, int crdWidth);
MlirAffineMap lvlTodim, int posWidth, int crdWidth);

/// Returns the level-rank of the `sparse_tensor.encoding` attribute.
MLIR_CAPI_EXPORTED intptr_t
Expand Down
13 changes: 13 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,19 @@ inline bool hasAnySparseOperandOrResult(Operation *op) {
return hasAnySparseOperand(op) || hasAnySparseResult(op);
}

//
// Inference.
//

/// Given the dimToLvl map, infers the lvlToDim map, or returns
/// empty Affine map when inference fails.
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context);

/// Returns the lvlToDim map for the given dimToLvl map specific
/// to the block sparse cases.
/// Asserts on failure (so only use when known to succeed).
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context);

//
// Reordering.
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
"AffineMap":$lvlToDim,
"unsigned":$posWidth,
"unsigned":$crdWidth), [{
if (!lvlToDim) {
lvlToDim = ::mlir::sparse_tensor::inferLvlToDim(dimToLvl, $_ctxt);
}
return $_get($_ctxt, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr>{});
}]>
Expand Down
17 changes: 13 additions & 4 deletions mlir/lib/Bindings/Python/DialectSparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,17 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
.def_classmethod(
"get",
[](py::object cls, std::vector<MlirSparseTensorDimLevelType> lvlTypes,
std::optional<MlirAffineMap> dimToLvl, int posWidth, int crdWidth,
std::optional<MlirAffineMap> dimToLvl,
std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
MlirContext context) {
// TODO: provide dimToLvl
return cls(mlirSparseTensorEncodingAttrGet(
context, lvlTypes.size(), lvlTypes.data(),
dimToLvl ? *dimToLvl : MlirAffineMap{nullptr}, posWidth,
dimToLvl ? *dimToLvl : MlirAffineMap{nullptr},
lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth,
crdWidth));
},
py::arg("cls"), py::arg("lvl_types"), py::arg("dim_to_lvl"),
py::arg("pos_width"), py::arg("crd_width"),
py::arg("lvl_to_dim"), py::arg("pos_width"), py::arg("crd_width"),
py::arg("context") = py::none(),
"Gets a sparse_tensor.encoding from parameters.")
.def_property_readonly(
Expand All @@ -71,6 +72,14 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
return {};
return ret;
})
.def_property_readonly(
"lvl_to_dim",
[](MlirAttribute self) -> std::optional<MlirAffineMap> {
MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self);
if (mlirAffineMapIsNull(ret))
return {};
return ret;
})
.def_property_readonly("pos_width",
mlirSparseTensorEncodingAttrGetPosWidth)
.def_property_readonly("crd_width",
Expand Down
7 changes: 3 additions & 4 deletions mlir/lib/CAPI/Dialect/SparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,14 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
MlirAttribute
mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
MlirSparseTensorDimLevelType const *lvlTypes,
MlirAffineMap dimToLvl, int posWidth,
int crdWidth) {
MlirAffineMap dimToLvl, MlirAffineMap lvlToDim,
int posWidth, int crdWidth) {
SmallVector<DimLevelType> cppLvlTypes;
cppLvlTypes.reserve(lvlRank);
for (intptr_t l = 0; l < lvlRank; ++l)
cppLvlTypes.push_back(static_cast<DimLevelType>(lvlTypes[l]));
mlir::AffineMap lvlToDim; // TODO: provide in API
return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes,
unwrap(dimToLvl), lvlToDim,
unwrap(dimToLvl), unwrap(lvlToDim),
posWidth, crdWidth));
}

Expand Down
77 changes: 73 additions & 4 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,8 @@ Type SparseTensorEncodingAttr::getCrdType() const {
SparseTensorEncodingAttr
SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
// TODO: infer lvlToDim
return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl,
/*lvlToDim*/ AffineMap(), getPosWidth(),
getLvlToDim(), getPosWidth(),
getCrdWidth());
}

Expand Down Expand Up @@ -583,7 +582,8 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
#undef RETURN_ON_FAIL

// Construct struct-like storage for attribute.
AffineMap lvlToDim; // TODO: infer
// TODO: Fetch lvlToDim if user provides one
AffineMap lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
return parser.getChecked<SparseTensorEncodingAttr>(
parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
dimSlices);
Expand Down Expand Up @@ -749,6 +749,75 @@ mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
return nullptr;
}

AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
MLIRContext *context) {
auto map = static_cast<AffineMap>(dimToLvl);
AffineMap lvlToDim;
// Return an empty lvlToDim when inference is not successful.
if (!map || map.getNumSymbols() != 0) {
lvlToDim = AffineMap();
} else if (map.isPermutation()) {
lvlToDim = inversePermutation(map);
} else {
// TODO: check if it's block sparsity
lvlToDim = inverseBlockSparsity(map, context);
}
return lvlToDim;
}

AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
MLIRContext *context) {
SmallVector<AffineExpr> lvlExprs;
auto numLvls = dimToLvl.getNumResults();
lvlExprs.reserve(numLvls);
// lvlExprComponents stores information of the floordiv and mod operations
// applied to the same dimension, so as to build the lvlToDim map.
std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
for (unsigned i = 0, n = numLvls; i < n; i++) {
auto result = dimToLvl.getResult(i);
if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
if (result.getKind() == AffineExprKind::FloorDiv) {
// Position of the dimension in dimToLvl.
auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
"expected only one floordiv for each dimension");
SmallVector<AffineExpr, 3> components;
// Level variable for floordiv.
components.push_back(getAffineDimExpr(i, context));
// Multiplier.
components.push_back(binOp.getRHS());
// Map key is the position of the dimension.
lvlExprComponents[pos] = components;
} else if (result.getKind() == AffineExprKind::Mod) {
auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
"expected floordiv before mod");
// Add level variable for mod to the same vector
// of the corresponding floordiv.
lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
} else {
assert(false && "expected floordiv or mod");
}
} else {
lvlExprs.push_back(getAffineDimExpr(i, context));
}
}
// Build lvlExprs from lvlExprComponents.
// For example, for il = i floordiv 2 and ii = i mod 2, the components
// would be [il, 2, ii]. It could be used to build the AffineExpr
// i = il * 2 + ii in lvlToDim.
for (auto &components : lvlExprComponents) {
assert(components.second.size() == 3 &&
"expected 3 components to build lvlExprs");
auto mulOp = getAffineBinaryOpExpr(
AffineExprKind::Mul, components.second[0], components.second[1]);
auto addOp =
getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
lvlExprs.push_back(addOp);
}
return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
}

bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc,
Level startLvl, bool isUnique) {
if (!enc ||
Expand Down Expand Up @@ -811,7 +880,7 @@ RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
// default value.
unsigned posWidth = src.getPosWidth();
unsigned crdWidth = src.getCrdWidth();
AffineMap invPerm; // TODO
AffineMap invPerm = src.getLvlToDim();
auto enc = SparseTensorEncodingAttr::get(src.getContext(), lvlTypes, lvlPerm,
invPerm, posWidth, crdWidth);
return RankedTensorType::get(src.getDimShape(), src.getElementType(), enc);
Expand Down
5 changes: 3 additions & 2 deletions mlir/test/CAPI/sparse_tensor.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ static int testRoundtripEncoding(MlirContext ctx) {
// CHECK: level_type: 4
// CHECK: level_type: 8
// CHECK: level_type: 8
MlirAffineMap lvlToDim =
mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
enum MlirSparseTensorDimLevelType *lvlTypes =
malloc(sizeof(enum MlirSparseTensorDimLevelType) * lvlRank);
Expand All @@ -53,9 +55,8 @@ static int testRoundtripEncoding(MlirContext ctx) {
// CHECK: crdWidth: 64
int crdWidth = mlirSparseTensorEncodingAttrGetCrdWidth(originalAttr);
fprintf(stderr, "crdWidth: %d\n", crdWidth);
// TODO: lvlToDim
MlirAttribute newAttr = mlirSparseTensorEncodingAttrGet(
ctx, lvlRank, lvlTypes, dimToLvl, posWidth, crdWidth);
ctx, lvlRank, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth);
mlirAttributeDump(newAttr); // For debugging filecheck output.
// CHECK: equal: 1
fprintf(stderr, "equal: %d\n", mlirAttributeEqual(originalAttr, newAttr));
Expand Down
52 changes: 52 additions & 0 deletions mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,24 @@ func.func private @BSR(%arg0: tensor<?x?xf64, #BSR>) {

// -----

#BCSR = #sparse_tensor.encoding<{
map = ( i, j, k ) ->
( i floordiv 2 : dense,
j floordiv 3 : dense,
k floordiv 4 : compressed,
i mod 2 : dense,
j mod 3 : dense,
k mod 4 : dense
)
}>

// CHECK-LABEL: func private @BCSR(
// CHECK-SAME: tensor<?x?x?xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 floordiv 2 : dense, d1 floordiv 3 : dense, d2 floordiv 4 : compressed, d0 mod 2 : dense, d1 mod 3 : dense, d2 mod 4 : dense) }>>
func.func private @BCSR(%arg0: tensor<?x?x?xf64, #BCSR>) {
return
}
// -----

#BSR_explicit = #sparse_tensor.encoding<{
map =
{il, jl, ii, jj}
Expand Down Expand Up @@ -194,3 +212,37 @@ func.func private @BSR_explicit(%arg0: tensor<?x?xf64, #BSR_explicit>) {
func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
return
}

// -----

#NV_24 = #sparse_tensor.encoding<{
map = ( i, j, k ) ->
( i : dense,
j : dense,
k floordiv 4 : dense,
k mod 4 : block2_4
)
}>

// CHECK-LABEL: func private @NV_24(
// CHECK-SAME: tensor<?x?x?xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : block2_4) }>>
func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
return
}

// -----

#NV_24 = #sparse_tensor.encoding<{
map = ( i, j, k ) ->
( i : dense,
k floordiv 4 : dense,
j : dense,
k mod 4 : block2_4
)
}>

// CHECK-LABEL: func private @NV_24(
// CHECK-SAME: tensor<?x?x?xf64, #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : block2_4) }>>
func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
return
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def main():
for iwidth in [32]:
for e in [True]:
attr = st.EncodingAttr.get(
level, ordering, pwidth, iwidth
level, ordering, None, pwidth, iwidth
)
opt = f"parallelization-strategy=none"
compiler = sparse_compiler.SparseCompiler(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def main():
for pwidth in bitwidths:
for iwidth in bitwidths:
attr = st.EncodingAttr.get(
level, ordering, pwidth, iwidth
level, ordering, None, pwidth, iwidth
)
build_compile_and_run_SpMM(attr, compiler)
count = count + 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def main():
for level in levels:
for ordering in orderings:
for bwidth in bitwidths:
attr = st.EncodingAttr.get(level, ordering, bwidth, bwidth)
attr = st.EncodingAttr.get(level, ordering, None, bwidth, bwidth)
build_compile_and_run_output(attr, compiler)
count = count + 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def main():
for pwidth in bitwidths:
for iwidth in bitwidths:
attr = st.EncodingAttr.get(
level, ordering, pwidth, iwidth
level, ordering, None, pwidth, iwidth
)
types.append(ir.RankedTensorType.get(shape, f64, attr))
#
Expand Down
14 changes: 12 additions & 2 deletions mlir/test/python/dialects/sparse_tensor/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ def testEncodingAttr1D():
print(f"lvl_types: {casted.lvl_types}")
# CHECK: dim_to_lvl: None
print(f"dim_to_lvl: {casted.dim_to_lvl}")
# CHECK: lvl_to_dim: None
print(f"lvl_to_dim: {casted.lvl_to_dim}")
# CHECK: pos_width: 16
print(f"pos_width: {casted.pos_width}")
# CHECK: crd_width: 32
print(f"crd_width: {casted.crd_width}")

created = st.EncodingAttr.get(casted.lvl_types, None, 0, 0)
created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
# CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
print(created)
# CHECK: created_equal: False
Expand Down Expand Up @@ -72,12 +74,20 @@ def testEncodingAttr2D():
print(f"lvl_types: {casted.lvl_types}")
# CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
print(f"dim_to_lvl: {casted.dim_to_lvl}")
# CHECK: lvl_to_dim: (d0, d1) -> (d1, d0)
print(f"lvl_to_dim: {casted.lvl_to_dim}")
# CHECK: pos_width: 8
print(f"pos_width: {casted.pos_width}")
# CHECK: crd_width: 32
print(f"crd_width: {casted.crd_width}")

created = st.EncodingAttr.get(casted.lvl_types, casted.dim_to_lvl, 8, 32)
created = st.EncodingAttr.get(
casted.lvl_types,
casted.dim_to_lvl,
casted.lvl_to_dim,
8,
32,
)
# CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
print(created)
# CHECK: created_equal: True
Expand Down

0 comments on commit d4088e7

Please sign in to comment.