Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mosaic TPU] Improve vector layout inference for vector.shape_cast #23792

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5807,7 +5807,8 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
// TODO: b/342235360 - This check is temporary while we increase and test
// support for offsets outside of the first tile. When support is more broad,
// any op without support should check it within their own rule.
if (!isa<vector::BroadcastOp, vector::ExtractStridedSliceOp>(op)) {
if (!isa<vector::BroadcastOp, vector::ExtractStridedSliceOp,
vector::ShapeCastOp>(op)) {
for (const Layout &layout : layouts_in) {
if (layout && layout->offsets()[1].has_value() &&
layout->offsets()[1].value() >= layout->tiling()[1]) {
Expand Down
119 changes: 65 additions & 54 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1354,6 +1354,9 @@ class VectorLayoutInferer {

// TODO(tlongeri): Be smarter about trying implicit dims. We should probably
// only add them when folding dimensions, and remove them when unfolding.
// The ordering of candidate implicit dims is important! Inserting an
// implicit second minor can make a reshape possible, but also very
// inefficient. We should always prefer to try with None first.
SmallVector<ImplicitDim, 3> candidate_implicit_dims;
if (res_shape.size() >= 2) {
candidate_implicit_dims.push_back(ImplicitDim::kNone);
Expand All @@ -1378,65 +1381,73 @@ class VectorLayoutInferer {
}
}

// See if we can do sublane or lane (un)folding.
for (const ImplicitDim implicit_dim : candidate_implicit_dims) {
const std::array<int64_t, 2> res_tiled_ishape =
VectorLayout::getImplicitTiledDims(implicit_dim, res_shape, 1);
// Sublane (un)folding.
if (src_tiled_ishape[1] == res_tiled_ishape[1] &&
src_tiled_ishape[0] % vreg_slice[0] == 0 &&
res_tiled_ishape[0] % vreg_slice[0] == 0) {
// Sublane (un)folding. We attempt to reduce the sublane tiling, which
// might make this reshape a no-op. We use do-while to handle the packed
// 1D tilings that use 1 in the sublane dimension.
int64_t sublane_tiling = vreg_slice[0];
do {
if (src_shape.size() >= 2 && res_shape.size() >= 2 &&
src_shape.back() == res_shape.back() &&
*(src_shape.end() - 2) % sublane_tiling == 0 &&
*(res_shape.end() - 2) % sublane_tiling == 0) {
std::array<int64_t, 2> tiling = {sublane_tiling, target_shape_[1]};
// TODO(b/343808585): We shouldn't force second minor offset to 0 when
// unfolding, it's still a no-op, but we need to add
// support in apply-vector-layout.
const LayoutOffsets offsets = {0, layout.offsets()[1]};
// unfolding, it's still a no-op, but we need to
// add support in apply-vector-layout.
LayoutOffsets offsets = {0, layout.offsets()[1]};
setLayout(op,
VectorLayout(layout.bitwidth(), offsets, layout.tiling(),
layout.implicit_dim()),
VectorLayout(layout.bitwidth(), offsets, layout.tiling(),
implicit_dim));
VectorLayout(layout.bitwidth(), offsets, tiling,
ImplicitDim::kNone),
VectorLayout(layout.bitwidth(), offsets, tiling,
ImplicitDim::kNone));
return success();
}
sublane_tiling /= 2;
} while (sublane_tiling >= layout.packing());

// Lane (un)folding.
if (src_shape.back() != res_shape.back() &&
src_shape.back() % layout.tiling()[1] == 0 &&
res_shape.back() % layout.tiling()[1] == 0) {
const int packing = kNativeBitwidth / bitwidth;
const auto elements_per_vreg = native_tiling[0] * native_tiling[1];
// When we shapecast from input shape
// (..., m * target_shape_[1] * packing) to output shape
// (..., target_shape_[1]), the reshape becomes no-op when input is
// densely packed with tiling (1, target_shape_[1] * packing) and output
// has the native tiling.
if (res_shape.size() >= 2 && res_shape.back() == target_shape_[1] &&
*(res_shape.end() - 2) % native_tiling[0] == 0 &&
src_shape.back() % elements_per_vreg == 0) {
// Inferring in_layout to have tiling (1, 128 * packing) triggers any
// necessary relayout before shapecast.
setLayout(op,
VectorLayout(layout.bitwidth(), {0, 0},
{1, target_shape_[1] * packing},
layout.implicit_dim() == ImplicitDim::kMinor
? ImplicitDim::kSecondMinor
: layout.implicit_dim()),
VectorLayout(layout.bitwidth(), {0, 0}, native_tiling,
ImplicitDim::kNone));
return success();
}
// Lane (un)folding.
if (src_tiled_ishape[1] != res_tiled_ishape[1] &&
src_tiled_ishape[1] % layout.tiling()[1] == 0 &&
res_tiled_ishape[1] % layout.tiling()[1] == 0) {
const int packing = kNativeBitwidth / bitwidth;
const auto elements_per_vreg = native_tiling[0] * native_tiling[1];
// When we shapecast from input shape
// (..., m * target_shape_[1] * packing) to output shape
// (..., target_shape_[1]), the reshape becomes no-op when input is
// densely packed with tiling (1, target_shape_[1] * packing) and output
// has the native tiling.
if (res_tiled_ishape[1] == target_shape_[1] &&
res_tiled_ishape[0] % native_tiling[0] == 0 &&
src_tiled_ishape[1] % elements_per_vreg == 0) {
// Inferring in_layout to have tiling (1, 128 * packing) triggers any
// necessary relayout before shapecast.
setLayout(op,
VectorLayout(layout.bitwidth(), {0, 0},
{1, target_shape_[1] * packing},
layout.implicit_dim()),
VectorLayout(layout.bitwidth(), {0, 0}, native_tiling,
implicit_dim));
return success();
}

// When we shapecast from input shape (..., target_shape_[1]) to output
// shape (..., m * target_shape_[1] * packing), the reshape becomes
// no-op when input has the native tiling and output is densely packed
// with tiling (1, target_shape_[1] * packing).
if (src_tiled_ishape[1] == target_shape_[1] &&
src_tiled_ishape[0] % native_tiling[0] == 0 &&
res_tiled_ishape[1] % elements_per_vreg == 0) {
setLayout(
op,
VectorLayout(layout.bitwidth(), {0, 0}, native_tiling,
layout.implicit_dim()),
VectorLayout(layout.bitwidth(), {0, 0},
{1, target_shape_[1] * packing}, implicit_dim));
return success();
}
// When we shapecast from input shape (..., target_shape_[1]) to output
// shape (..., m * target_shape_[1] * packing), the reshape becomes
// no-op when input has the native tiling and output is densely packed
// with tiling (1, target_shape_[1] * packing).
if (src_shape.size() >= 2 && src_shape.back() == target_shape_[1] &&
*(src_shape.end() - 2) % native_tiling[0] == 0 &&
res_shape.back() % elements_per_vreg == 0) {
setLayout(
op,
VectorLayout(layout.bitwidth(), {0, 0}, native_tiling,
ImplicitDim::kNone),
VectorLayout(layout.bitwidth(), {0, 0},
{1, target_shape_[1] * packing},
res_shape.size() >= 2 ? ImplicitDim::kNone
: ImplicitDim::kSecondMinor));
return success();
}
}

Expand Down