Skip to content

Commit

Permalink
xe: conv: keep reorder for plain layouts when beneficial
Browse files Browse the repository at this point in the history
  • Loading branch information
echeresh committed Jan 30, 2025
1 parent 4727d3c commit c509501
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions src/gpu/intel/jit/conv/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,9 @@ struct nc_block_t {
nc_block_t(int n_block, int c_block)
: n_block_(n_block), c_block_(c_block) {}

int n_block() const { return n_block_; }
int c_block() const { return c_block_; }

std::string tag() const {
std::vector<int> idxs = {1, 0};
return build_tag({n_block_, c_block_}, {1, 1}, {'a', 'b'}, idxs);
Expand Down Expand Up @@ -583,6 +586,18 @@ void maybe_set_plain_weights(const conv_config_t &cfg, bool src_dst_axb,
if (user_wei_tag.empty()) user_wei_tag = user_wei_req;
}

bool is_plain_tag_optimal_for_output(
const std::string &tag, const std::string &user_tag) {
// NHWC is OK with output as C is used for blocking and C is dense.
if (user_tag == "axb") return true;
// NCHW is OK only when blocked by W (not N).
if (user_tag == "abx") {
bool is_n_blocked = (tag.find("A") != std::string::npos);
return !is_n_blocked;
}
return false;
}

void init_data_tags(const conv_config_t &cfg, const memory_desc_t &src_md,
const memory_desc_t &wei_md, const memory_desc_t &dst_md,
std::string &src_tag, std::string &wei_tag, std::string &dst_tag,
Expand Down Expand Up @@ -651,9 +666,11 @@ void init_data_tags(const conv_config_t &cfg, const memory_desc_t &src_md,
if (src_abx && !src_matches) user_src_tag = "abx";
if (dst_abx && !dst_matches) user_dst_tag = "abx";

// Use plain tag for output to avoid extra reorders.
if (src_output) src_tag = user_src_tag;
if (dst_output) dst_tag = user_dst_tag;
// Use plain tag for output to avoid extra reorders when beneficial.
if (src_output && is_plain_tag_optimal_for_output(src_tag, user_src_tag))
src_tag = user_src_tag;
if (dst_output && is_plain_tag_optimal_for_output(dst_tag, user_dst_tag))
dst_tag = user_dst_tag;

if (user_src_req == "user") src_tag = user_src_tag = "user";
if (user_wei_req == "user") wei_tag = user_wei_tag = "user";
Expand Down

0 comments on commit c509501

Please sign in to comment.