-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Coord refactor #186
base: sycl-develop
Are you sure you want to change the base?
Coord refactor #186
Changes from all commits
4eb88ab
5595428
4118c9f
6516e21
2c5226f
bbd5a58
f7dfbc0
73ec96a
fa1bc03
684804e
4149f00
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -166,9 +166,15 @@ struct XE_2D_LD_Unpack { | |
CUTE_HOST_DEVICE friend constexpr void | ||
copy_unpack(Traits_LD_t const &traits, Tensor<TS, SLayout> const &src, | ||
Tensor<TD, DLayout> &dst) { | ||
static_assert(is_rmem<TD>::value); | ||
|
||
using dtype = typename Tensor<TD, DLayout>::value_type; | ||
constexpr int dtype_size = sizeof(dtype); | ||
constexpr int bits_in_byte = 8; | ||
|
||
static_assert(is_rmem<TD>::value); | ||
static_assert(size(SLayout{}) * dtype_size * bits_in_byte == size<1>(typename Traits_LD_t::SrcLayout{}), | ||
"Src tensor size does not match copy atom size"); | ||
static_assert(size(DLayout{}) * dtype_size * bits_in_byte == size<1>(typename Traits_LD_t::DstLayout{}), | ||
"Dst tensor size does not match copy atom size"); | ||
|
||
dtype *base_addr = (dtype *)traits.base_ptr; | ||
|
||
|
@@ -238,6 +244,15 @@ struct XE_2D_LD_Unpack { | |
make_layout(t_shape, t_stride)); | ||
} | ||
|
||
// Generate the PVC coord tensor | ||
template <class GShape> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this seems unrelated to the class it's in. Maybe it shouldn't be a part of copy traits? |
||
CUTE_HOST_DEVICE constexpr | ||
auto | ||
get_pvc_tensor(GShape const& g_shape) const { | ||
static_assert(rank(GShape{}) == 3, "mismatch rank"); | ||
return make_counting_tensor(make_layout(g_shape, make_stride(E<0>(), E<1>(), E<2>()))); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
} | ||
|
||
template <class... TensorArgs> | ||
static constexpr auto with(Tensor<TensorArgs...> const &tensor) { | ||
return Traits_LD_t{tensor}; | ||
|
@@ -297,18 +312,25 @@ template <class CopyOp, class StrideIndicator = cute::Stride<int64_t, cute::Int< | |
CUTE_HOST_DEVICE friend constexpr void | ||
copy_unpack(Traits_ST_t const &traits, Tensor<TS, SLayout> const &src, | ||
Tensor<TD, DLayout> &dst) { | ||
static_assert(is_rmem<TS>::value); | ||
|
||
using dtype = typename Tensor<TS, SLayout>::value_type; | ||
constexpr int dtype_size = sizeof(dtype); | ||
constexpr int bits_in_byte = 8; | ||
|
||
dtype *base_addr = (dtype *)traits.base_ptr; | ||
static_assert(is_rmem<TS>::value); | ||
static_assert(size(SLayout{}) * dtype_size * bits_in_byte == size<1>(typename Traits_ST_t::SrcLayout{}), | ||
"Src tensor size does not match copy atom size"); | ||
static_assert(size(DLayout{}) * dtype_size * bits_in_byte == size<1>(typename Traits_ST_t::DstLayout{}), | ||
Comment on lines
+320
to
+323
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As above, use |
||
"Dst tensor size does not match copy atom size"); | ||
|
||
dtype *base_addr = (dtype *)traits.base_ptr; | ||
|
||
auto [m, n, l] = dst.data().coord_; | ||
|
||
CopyOp::copy(base_addr + l * traits.stride_l, | ||
(int)(traits.width * sizeof(dtype)), (int)(traits.height), | ||
(int)(traits.pitch * sizeof(dtype)), | ||
intel::coord_t{(int)n, (int)m}, &*src.data()); | ||
(int)(traits.width * dtype_size), (int)(traits.height), | ||
(int)(traits.pitch * dtype_size), | ||
intel::coord_t{(int)n, (int)m}, &*src.data()); | ||
} | ||
|
||
template <class Coord, class GShape> | ||
|
@@ -345,6 +367,15 @@ template <class CopyOp, class StrideIndicator = cute::Stride<int64_t, cute::Int< | |
make_layout(t_shape, t_stride)); | ||
} | ||
|
||
// Generate the PVC coord tensor | ||
template <class GShape> | ||
CUTE_HOST_DEVICE constexpr | ||
auto | ||
get_pvc_tensor(GShape const& g_shape) const { | ||
static_assert(rank(GShape{}) == 3, "mismatch rank"); | ||
return make_counting_tensor(make_layout(g_shape, make_stride(E<0>(), E<1>(), E<2>()))); | ||
} | ||
|
||
template <class... TensorArgs> | ||
static constexpr auto with(Tensor<TensorArgs...> const &tensor) { | ||
return Traits_ST_t{tensor}; | ||
|
@@ -993,11 +1024,11 @@ struct Copy_Traits<XE_2D_U16x32x32_LD_N, args_t...> | |
: XE_2D_LD_Unpack<XE_2D_U16x32x32_LD_N, args_t...> { | ||
using ThrID = Layout<_16>; | ||
// Map from (src-thr,src-val) to bit | ||
using SrcLayout = Layout<Shape <_16,_16>, | ||
Stride< _0, _1>>; | ||
using SrcLayout = Layout<Shape <_16,Shape <_16, _2, _32>>, | ||
Stride<_0,Stride< _1,_256,_512>>>; | ||
Comment on lines
+1027
to
+1028
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the formatting is a bit messy here |
||
// Map from (dst-thr,dst-val) to bit | ||
using DstLayout = Layout<Shape <_16,Shape <_16, _2, _32>>, | ||
Stride<_16,Stride< _1,_256,_512>>>; | ||
Stride<_16,Stride< _1,_512,_16>>>; | ||
// Reference map from (thr,val) to bit | ||
using RefLayout = DstLayout; | ||
|
||
|
@@ -1298,7 +1329,7 @@ struct Copy_Traits<XE_2D_U32x8x16_LD_N, args_t...> | |
// Logical thread id to thread idx | ||
using ThrID = Layout<_16>; | ||
// Map from (src-thr,src-val) to bit | ||
using SrcLayout = Layout<Shape <_16,_32>, | ||
using SrcLayout = Layout<Shape <_16,_256>, | ||
Stride< _0, _1>>; | ||
// Map from (dst-thr,dst-val) to bit | ||
using DstLayout = Layout<Shape <_16,Shape <_32, _8>>, | ||
|
@@ -1458,7 +1489,7 @@ struct Copy_Traits<XE_2D_U16x32x32_LD_V, args_t...> | |
// Logical thread id to thread idx | ||
using ThrID = Layout<_16>; | ||
// Map from (src-thr,src-val) to bit | ||
using SrcLayout = Layout<Shape <_16,_64>, | ||
using SrcLayout = Layout<Shape <_16,_1024>, | ||
Stride< _0, _1>>; | ||
// Map from (dst-thr,dst-val) to bit | ||
using DstLayout = Layout<Shape <_16,Shape <_16, _2, _2, _16>>, | ||
|
@@ -1929,8 +1960,8 @@ struct Copy_Traits<XE_2D_U32x8x16_ST_N, args_t...> | |
using SrcLayout = Layout<Shape <_16,Shape <_32, _8>>, | ||
Stride<_32,Stride< _1,_512>>>; | ||
// Map from (dst-thr,dst-val) to bit | ||
using DstLayout = Layout<Shape <_16,_32>, | ||
Stride< _0, _1>>; | ||
using DstLayout = Layout<Shape <_16,_256>, | ||
Stride< _0, _1>>; // 0 here makes all threads in a warp get the same base address | ||
// Reference map from (thr,val) to bit | ||
using RefLayout = SrcLayout; | ||
|
||
|
@@ -2138,137 +2169,4 @@ BUILD_XE_NAME(32) | |
static_assert(dependent_false<PrefetchTileSize> && "Invalid PrefetchTileSize[0]"); | ||
} | ||
} // end namespace detail | ||
|
||
template <class TiledCopy, class ThrIdx> | ||
class Xe2DThrCopy : ThrCopy<TiledCopy, ThrIdx> { | ||
|
||
public: | ||
|
||
CUTE_HOST_DEVICE | ||
Xe2DThrCopy(ThrIdx const& thr_idx) : ThrCopy<TiledCopy, ThrIdx> (thr_idx) {} | ||
|
||
template <class DTensor> | ||
CUTE_HOST_DEVICE | ||
auto | ||
retile_D(DTensor&& dtensor) { | ||
if constexpr (!TiledCopy::is_convention_MN) { | ||
return retile_D_nkl(dtensor); | ||
} else { | ||
return retile_D_mkl(dtensor); | ||
} | ||
} | ||
|
||
template <class MMA, class MMATensor> | ||
CUTE_HOST_DEVICE | ||
auto | ||
retile_MMA(MMA const&, MMATensor&& mma_tensor) { | ||
if constexpr (TiledCopy::is_convention_MN) { | ||
static constexpr auto m = decltype(size<1>(mma_tensor.shape()))::value; | ||
static constexpr auto k = decltype(size<2>(mma_tensor.shape()))::value; | ||
static constexpr auto m_step = size<0>(typename TiledCopy::BlockShape{}) | ||
/ size<0>(typename MMA::Shape_MNK{}); | ||
static constexpr auto k_step = size<1>(typename TiledCopy::BlockShape{}) | ||
/ size<2>(typename MMA::Shape_MNK{}); | ||
|
||
auto retiled_tensor = make_tensor(mma_tensor.data(), | ||
make_shape(size<0>(mma_tensor.shape()), | ||
Int<m_step>{}, | ||
Int<k_step>{}, | ||
Int<m / m_step>{}, | ||
Int<k / k_step>{})); | ||
return make_tensor(mma_tensor.data(),group<2, 4>(group<1, 3>(select<0, 1, 3, 2, 4>(retiled_tensor.layout())))); | ||
} else { | ||
static constexpr auto k = decltype(size<2>(mma_tensor.shape()))::value; | ||
static constexpr auto k_step = size<0>(typename TiledCopy::BlockShape{}) | ||
/ size<2>(typename MMA::Shape_MNK{}); | ||
|
||
auto retiled_tensor = make_tensor(mma_tensor.data(), | ||
make_shape(size<0>(mma_tensor.shape()), | ||
Int<k_step>{}, | ||
size<1>(mma_tensor.shape()), | ||
Int<k / k_step>{})); | ||
return make_tensor(mma_tensor.data(),group<2, 4>(select<0, 2, 1, 3>(retiled_tensor.layout()))); | ||
} | ||
} | ||
|
||
private: | ||
|
||
template <class DTensor> | ||
CUTE_HOST_DEVICE static | ||
auto | ||
retile_D_mkl(DTensor&& dtensor) { | ||
auto tmp = ThrCopy<TiledCopy, ThrIdx>::retile_D(dtensor); | ||
return make_tensor(static_cast<decltype(tmp) &&>(tmp).data(), | ||
tmp.shape()); | ||
} | ||
|
||
template <class DTensor> | ||
CUTE_HOST_DEVICE static | ||
auto | ||
retile_D_nkl(DTensor&& dtensor) { | ||
auto b_tensor = make_tensor(dtensor.data(), | ||
make_shape(size<0>(dtensor.shape()), | ||
size<2>(dtensor.shape()), | ||
size<1>(dtensor.shape()))); | ||
auto tmp = ThrCopy<TiledCopy, ThrIdx>::retile_D(b_tensor); | ||
return make_tensor(static_cast<decltype(tmp) &&>(tmp).data(), | ||
make_shape(size<0>(tmp.shape()), | ||
size<2>(tmp.shape()), | ||
size<1>(tmp.shape()))); | ||
} | ||
}; | ||
|
||
template <class Copy_Atom, | ||
class LayoutCopy_TV, // (tid,vid) -> coord [Need not be 2D...] | ||
class ShapeTiler_MN> // coord space | ||
struct Xe2DTiledCopy : TiledCopy<Copy_Atom, LayoutCopy_TV, ShapeTiler_MN>{ | ||
|
||
template <class ThrIdx, | ||
__CUTE_REQUIRES(is_integral<ThrIdx>::value)> | ||
CUTE_HOST_DEVICE | ||
auto | ||
get_slice(ThrIdx const& thr_idx) const | ||
{ | ||
return Xe2DThrCopy<Xe2DTiledCopy, ThrIdx>(thr_idx); | ||
} | ||
}; | ||
|
||
template <class... Args, | ||
class ThrLayout, | ||
class ValLayout = typename Copy_Atom<Args...>::Value_Layout> | ||
CUTE_HOST_DEVICE | ||
auto | ||
make_xe_2d_copy(Copy_Atom<Args...> const& copy_atom, | ||
ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx | ||
ValLayout const& val_layout = {}) // (m,n) -> val_idx | ||
{ | ||
// Take the raked_products to compute the Layout_MN | ||
// (M,N) -> (thr_idx, val_idx) | ||
auto layout_mn = raked_product(thr_layout, val_layout); | ||
// (thr_idx, val_idx) -> (M,N) | ||
auto layout_tv = right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout))); | ||
// Tiler for extracting relevant elements | ||
// (M,N) -> tensor coord | ||
auto tiler = product_each(shape(layout_mn)); | ||
|
||
#if 0 | ||
print("thr_layout: "); print(thr_layout); print("\n"); | ||
print("val_layout: "); print(val_layout); print("\n"); | ||
print("layout_mn : "); print(layout_mn); print("\n"); | ||
print("layout_tv : "); print(layout_tv); print("\n"); | ||
print("tiler : "); print(tiler); print("\n"); | ||
#endif | ||
|
||
return Xe2DTiledCopy<Copy_Atom<Args...>, decltype(layout_tv), decltype(tiler)>{copy_atom}; | ||
} | ||
|
||
// The number of threads involved in a Xe2DTiledCopy | ||
template <class... Args> | ||
CUTE_HOST_DEVICE constexpr | ||
auto | ||
size(Xe2DTiledCopy<Args...> const&) | ||
{ | ||
return typename Xe2DTiledCopy<Args...>::TiledNumThr{}; | ||
} | ||
|
||
} // end namespace cute |
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -160,7 +160,7 @@ class CollectiveEpilogue< | |||||||||||||||
typename FusionCallbacks::Arguments thread{}; | ||||||||||||||||
ElementC const* ptr_C; | ||||||||||||||||
StrideC dC; | ||||||||||||||||
ElementD const* ptr_D; | ||||||||||||||||
ElementD* ptr_D; | ||||||||||||||||
StrideD dD; | ||||||||||||||||
}; | ||||||||||||||||
|
||||||||||||||||
|
@@ -169,6 +169,10 @@ class CollectiveEpilogue< | |||||||||||||||
typename FusionCallbacks::Params thread{}; | ||||||||||||||||
XE_Copy_C xe_load_c; | ||||||||||||||||
XE_Copy_D xe_store_d; | ||||||||||||||||
ElementC const* ptr_C = nullptr; | ||||||||||||||||
StrideC dC{}; | ||||||||||||||||
ElementD* ptr_D = nullptr; | ||||||||||||||||
StrideD dD{}; | ||||||||||||||||
}; | ||||||||||||||||
|
||||||||||||||||
// | ||||||||||||||||
|
@@ -206,7 +210,11 @@ class CollectiveEpilogue< | |||||||||||||||
return { | ||||||||||||||||
FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), | ||||||||||||||||
xe_load_c, | ||||||||||||||||
xe_store_d | ||||||||||||||||
xe_store_d, | ||||||||||||||||
args.ptr_C, | ||||||||||||||||
args.dC, | ||||||||||||||||
args.ptr_D, | ||||||||||||||||
args.dD | ||||||||||||||||
}; | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
|
@@ -296,9 +304,8 @@ class CollectiveEpilogue< | |||||||||||||||
// Indexing variables | ||||||||||||||||
auto [M, N, K, L] = problem_shape_mnkl; | ||||||||||||||||
auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; | ||||||||||||||||
auto m_offset = m_coord * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M; | ||||||||||||||||
auto n_offset = n_coord * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; | ||||||||||||||||
auto l_offset = l_coord; | ||||||||||||||||
auto m_sg = get_sub_group_id() / ATOM_N; | ||||||||||||||||
auto n_sg = get_sub_group_id() % ATOM_N; | ||||||||||||||||
|
||||||||||||||||
using EpilogueTile = decltype(get<0>(params.xe_store_d.get_layoutS_MN()).shape()); | ||||||||||||||||
|
||||||||||||||||
|
@@ -310,12 +317,27 @@ class CollectiveEpilogue< | |||||||||||||||
auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord); | ||||||||||||||||
|
||||||||||||||||
bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); | ||||||||||||||||
|
||||||||||||||||
// Represent the full output tensor | ||||||||||||||||
Tensor mD_mnl = params.xe_store_d.get_pvc_tensor(make_shape(M,N,L)); | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be a counting tensor of D I believe, so maybe |
||||||||||||||||
|
||||||||||||||||
// Tile the output tensor per CTA | ||||||||||||||||
Tensor g_cta_D_mnl = local_tile(mD_mnl, CtaTileMNK{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) | ||||||||||||||||
|
||||||||||||||||
// Slice to get the tile this CTA is responsible for // (BLK_M,BLK_N) | ||||||||||||||||
Tensor g_cta_D = g_cta_D_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) | ||||||||||||||||
Comment on lines
+324
to
+328
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I think this is simpler.
Comment on lines
+325
to
+328
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am wondering here, if it should be possible to avoid this and have something like Tensor g_cta_D_mnl = local_tile(mD_mnl, CtaTileMNK{}, make_coord(m_coord,n_coord,l_coord), Step<_1,_1, X>{}); |
||||||||||||||||
|
||||||||||||||||
// Tile the output tensor per warp | ||||||||||||||||
Tensor gD_mnl = local_tile(g_cta_D, SubgroupTileShape{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) | ||||||||||||||||
|
||||||||||||||||
// Slice to get the tile this warp is responsible for | ||||||||||||||||
Tensor gD = gD_mnl(_,_,m_sg,n_sg); // (BLK_M,BLK_N) | ||||||||||||||||
Comment on lines
+330
to
+334
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I think this is correct too
Comment on lines
+331
to
+334
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here |
||||||||||||||||
|
||||||||||||||||
auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx); | ||||||||||||||||
Tensor tCgD = thread_xe_store_d.partition_D(gD); | ||||||||||||||||
|
||||||||||||||||
Tensor trC = make_tensor<typename TiledMma::ValTypeC>(Shape<Int<FragmentSize>>{}); | ||||||||||||||||
Tensor trD = make_tensor<typename TiledMma::ValTypeD>(Shape<Int<FragmentSize>>{}); | ||||||||||||||||
Tensor rw_coord = params.xe_store_d.get_pvc_tensor( | ||||||||||||||||
make_coord(m_offset, n_offset, l_offset), | ||||||||||||||||
make_shape(_, Int<FragsM>{}, Int<FragsN>{})); | ||||||||||||||||
|
||||||||||||||||
// Because Sm90 uses shared memory, they are not tied to using the same accumulator values | ||||||||||||||||
// for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be | ||||||||||||||||
|
@@ -334,7 +356,7 @@ class CollectiveEpilogue< | |||||||||||||||
// Get the fusion callbacks | ||||||||||||||||
// Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles | ||||||||||||||||
constexpr bool RefSrc = true; | ||||||||||||||||
auto residue_mn = make_coord(M, N); | ||||||||||||||||
auto residue_mn = make_coord(M, N); //TODO(Codeplay): this is not correct | ||||||||||||||||
auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ | ||||||||||||||||
problem_shape_mnkl, | ||||||||||||||||
SubgroupTileShape{}, | ||||||||||||||||
|
@@ -367,7 +389,8 @@ class CollectiveEpilogue< | |||||||||||||||
for (int epi_m = 0; epi_m < FragsM; epi_m++) { | ||||||||||||||||
|
||||||||||||||||
if (is_C_load_needed) { | ||||||||||||||||
copy(params.xe_load_c, rw_coord(_, epi_m, epi_n), trC); | ||||||||||||||||
//cordinates for C and D are the same | ||||||||||||||||
copy(params.xe_load_c, tCgD(_, epi_m, epi_n), trC); | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); | ||||||||||||||||
|
@@ -378,7 +401,7 @@ class CollectiveEpilogue< | |||||||||||||||
for (int epi_v = 0; epi_v < size(trD_frag); ++epi_v) { | ||||||||||||||||
trD_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); | ||||||||||||||||
} | ||||||||||||||||
copy(params.xe_store_d, trD, rw_coord(_, epi_m, epi_n)); | ||||||||||||||||
copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n)); | ||||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cutlass provides
cutlass::sizeof_bits<dtype>
for this