Skip to content

Commit

Permalink
Merge pull request #274 from MichaelBroughton/exclude_support
Browse files Browse the repository at this point in the history
Add inclusion/exclusion support on bulksetampl.
  • Loading branch information
95-martin-orion authored Jan 26, 2021
2 parents 3824fb2 + f6d1444 commit 8e2781a
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 29 deletions.
27 changes: 18 additions & 9 deletions lib/statespace_avx.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,22 +239,31 @@ class StateSpaceAVX : public StateSpace<StateSpaceAVX<For>, For, float> {
state.get()[k + 8] = im;
}

// Sets state[i] = val where (i & mask) == bits
// Sets state[i] = complex(re, im) where (i & mask) == bits.
// if `exclude` is true then the criteria becomes (i & mask) != bits.
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits,
const std::complex<fp_type>& val) const {
BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val));
const std::complex<fp_type>& val,
bool exclude = false) const {
BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val), exclude);
}

// Sets state[i] = complex(re, im) where (i & mask) == bits
// Sets state[i] = complex(re, im) where (i & mask) == bits.
// if `exclude` is true then the criteria becomes (i & mask) != bits.
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re,
fp_type im) const {
fp_type im, bool exclude = false) const {
__m256 re_reg = _mm256_set1_ps(re);
__m256 im_reg = _mm256_set1_ps(im);

__m256i exclude_reg = _mm256_setzero_si256();
if (exclude) {
exclude_reg = _mm256_cmpeq_epi32(exclude_reg, exclude_reg);
}

auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv,
uint64_t bitsv, __m256 re_n, __m256 im_n, fp_type* p) {
__m256 ml =
_mm256_castsi256_ps(detail::GetZeroMaskAVX(8 * i, maskv, bitsv));
uint64_t bitsv, __m256 re_n, __m256 im_n, __m256i exclude_n,
fp_type* p) {
__m256 ml = _mm256_castsi256_ps(_mm256_xor_si256(
detail::GetZeroMaskAVX(8 * i, maskv, bitsv), exclude_n));

__m256 re = _mm256_load_ps(p + 16 * i);
__m256 im = _mm256_load_ps(p + 16 * i + 8);
Expand All @@ -267,7 +276,7 @@ class StateSpaceAVX : public StateSpace<StateSpaceAVX<For>, For, float> {
};

Base::for_.Run(MinSize(state.num_qubits()) / 16, f, mask, bits, re_reg,
im_reg, state.get());
im_reg, exclude_reg, state.get());
}

// Does the equivalent of dest += src elementwise.
Expand Down
20 changes: 12 additions & 8 deletions lib/statespace_basic.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,26 +96,30 @@ class StateSpaceBasic : public StateSpace<StateSpaceBasic<For, FP>, For, FP> {
state.get()[p + 1] = im;
}

// Sets state[i] = val where (i & mask) == bits
// Sets state[i] = complex(re, im) where (i & mask) == bits.
// if `exclude` is true then the criteria becomes (i & mask) != bits.
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits,
const std::complex<fp_type>& val) const {
BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val));
const std::complex<fp_type>& val,
bool exclude = false) const {
BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val), exclude);
}

// Sets state[i] = complex(re, im) where (i & mask) == bits
// Sets state[i] = complex(re, im) where (i & mask) == bits.
// if `exclude` is true then the criteria becomes (i & mask) != bits.
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re,
fp_type im) const {
fp_type im, bool exclude = false) const {
auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv,
uint64_t bitsv, fp_type re_n, fp_type im_n, fp_type* p) {
uint64_t bitsv, fp_type re_n, fp_type im_n, bool excludev,
fp_type* p) {
auto s = p + 2 * i;
bool in_mask = (i & maskv) == bitsv;

in_mask ^= excludev;
s[0] = in_mask ? re_n : s[0];
s[1] = in_mask ? im_n : s[1];
};

Base::for_.Run(MinSize(state.num_qubits()) / 2, f, mask, bits, re, im,
state.get());
exclude, state.get());
}

// Does the equivalent of dest += src elementwise.
Expand Down
23 changes: 16 additions & 7 deletions lib/statespace_sse.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,21 +200,30 @@ class StateSpaceSSE : public StateSpace<StateSpaceSSE<For>, For, float> {
state.get()[p + 4] = im;
}

// Sets state[i] = val where (i & mask) == bits
// Sets state[i] = complex(re, im) where (i & mask) == bits.
// if `exclude` is true then the criteria becomes (i & mask) != bits.
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits,
const std::complex<fp_type>& val) const {
const std::complex<fp_type>& val,
bool exclude = false) const {
BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val));
}

// Sets state[i] = complex(re, im) where (i & mask) == bits
// Sets state[i] = complex(re, im) where (i & mask) == bits.
// if `exclude` is true then the criteria becomes (i & mask) != bits.
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re,
fp_type im) const {
fp_type im, bool exclude = false) const {
__m128 re_reg = _mm_set1_ps(re);
__m128 im_reg = _mm_set1_ps(im);
__m128i exclude_reg = _mm_setzero_si128();
if (exclude) {
exclude_reg = _mm_cmpeq_epi32(exclude_reg, exclude_reg);
}

auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv,
uint64_t bitsv, __m128 re_n, __m128 im_n, fp_type* p) {
__m128 ml = _mm_castsi128_ps(detail::GetZeroMaskSSE(4 * i, maskv, bitsv));
uint64_t bitsv, __m128 re_n, __m128 im_n, __m128i exclude_n,
fp_type* p) {
__m128 ml = _mm_castsi128_ps(_mm_xor_si128(
detail::GetZeroMaskSSE(4 * i, maskv, bitsv), exclude_n));

__m128 re = _mm_load_ps(p + 8 * i);
__m128 im = _mm_load_ps(p + 8 * i + 4);
Expand All @@ -227,7 +236,7 @@ class StateSpaceSSE : public StateSpace<StateSpaceSSE<For>, For, float> {
};

Base::for_.Run(MinSize(state.num_qubits()) / 8, f, mask, bits, re_reg,
im_reg, state.get());
im_reg, exclude_reg, state.get());
}

