Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coord refactor #186

Open
wants to merge 11 commits into
base: sycl-develop
Choose a base branch
from
192 changes: 45 additions & 147 deletions include/cute/atom/copy_traits_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,15 @@ struct XE_2D_LD_Unpack {
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Traits_LD_t const &traits, Tensor<TS, SLayout> const &src,
Tensor<TD, DLayout> &dst) {
static_assert(is_rmem<TD>::value);

using dtype = typename Tensor<TD, DLayout>::value_type;
constexpr int dtype_size = sizeof(dtype);
constexpr int bits_in_byte = 8;
Comment on lines +170 to +171
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cutlass provides cutlass::sizeof_bits<dtype> for this


static_assert(is_rmem<TD>::value);
static_assert(size(SLayout{}) * dtype_size * bits_in_byte == size<1>(typename Traits_LD_t::SrcLayout{}),
"Src tensor size does not match copy atom size");
static_assert(size(DLayout{}) * dtype_size * bits_in_byte == size<1>(typename Traits_LD_t::DstLayout{}),
"Dst tensor size does not match copy atom size");

dtype *base_addr = (dtype *)traits.base_ptr;

Expand Down Expand Up @@ -238,6 +244,15 @@ struct XE_2D_LD_Unpack {
make_layout(t_shape, t_stride));
}

// Generate the PVC coord tensor
template <class GShape>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems unrelated to the class it's in. Maybe it shouldn't be a part of copy traits?

CUTE_HOST_DEVICE constexpr
auto
get_pvc_tensor(GShape const& g_shape) const {
static_assert(rank(GShape{}) == 3, "mismatch rank");
return make_counting_tensor(make_layout(g_shape, make_stride(E<0>(), E<1>(), E<2>())));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_tma_tensor uses g_stride_ for the 2nd arg to make_layout here. Is there any loss of generality with this simpler approach?

}

template <class... TensorArgs>
static constexpr auto with(Tensor<TensorArgs...> const &tensor) {
return Traits_LD_t{tensor};
Expand Down Expand Up @@ -297,18 +312,25 @@ template <class CopyOp, class StrideIndicator = cute::Stride<int64_t, cute::Int<
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Traits_ST_t const &traits, Tensor<TS, SLayout> const &src,
Tensor<TD, DLayout> &dst) {
static_assert(is_rmem<TS>::value);

using dtype = typename Tensor<TS, SLayout>::value_type;
constexpr int dtype_size = sizeof(dtype);
constexpr int bits_in_byte = 8;

dtype *base_addr = (dtype *)traits.base_ptr;
static_assert(is_rmem<TS>::value);
static_assert(size(SLayout{}) * dtype_size * bits_in_byte == size<1>(typename Traits_ST_t::SrcLayout{}),
"Src tensor size does not match copy atom size");
static_assert(size(DLayout{}) * dtype_size * bits_in_byte == size<1>(typename Traits_ST_t::DstLayout{}),
Comment on lines +320 to +323
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above, use cutlass::sizeof_bits<dtype> I think.

"Dst tensor size does not match copy atom size");

dtype *base_addr = (dtype *)traits.base_ptr;

auto [m, n, l] = dst.data().coord_;

CopyOp::copy(base_addr + l * traits.stride_l,
(int)(traits.width * sizeof(dtype)), (int)(traits.height),
(int)(traits.pitch * sizeof(dtype)),
intel::coord_t{(int)n, (int)m}, &*src.data());
(int)(traits.width * dtype_size), (int)(traits.height),
(int)(traits.pitch * dtype_size),
intel::coord_t{(int)n, (int)m}, &*src.data());
}

template <class Coord, class GShape>
Expand Down Expand Up @@ -345,6 +367,15 @@ template <class CopyOp, class StrideIndicator = cute::Stride<int64_t, cute::Int<
make_layout(t_shape, t_stride));
}

// Generate the PVC coord tensor
template <class GShape>
CUTE_HOST_DEVICE constexpr
auto
get_pvc_tensor(GShape const& g_shape) const {
static_assert(rank(GShape{}) == 3, "mismatch rank");
return make_counting_tensor(make_layout(g_shape, make_stride(E<0>(), E<1>(), E<2>())));
}

