Skip to content

Commit

Permalink
Merge branch 'dagger_fix' of github.com:MichaelBroughton/qsim into da…
Browse files Browse the repository at this point in the history
…gger_fix
  • Loading branch information
MichaelBroughton committed Jan 26, 2021
2 parents 255131c + 434816f commit 70bf282
Show file tree
Hide file tree
Showing 12 changed files with 339 additions and 12 deletions.
5 changes: 4 additions & 1 deletion lib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,10 @@ cc_library(
cc_library(
name = "unitary_calculator_basic",
hdrs = ["unitary_calculator_basic.h"],
deps = [":unitaryspace_basic"],
deps = [
":bits",
":unitaryspace_basic"
],
)

### Unitary mux header ###
Expand Down
31 changes: 31 additions & 0 deletions lib/statespace_avx.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,37 @@ class StateSpaceAVX : public StateSpace<StateSpaceAVX<For>, For, float> {
state.get()[k + 8] = im;
}

// Sets state[i] = val where (i & mask) == bits
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits,
const std::complex<fp_type>& val) const {
BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val));
}

// Sets state[i] = complex(re, im) where (i & mask) == bits
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re,
fp_type im) const {
__m256 re_reg = _mm256_set1_ps(re);
__m256 im_reg = _mm256_set1_ps(im);

auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv,
uint64_t bitsv, __m256 re_n, __m256 im_n, fp_type* p) {
__m256 ml =
_mm256_castsi256_ps(detail::GetZeroMaskAVX(8 * i, maskv, bitsv));

__m256 re = _mm256_load_ps(p + 16 * i);
__m256 im = _mm256_load_ps(p + 16 * i + 8);

re = _mm256_blendv_ps(re, re_n, ml);
im = _mm256_blendv_ps(im, im_n, ml);

_mm256_store_ps(p + 16 * i, re);
_mm256_store_ps(p + 16 * i + 8, im);
};

Base::for_.Run(MinSize(state.num_qubits()) / 16, f, mask, bits, re_reg,
im_reg, state.get());
}

// Does the equivalent of dest += src elementwise.
bool Add(const State& src, State& dest) const {
if (src.num_qubits() != dest.num_qubits()) {
Expand Down
22 changes: 22 additions & 0 deletions lib/statespace_basic.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,28 @@ class StateSpaceBasic : public StateSpace<StateSpaceBasic<For, FP>, For, FP> {
state.get()[p + 1] = im;
}

// Sets state[i] = val where (i & mask) == bits
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits,
const std::complex<fp_type>& val) const {
BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val));
}

// Sets state[i] = complex(re, im) where (i & mask) == bits
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re,
fp_type im) const {
auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv,
uint64_t bitsv, fp_type re_n, fp_type im_n, fp_type* p) {
auto s = p + 2 * i;
bool in_mask = (i & maskv) == bitsv;

s[0] = in_mask ? re_n : s[0];
s[1] = in_mask ? im_n : s[1];
};

Base::for_.Run(MinSize(state.num_qubits()) / 2, f, mask, bits, re, im,
state.get());
}

// Does the equivalent of dest += src elementwise.
bool Add(const State& src, State& dest) const {
if (src.num_qubits() != dest.num_qubits()) {
Expand Down
30 changes: 30 additions & 0 deletions lib/statespace_sse.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,36 @@ class StateSpaceSSE : public StateSpace<StateSpaceSSE<For>, For, float> {
state.get()[p + 4] = im;
}

// Sets state[i] = val where (i & mask) == bits
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits,
const std::complex<fp_type>& val) const {
BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val));
}

// Sets state[i] = complex(re, im) where (i & mask) == bits
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re,
fp_type im) const {
__m128 re_reg = _mm_set1_ps(re);
__m128 im_reg = _mm_set1_ps(im);

auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv,
uint64_t bitsv, __m128 re_n, __m128 im_n, fp_type* p) {
__m128 ml = _mm_castsi128_ps(detail::GetZeroMaskSSE(4 * i, maskv, bitsv));

__m128 re = _mm_load_ps(p + 8 * i);
__m128 im = _mm_load_ps(p + 8 * i + 4);

re = _mm_blendv_ps(re, re_n, ml);
im = _mm_blendv_ps(im, im_n, ml);

_mm_store_ps(p + 8 * i, re);
_mm_store_ps(p + 8 * i + 4, im);
};

Base::for_.Run(MinSize(state.num_qubits()) / 8, f, mask, bits, re_reg,
im_reg, state.get());
}

