Skip to content

Commit

Permalink
simd_builder: fix constant locations
Browse files Browse the repository at this point in the history
  • Loading branch information
Nekotekina committed Aug 29, 2022
1 parent e287070 commit 10b07f8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
31 changes: 19 additions & 12 deletions Utilities/JIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ asmjit::simd_builder::simd_builder(CodeHolder* ch) noexcept
: native_asm(ch)
{
_init(true);
consts[~u128()] = this->newLabel();
}

void asmjit::simd_builder::_init(bool full)
Expand Down Expand Up @@ -402,6 +403,16 @@ void asmjit::simd_builder::_init(bool full)
}
}

void asmjit::simd_builder::operator()() noexcept
{
for (auto&& [x, y] : consts)
{
this->align(AlignMode::kData, 16);
this->bind(y);
this->embed(&x, 16);
}
}

void asmjit::simd_builder::vec_cleanup_ret()
{
if (utils::has_avx() && vsize > 16)
Expand Down Expand Up @@ -437,23 +448,19 @@ void asmjit::simd_builder::vec_set_const(const Operand& v, const v128& val)
return vec_set_all_zeros(v);
if (!~val._u)
return vec_set_all_ones(v);

if (uptr(&val) < 0x8000'0000)
else
{
// Assume the constant comes from a code or data segment (unsafe)
Label co = consts[val._u];
if (!co.isValid())
co = consts[val._u] = this->newLabel();
if (x86::Zmm zr(v.id()); zr == v)
this->vbroadcasti32x4(zr, x86::oword_ptr(uptr(&val)));
this->vbroadcasti32x4(zr, x86::oword_ptr(co));
else if (x86::Ymm yr(v.id()); yr == v)
this->vbroadcasti128(yr, x86::oword_ptr(uptr(&val)));
this->vbroadcasti128(yr, x86::oword_ptr(co));
else if (utils::has_avx())
this->vmovaps(x86::Xmm(v.id()), x86::oword_ptr(uptr(&val)));
this->vmovaps(x86::Xmm(v.id()), x86::oword_ptr(co));
else
this->movaps(x86::Xmm(v.id()), x86::oword_ptr(uptr(&val)));
}
else
{
// TODO
fmt::throw_exception("Unexpected constant location");
this->movaps(x86::Xmm(v.id()), x86::oword_ptr(co));
}
}

Expand Down
13 changes: 11 additions & 2 deletions Utilities/JIT.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,17 @@ namespace asmjit
#if defined(ARCH_X64)
struct simd_builder : native_asm
{
std::unordered_map<u128, Label> consts;

Operand v0, v1, v2, v3, v4, v5;

uint vsize = 16;
uint vmask = 0;

simd_builder(CodeHolder* ch) noexcept;

void operator()() noexcept;

void _init(bool full);
void vec_cleanup_ret();
void vec_set_all_zeros(const Operand& v);
Expand Down Expand Up @@ -312,8 +316,7 @@ namespace asmjit
if (vmask)
{
// Build single last iteration (masked)
static constexpr u64 all_ones = -1;
this->bzhi(reg_cnt, x86::Mem(uptr(&all_ones)), reg_cnt);
this->bzhi(reg_cnt, x86::Mem(consts[~u128()], 0), reg_cnt);
this->kmovq(x86::k7, reg_cnt);
vmask = 7;
build();
Expand Down Expand Up @@ -427,6 +430,12 @@ inline FT build_function_asm(std::string_view name, F&& builder, ::jit_runtime*
builder(compiler, args);
}

if constexpr (std::is_invocable_r_v<void, Asm>)
{
// Finalization
compiler();
}

const auto result = rt._add(&code);
jit_announce(result, code.codeSize(), name);
return reinterpret_cast<FT>(uptr(result));
Expand Down

0 comments on commit 10b07f8

Please sign in to comment.