template <class... TensorArgs>
static constexpr auto with(Tensor<TensorArgs...> const &tensor) {
return Traits_ST_t{tensor};
Expand Down Expand Up @@ -993,11 +1024,11 @@ struct Copy_Traits<XE_2D_U16x32x32_LD_N, args_t...>
: XE_2D_LD_Unpack<XE_2D_U16x32x32_LD_N, args_t...> {
using ThrID = Layout<_16>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape <_16,_16>,
Stride< _0, _1>>;
using SrcLayout = Layout<Shape <_16,Shape <_16, _2, _32>>,
Stride<_0,Stride< _1,_256,_512>>>;
Comment on lines +1027 to +1028
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the formatting is a bit messy here

// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape <_16,Shape <_16, _2, _32>>,
Stride<_16,Stride< _1,_256,_512>>>;
Stride<_16,Stride< _1,_512,_16>>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;

Expand Down Expand Up @@ -1298,7 +1329,7 @@ struct Copy_Traits<XE_2D_U32x8x16_LD_N, args_t...>
// Logical thread id to thread idx
using ThrID = Layout<_16>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape <_16,_32>,
using SrcLayout = Layout<Shape <_16,_256>,
Stride< _0, _1>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape <_16,Shape <_32, _8>>,
Expand Down Expand Up @@ -1458,7 +1489,7 @@ struct Copy_Traits<XE_2D_U16x32x32_LD_V, args_t...>
// Logical thread id to thread idx
using ThrID = Layout<_16>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape <_16,_64>,
using SrcLayout = Layout<Shape <_16,_1024>,
Stride< _0, _1>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape <_16,Shape <_16, _2, _2, _16>>,
Expand Down Expand Up @@ -1929,8 +1960,8 @@ struct Copy_Traits<XE_2D_U32x8x16_ST_N, args_t...>
using SrcLayout = Layout<Shape <_16,Shape <_32, _8>>,
Stride<_32,Stride< _1,_512>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape <_16,_32>,
Stride< _0, _1>>;
using DstLayout = Layout<Shape <_16,_256>,
Stride< _0, _1>>; // 0 here makes all threads in a warp get the same base address
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;

Expand Down Expand Up @@ -2138,137 +2169,4 @@ BUILD_XE_NAME(32)
static_assert(dependent_false<PrefetchTileSize> && "Invalid PrefetchTileSize[0]");
}
} // end namespace detail

template <class TiledCopy, class ThrIdx>
class Xe2DThrCopy : ThrCopy<TiledCopy, ThrIdx> {

public:

CUTE_HOST_DEVICE
Xe2DThrCopy(ThrIdx const& thr_idx) : ThrCopy<TiledCopy, ThrIdx> (thr_idx) {}

template <class DTensor>
CUTE_HOST_DEVICE
auto
retile_D(DTensor&& dtensor) {
if constexpr (!TiledCopy::is_convention_MN) {
return retile_D_nkl(dtensor);
} else {
return retile_D_mkl(dtensor);
}
}

template <class MMA, class MMATensor>
CUTE_HOST_DEVICE
auto
retile_MMA(MMA const&, MMATensor&& mma_tensor) {
if constexpr (TiledCopy::is_convention_MN) {
static constexpr auto m = decltype(size<1>(mma_tensor.shape()))::value;
static constexpr auto k = decltype(size<2>(mma_tensor.shape()))::value;
static constexpr auto m_step = size<0>(typename TiledCopy::BlockShape{})
/ size<0>(typename MMA::Shape_MNK{});
static constexpr auto k_step = size<1>(typename TiledCopy::BlockShape{})
/ size<2>(typename MMA::Shape_MNK{});

auto retiled_tensor = make_tensor(mma_tensor.data(),
make_shape(size<0>(mma_tensor.shape()),
Int<m_step>{},
Int<k_step>{},
Int<m / m_step>{},
Int<k / k_step>{}));
return make_tensor(mma_tensor.data(),group<2, 4>(group<1, 3>(select<0, 1, 3, 2, 4>(retiled_tensor.layout()))));
} else {
static constexpr auto k = decltype(size<2>(mma_tensor.shape()))::value;
static constexpr auto k_step = size<0>(typename TiledCopy::BlockShape{})
/ size<2>(typename MMA::Shape_MNK{});

auto retiled_tensor = make_tensor(mma_tensor.data(),
make_shape(size<0>(mma_tensor.shape()),
Int<k_step>{},
size<1>(mma_tensor.shape()),
Int<k / k_step>{}));
return make_tensor(mma_tensor.data(),group<2, 4>(select<0, 2, 1, 3>(retiled_tensor.layout())));
}
}

private:

template <class DTensor>
CUTE_HOST_DEVICE static
auto
retile_D_mkl(DTensor&& dtensor) {
auto tmp = ThrCopy<TiledCopy, ThrIdx>::retile_D(dtensor);
return make_tensor(static_cast<decltype(tmp) &&>(tmp).data(),
tmp.shape());
}

template <class DTensor>
CUTE_HOST_DEVICE static
auto
retile_D_nkl(DTensor&& dtensor) {
auto b_tensor = make_tensor(dtensor.data(),
make_shape(size<0>(dtensor.shape()),
size<2>(dtensor.shape()),
size<1>(dtensor.shape())));
auto tmp = ThrCopy<TiledCopy, ThrIdx>::retile_D(b_tensor);
return make_tensor(static_cast<decltype(tmp) &&>(tmp).data(),
make_shape(size<0>(tmp.shape()),
size<2>(tmp.shape()),
size<1>(tmp.shape())));
}
};