// Does the equivalent of dest += src elementwise.
bool Add(const State& src, State& dest) const {
if (src.num_qubits() != dest.num_qubits()) {
Expand Down
105 changes: 101 additions & 4 deletions lib/unitary_calculator_basic.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <cstdint>
#include <vector>

#include "bits.h"
#include "unitaryspace_basic.h"

namespace qsim {
Expand Down Expand Up @@ -69,8 +70,9 @@ class UnitaryCalculatorBasic final {
const fp_type* matrix, Unitary& state) const {
if (qs.size() == 1) {
ApplyControlledGate1(qs[0], cqs, cmask, matrix, state);
} else if (qs.size() == 2) {
ApplyControlledGate2(qs[0], qs[1], cqs, cmask, matrix, state);
}
// Implement 2 qubit version.
}

private:
Expand Down Expand Up @@ -164,6 +166,101 @@ class UnitaryCalculatorBasic final {
emaskh, rstate);
}

void ApplyControlledGate2(unsigned q0, unsigned q1,
const std::vector<unsigned>& cqs, uint64_t cmask,
const fp_type* matrix, State& state) const {
uint64_t xs[2];
uint64_t ms[3];

xs[0] = uint64_t{1} << (q0 + 1);
ms[0] = (uint64_t{1} << q0) - 1;

xs[1] = uint64_t{1} << (q1 + 1);
ms[1] = ((uint64_t{1} << q1) - 1) ^ (xs[0] - 1);

ms[2] = ((uint64_t{1} << num_qubits_) - 1) ^ (xs[1] - 1);

uint64_t xss[4];
for (unsigned i = 0; i < 4; ++i) {
uint64_t a = 0;
for (uint64_t k = 0; k < 2; ++k) {
if (((i >> k) & 1) == 1) {
a += xs[k];
}
}
xss[i] = a;
}

uint64_t emaskh = 0;

for (auto q : cqs) {
emaskh |= uint64_t{1} << q;
}

uint64_t cmaskh = bits::ExpandBits(cmask, num_qubits_, emaskh);

emaskh |= uint64_t{1} << q0;
emaskh |= uint64_t{1} << q1;

emaskh = ~emaskh;

auto f = [](unsigned n, unsigned m, uint64_t ii, const fp_type* v,
const uint64_t* ms, const uint64_t* xss, unsigned n_qb,
unsigned sqrt_size, uint64_t cmaskh, uint64_t emaskh,
fp_type* rstate) {
fp_type rn, in;
fp_type rs[16], is[16];

auto row_size = uint64_t{1} << n_qb;

uint64_t i = ii % sqrt_size;
uint64_t j = ii / sqrt_size;

uint64_t col_loc = (1 * i & ms[0]) | (2 * i & ms[1]) | (4 * i & ms[2]);
uint64_t row_loc = bits::ExpandBits(j, n_qb, emaskh) | cmaskh;

auto p0 = rstate + row_size * 2 * row_loc + 2 * col_loc;

for (unsigned l = 0; l < 4; ++l) {
for (unsigned k = 0; k < 4; ++k) {
rs[4 * l + k] = *(p0 + xss[l] * row_size + xss[k]);
is[4 * l + k] = *(p0 + xss[l] * row_size + xss[k] + 1);
}
}

for (unsigned l = 0; l < 4; l++) {
uint64_t j = 0;
for (unsigned k = 0; k < 4; ++k) {
rn = rs[l] * v[j] - is[l] * v[j + 1];
in = rs[l] * v[j + 1] + is[l] * v[j];
j += 2;

for (unsigned p = 1; p < 4; ++p) {
rn += rs[4 * p + l] * v[j] - is[4 * p + l] * v[j + 1];
in += rs[4 * p + l] * v[j + 1] + is[4 * p + l] * v[j];

j += 2;
}
*(p0 + xss[k] * row_size + xss[l]) = rn;
*(p0 + xss[k] * row_size + xss[l] + 1) = in;
}
}
};

fp_type* rstate = state.get();

unsigned k = 2 + cqs.size();
unsigned n = num_qubits_ > k ? num_qubits_ - k : 0;
uint64_t size = uint64_t{1} << n;

unsigned kk = 2;
unsigned nn = num_qubits_ > kk ? num_qubits_ - kk : 0;
uint64_t size2 = uint64_t{1} << nn;

for_.Run(size * size2, f, matrix, ms, xss, num_qubits_, size2, cmaskh,
emaskh, rstate);
}

void ApplyGate1(unsigned q0, const fp_type* matrix, Unitary& state) const {
uint64_t xs[1];
uint64_t ms[2];
Expand Down Expand Up @@ -259,9 +356,9 @@ class UnitaryCalculatorBasic final {
xss[i] = a;
}

auto f = [q0, q1](unsigned n, unsigned m, uint64_t ii, const fp_type* v,
const uint64_t* ms, const uint64_t* xss, unsigned n_qb,
unsigned sqrt_size, fp_type* rstate) {
auto f = [](unsigned n, unsigned m, uint64_t ii, const fp_type* v,
const uint64_t* ms, const uint64_t* xss, unsigned n_qb,
unsigned sqrt_size, fp_type* rstate) {
fp_type rn, in;
fp_type rs[16], is[16];

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def build_extension(self, ext):
# README file as long_description.
long_description = open('README.md', encoding='utf-8').read()

__version__ = '0.6.0'
__version__ = '0.7.1'

setup(
name='qsimcirq',
Expand Down
4 changes: 4 additions & 0 deletions tests/statespace_avx_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ TEST(StateSpaceAVXTest, InvalidStateSize) {
TestInvalidStateSize<StateSpaceAVX<For>>();
}

TEST(StateSpaceBasicTest, BulkSetAmpl) {
TestBulkSetAmplitude<StateSpaceAVX<For>>();
}

} // namespace qsim

int main(int argc, char** argv) {
Expand Down
4 changes: 4 additions & 0 deletions tests/statespace_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ TEST(StateSpaceBasicTest, InvalidStateSize) {
TestInvalidStateSize<StateSpaceBasic<For, float>>();
}

TEST(StateSpaceBasicTest, BulkSetAmpl) {
TestBulkSetAmplitude<StateSpaceBasic<For, float>>();
}

} // namespace qsim

int main(int argc, char** argv) {
Expand Down
4 changes: 4 additions & 0 deletions tests/statespace_sse_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ TEST(StateSpaceSSETest, InvalidStateSize) {
TestInvalidStateSize<StateSpaceSSE<For>>();
}

TEST(StateSpaceBasicTest, BulkSetAmpl) {
TestBulkSetAmplitude<StateSpaceSSE<For>>();
}

} // namespace qsim

int main(int argc, char** argv) {
Expand Down
61 changes: 61 additions & 0 deletions tests/statespace_testfixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,67 @@ void TestInvalidStateSize() {
EXPECT_FALSE(!std::isnan(state_space.RealInnerProduct(state1, state2)));
}

template <typename StateSpace>
void TestBulkSetAmplitude() {
using State = typename StateSpace::State;
unsigned num_qubits = 3;

StateSpace state_space(1);

State state = state_space.Create(num_qubits);
for(int i = 0; i < 8; i++) {
state_space.SetAmpl(state, i, 1, 1);
}
state_space.BulkSetAmpl(state, 1, 0, 0, 0);
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex<float>(1, 1));

for(int i = 0; i < 8; i++) {
state_space.SetAmpl(state, i, 1, 1);
}
state_space.BulkSetAmpl(state, 2, 0, 0, 0);
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex<float>(1, 1));

for(int i = 0; i < 8; i++) {
state_space.SetAmpl(state, i, 1, 1);
}
state_space.BulkSetAmpl(state, 4, 0, 0, 0);
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex<float>(1, 1));

for(int i = 0; i < 8; i++) {
state_space.SetAmpl(state, i, 1, 1);
}
state_space.BulkSetAmpl(state, 4 | 1, 4, 0, 0);
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex<float>(1, 1));
}

} // namespace qsim

#endif // STATESPACE_TESTFIXTURE_H_
14 changes: 8 additions & 6 deletions tests/unitary_calculator_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "unitary_calculator_testfixture.h"

#include "gtest/gtest.h"
#include "../lib/unitary_calculator_basic.h"

#include "../lib/formux.h"
#include "../lib/unitaryspace_basic.h"
#include "../lib/unitary_calculator_basic.h"
#include "gtest/gtest.h"
#include "unitary_calculator_testfixture.h"

namespace qsim {

namespace unitary {
namespace {

Expand All @@ -37,11 +35,15 @@ TEST(UnitaryCalculatorTest, ApplyGate2) {
TestApplyGate2<UnitaryCalculatorBasic<For, float>>();
}

TEST(UnitaryCalculatorTest, ApplyControlledGate2) {
TestApplyControlledGate2<UnitaryCalculatorBasic<For, float>>();
}

TEST(UnitaryCalculatorTest, ApplyFusedGate) {
TestApplyFusedGate<UnitaryCalculatorBasic<For, float>>();
}

} // namspace
} // namespace
} // namespace unitary
} // namespace qsim

Expand Down
Loading

0 comments on commit 70bf282

Please sign in to comment.