diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index c849291b4f2b..7b6a770cc062 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -1236,9 +1236,31 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, const VectorType res_ty = concatenate_op.getResult().getType(); const uint32_t dimension = concatenate_op.getDimension(); if (dimension - res_ty.getRank() >= -2) { - return op.emitOpError( - "Not implemented: Concatenation along the last two dimensions"); + if (!layout.hasNaturalTopology(ctx.target_shape) || + layout.offsets() != LayoutOffsets{0, 0}) { + return op.emitOpError( + "Only native tiling with offset (0, 0) is supported when " + "concatenation along tiling dims."); + } + // Check if shapes of src and res are aligned to native tiling. + auto check_aligned = [&](const VectorType &vty) { + return vty.getRank() >= 2 && + *(vty.getShape().end() - 2) % *(layout.tiling().end() - 2) == 0 && + *(vty.getShape().end() - 1) % *(layout.tiling().end() - 1) == 0; + }; + bool is_aligned = check_aligned(res_ty); + int op_idx = 0; + while (is_aligned && op_idx < op.getNumOperands()) { + auto vty = dyn_cast(op.getOperand(op_idx++).getType()); + is_aligned = check_aligned(vty); + } + if (!is_aligned) { + return op.emitOpError( + "Only aligned shapes are supported when concatenation along tiling " + "dims"); + } } + SmallVector> tiles; tiles.reserve(concatenate_op->getNumOperands()); for (Value operand : concatenate_op.getOperands()) { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 7957c0b8529d..841c73ae8fb8 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -615,14 +615,25 @@ class VectorLayoutInferer { } LogicalResult infer(tpu::ConcatenateOp op) { - TPU_CHECK_OP(op.getDimension() - op.getType().getRank() < -2, - "Concatenation is not supported along the last two axes"); TPU_CHECK_OP(!op.getSources().empty(), "Need at least one vector to concatenate"); - // Fix all the layouts to the layout of the first operand. - // This might not be the best strategy, but it works. - SmallVector in_layouts(op.getNumOperands(), - getLayout(op.getSources().front())); + auto res_rank = op.getType().getRank(); + auto dimension = op.getDimension(); + TPU_CHECK_OP(0 <= dimension && dimension < res_rank, + "Expect a valid concatenate dimension"); + if (res_rank == 1) { + NYI("Support concatenation with 1D vectors"); + } + auto res_ty = op.getResult().getType(); + int8_t bitwidth = res_ty.getElementTypeBitWidth(); + if (bitwidth != 32) { + NYI("Support concatenation with non 32-bit data"); + } + auto layout = (dimension >= res_rank - 2) + ? VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth), + ImplicitDim::kNone) + : getLayout(op.getSources().front()); + SmallVector in_layouts(op->getNumOperands(), layout); setLayout(op, in_layouts, in_layouts.back()); return success(); } diff --git a/jaxlib/mosaic/python/apply_vector_layout.py b/jaxlib/mosaic/python/apply_vector_layout.py index 7fe7a0b3e36d..0a1f42545e62 100644 --- a/jaxlib/mosaic/python/apply_vector_layout.py +++ b/jaxlib/mosaic/python/apply_vector_layout.py @@ -2093,8 +2093,24 @@ def _tpu_concatenate_rule( raise NotImplementedError res_ty = ir.VectorType(op.result.type) dimension = ir.IntegerAttr(op.dimension).value - if dimension - res_ty.rank >= -2: - raise NotImplementedError("Concatenation along the last two dimensions") + + if dimension >= res_ty.rank - 2: + if (not layout.has_natural_topology) or layout.offsets != (0, 0): + raise NotImplementedError( + "Only native tiling with offset (0, 0) is supported when" + " concatenation along tiling dims." + ) + # Check if shapes of src and res are aligned to native tiling. + for vty in [res_ty] + [ir.VectorType(src.type) for src in op.operands]: + if ( + vty.rank < 2 + or vty.shape[-2] % layout.tiling[-2] != 0 + or vty.shape[-1] % layout.tiling[-1] != 0 + ): + raise NotImplementedError( + "Only aligned shapes are supported when concatenation along tiling" + " dims." + ) tiles = [disassemble(layout, x) for x in op.operands] res_tiles = np.concatenate(tiles, axis=dimension) ctx.replace(op, assemble(res_ty, layout, res_tiles))