template <class Copy_Atom,
class LayoutCopy_TV, // (tid,vid) -> coord [Need not be 2D...]
class ShapeTiler_MN> // coord space
struct Xe2DTiledCopy : TiledCopy<Copy_Atom, LayoutCopy_TV, ShapeTiler_MN>{

template <class ThrIdx,
__CUTE_REQUIRES(is_integral<ThrIdx>::value)>
CUTE_HOST_DEVICE
auto
get_slice(ThrIdx const& thr_idx) const
{
return Xe2DThrCopy<Xe2DTiledCopy, ThrIdx>(thr_idx);
}
};

template <class... Args,
class ThrLayout,
class ValLayout = typename Copy_Atom<Args...>::Value_Layout>
CUTE_HOST_DEVICE
auto
make_xe_2d_copy(Copy_Atom<Args...> const& copy_atom,
ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx
ValLayout const& val_layout = {}) // (m,n) -> val_idx
{
// Take the raked_products to compute the Layout_MN
// (M,N) -> (thr_idx, val_idx)
auto layout_mn = raked_product(thr_layout, val_layout);
// (thr_idx, val_idx) -> (M,N)
auto layout_tv = right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout)));
// Tiler for extracting relevant elements
// (M,N) -> tensor coord
auto tiler = product_each(shape(layout_mn));

#if 0
print("thr_layout: "); print(thr_layout); print("\n");
print("val_layout: "); print(val_layout); print("\n");
print("layout_mn : "); print(layout_mn); print("\n");
print("layout_tv : "); print(layout_tv); print("\n");
print("tiler : "); print(tiler); print("\n");
#endif

return Xe2DTiledCopy<Copy_Atom<Args...>, decltype(layout_tv), decltype(tiler)>{copy_atom};
}

// The number of threads involved in a Xe2DTiledCopy
template <class... Args>
CUTE_HOST_DEVICE constexpr
auto
size(Xe2DTiledCopy<Args...> const&)
{
return typename Xe2DTiledCopy<Args...>::TiledNumThr{};
}

} // end namespace cute
45 changes: 34 additions & 11 deletions include/cutlass/epilogue/collective/xe_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class CollectiveEpilogue<
typename FusionCallbacks::Arguments thread{};
ElementC const* ptr_C;
StrideC dC;
ElementD const* ptr_D;
ElementD* ptr_D;
StrideD dD;
};

Expand All @@ -169,6 +169,10 @@ class CollectiveEpilogue<
typename FusionCallbacks::Params thread{};
XE_Copy_C xe_load_c;
XE_Copy_D xe_store_d;
ElementC const* ptr_C = nullptr;
StrideC dC{};
ElementD* ptr_D = nullptr;
StrideD dD{};
};

//
Expand Down Expand Up @@ -206,7 +210,11 @@ class CollectiveEpilogue<
return {
FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace),
xe_load_c,
xe_store_d
xe_store_d,
args.ptr_C,
args.dC,
args.ptr_D,
args.dD
};
}