// Does the equivalent of dest += src elementwise.
Expand Down
10 changes: 9 additions & 1 deletion tests/statespace_avx_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,18 @@ TEST(StateSpaceAVXTest, InvalidStateSize) {
TestInvalidStateSize<StateSpaceAVX<For>>();
}

TEST(StateSpaceBasicTest, BulkSetAmpl) {
TEST(StateSpaceAVXTest, BulkSetAmpl) {
TestBulkSetAmplitude<StateSpaceAVX<For>>();
}

TEST(StateSpaceAVXTest, BulkSetAmplExclude) {
TestBulkSetAmplitudeExclusion<StateSpaceAVX<For>>();
}

TEST(StateSpaceAVXTest, BulkSetAmplDefault) {
TestBulkSetAmplitudeDefault<StateSpaceAVX<For>>();
}

} // namespace qsim

int main(int argc, char** argv) {
Expand Down
8 changes: 8 additions & 0 deletions tests/statespace_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ TEST(StateSpaceBasicTest, BulkSetAmpl) {
TestBulkSetAmplitude<StateSpaceBasic<For, float>>();
}

TEST(StateSpaceBasicTest, BulkSetAmplExclude) {
TestBulkSetAmplitudeExclusion<StateSpaceBasic<For, float>>();
}

TEST(StateSpaceBasicTest, BulkSetAmplDefault) {
TestBulkSetAmplitudeDefault<StateSpaceBasic<For, float>>();
}

} // namespace qsim

int main(int argc, char** argv) {
Expand Down
10 changes: 9 additions & 1 deletion tests/statespace_sse_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,18 @@ TEST(StateSpaceSSETest, InvalidStateSize) {
TestInvalidStateSize<StateSpaceSSE<For>>();
}

TEST(StateSpaceBasicTest, BulkSetAmpl) {
TEST(StateSpaceSSETest, BulkSetAmpl) {
TestBulkSetAmplitude<StateSpaceSSE<For>>();
}

TEST(StateSpaceSSETest, BulkSetAmplExclude) {
TestBulkSetAmplitudeExclusion<StateSpaceSSE<For>>();
}

TEST(StateSpaceSSETest, BulkSetAmplDefault) {
TestBulkSetAmplitudeDefault<StateSpaceSSE<For>>();
}

} // namespace qsim

int main(int argc, char** argv) {
Expand Down
89 changes: 86 additions & 3 deletions tests/statespace_testfixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ void TestBulkSetAmplitude() {
for(int i = 0; i < 8; i++) {
state_space.SetAmpl(state, i, 1, 1);
}
state_space.BulkSetAmpl(state, 1, 0, 0, 0);
state_space.BulkSetAmpl(state, 1, 0, 0, 0, false);
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(0, 0));
Expand All @@ -833,7 +833,7 @@ void TestBulkSetAmplitude() {
for(int i = 0; i < 8; i++) {
state_space.SetAmpl(state, i, 1, 1);
}
state_space.BulkSetAmpl(state, 2, 0, 0, 0);
state_space.BulkSetAmpl(state, 2, 0, 0, 0, false);
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(1, 1));
Expand All @@ -846,7 +846,7 @@ void TestBulkSetAmplitude() {
for(int i = 0; i < 8; i++) {
state_space.SetAmpl(state, i, 1, 1);
}
state_space.BulkSetAmpl(state, 4, 0, 0, 0);
state_space.BulkSetAmpl(state, 4, 0, 0, 0, false);
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(0, 0));
Expand All @@ -856,6 +856,89 @@ void TestBulkSetAmplitude() {
EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex<float>(1, 1));

for(int i = 0; i < 8; i++) {
state_space.SetAmpl(state, i, 1, 1);
}
state_space.BulkSetAmpl(state, 4 | 1, 4, 0, 0, false);
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex<float>(1, 1));
}

template <typename StateSpace>
void TestBulkSetAmplitudeExclusion() {
using State = typename StateSpace::State;
unsigned num_qubits = 3;

StateSpace state_space(1);

State state = state_space.Create(num_qubits);
for(int i = 0; i < 8; i++) {
state_space.SetAmpl(state, i, 1, 1);
}
state_space.BulkSetAmpl(state, 1, 0, 0, 0, true);
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex<float>(0, 0));

for(int i = 0; i < 8; i++) {
state_space.SetAmpl(state, i, 1, 1);
}
state_space.BulkSetAmpl(state, 2, 0, 0, 0, true);
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex<float>(0, 0));

for(int i = 0; i < 8; i++) {
state_space.SetAmpl(state, i, 1, 1);
}
state_space.BulkSetAmpl(state, 4, 0, 0, 0, true);
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex<float>(0, 0));

for(int i = 0; i < 8; i++) {
state_space.SetAmpl(state, i, 1, 1);
}
state_space.BulkSetAmpl(state, 4 | 1, 4, 0, 0, true);
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex<float>(0, 0));
}

template <typename StateSpace>
void TestBulkSetAmplitudeDefault() {
using State = typename StateSpace::State;
unsigned num_qubits = 3;

StateSpace state_space(1);

State state = state_space.Create(num_qubits);
for(int i = 0; i < 8; i++) {
state_space.SetAmpl(state, i, 1, 1);
}
Expand Down

0 comments on commit 8e2781a

Please sign in to comment.