Expand Down Expand Up @@ -296,9 +304,8 @@ class CollectiveEpilogue<
// Indexing variables
auto [M, N, K, L] = problem_shape_mnkl;
auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl;
auto m_offset = m_coord * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M;
auto n_offset = n_coord * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
auto l_offset = l_coord;
auto m_sg = get_sub_group_id() / ATOM_N;
auto n_sg = get_sub_group_id() % ATOM_N;

using EpilogueTile = decltype(get<0>(params.xe_store_d.get_layoutS_MN()).shape());

Expand All @@ -310,12 +317,27 @@ class CollectiveEpilogue<
auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord);

bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed();

// Represent the full output tensor
Tensor mD_mnl = params.xe_store_d.get_pvc_tensor(make_shape(M,N,L));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be a counting tensor of D I believe, so maybe cD would be more appropriate? (I'm not really sure)


// Tile the output tensor per CTA
Tensor g_cta_D_mnl = local_tile(mD_mnl, CtaTileMNK{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)

// Slice to get the tile this CTA is responsible for // (BLK_M,BLK_N)
Tensor g_cta_D = g_cta_D_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
Comment on lines +324 to +328
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Tile the output tensor per CTA
Tensor g_cta_D_mnl = local_tile(mD_mnl, CtaTileMNK{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
// Slice to get the tile this CTA is responsible for // (BLK_M,BLK_N)
Tensor g_cta_D = g_cta_D_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
// Tile the output tensor per CTA
Tensor g_cta_D = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (BLK_M,BLK_N)

I think this is simpler.
Maybe it should be cta_cD?

Comment on lines +325 to +328
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering here, if it should be possible to avoid this and have something like

 Tensor g_cta_D_mnl  = local_tile(mD_mnl, CtaTileMNK{}, make_coord(m_coord,n_coord,l_coord), Step<_1,_1, X>{}); 


// Tile the output tensor per warp
Tensor gD_mnl = local_tile(g_cta_D, SubgroupTileShape{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)

// Slice to get the tile this warp is responsible for
Tensor gD = gD_mnl(_,_,m_sg,n_sg); // (BLK_M,BLK_N)
Comment on lines +330 to +334
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Tile the output tensor per warp
Tensor gD_mnl = local_tile(g_cta_D, SubgroupTileShape{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
// Slice to get the tile this warp is responsible for
Tensor gD = gD_mnl(_,_,m_sg,n_sg); // (BLK_M,BLK_N)
// Tile the output tensor per warp
Tensor gD = local_tile(g_cta_D, SubgroupTileShape{}, make_coord(m_sg,n_sg)); // (SG_M, SG_N)

I think this is correct too

Comment on lines +331 to +334
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here


auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx);
Tensor tCgD = thread_xe_store_d.partition_D(gD);

Tensor trC = make_tensor<typename TiledMma::ValTypeC>(Shape<Int<FragmentSize>>{});
Tensor trD = make_tensor<typename TiledMma::ValTypeD>(Shape<Int<FragmentSize>>{});
Tensor rw_coord = params.xe_store_d.get_pvc_tensor(
make_coord(m_offset, n_offset, l_offset),
make_shape(_, Int<FragsM>{}, Int<FragsN>{}));

// Because Sm90 uses shared memory, they are not tied to using the same accumulator values
// for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be
Expand All @@ -334,7 +356,7 @@ class CollectiveEpilogue<
// Get the fusion callbacks
// Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles
constexpr bool RefSrc = true;
auto residue_mn = make_coord(M, N);
auto residue_mn = make_coord(M, N); //TODO(Codeplay): this is not correct
auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{
problem_shape_mnkl,
SubgroupTileShape{},
Expand Down Expand Up @@ -367,7 +389,8 @@ class CollectiveEpilogue<
for (int epi_m = 0; epi_m < FragsM; epi_m++) {

if (is_C_load_needed) {
copy(params.xe_load_c, rw_coord(_, epi_m, epi_n), trC);
//cordinates for C and D are the same
copy(params.xe_load_c, tCgD(_, epi_m, epi_n), trC);
}

cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed);
Expand All @@ -378,7 +401,7 @@ class CollectiveEpilogue<
for (int epi_v = 0; epi_v < size(trD_frag); ++epi_v) {
trD_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n);
}
copy(params.xe_store_d, trD, rw_coord(_, epi_m, epi_n));
copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n));
}
}

Expand Down
Loading
Loading