diff --git a/lib/BUILD b/lib/BUILD index 445effd3..04773beb 100644 --- a/lib/BUILD +++ b/lib/BUILD @@ -414,7 +414,10 @@ cc_library( cc_library( name = "unitary_calculator_basic", hdrs = ["unitary_calculator_basic.h"], - deps = [":unitaryspace_basic"], + deps = [ + ":bits", + ":unitaryspace_basic" + ], ) ### Unitary mux header ### diff --git a/lib/statespace_avx.h b/lib/statespace_avx.h index 5ff17a83..f21dd3f8 100644 --- a/lib/statespace_avx.h +++ b/lib/statespace_avx.h @@ -239,6 +239,37 @@ class StateSpaceAVX : public StateSpace, For, float> { state.get()[k + 8] = im; } + // Sets state[i] = val where (i & mask) == bits + void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, + const std::complex& val) const { + BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val)); + } + + // Sets state[i] = complex(re, im) where (i & mask) == bits + void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re, + fp_type im) const { + __m256 re_reg = _mm256_set1_ps(re); + __m256 im_reg = _mm256_set1_ps(im); + + auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv, + uint64_t bitsv, __m256 re_n, __m256 im_n, fp_type* p) { + __m256 ml = + _mm256_castsi256_ps(detail::GetZeroMaskAVX(8 * i, maskv, bitsv)); + + __m256 re = _mm256_load_ps(p + 16 * i); + __m256 im = _mm256_load_ps(p + 16 * i + 8); + + re = _mm256_blendv_ps(re, re_n, ml); + im = _mm256_blendv_ps(im, im_n, ml); + + _mm256_store_ps(p + 16 * i, re); + _mm256_store_ps(p + 16 * i + 8, im); + }; + + Base::for_.Run(MinSize(state.num_qubits()) / 16, f, mask, bits, re_reg, + im_reg, state.get()); + } + // Does the equivalent of dest += src elementwise. bool Add(const State& src, State& dest) const { if (src.num_qubits() != dest.num_qubits()) { diff --git a/lib/statespace_basic.h b/lib/statespace_basic.h index c4c41bdd..2cdab2c8 100644 --- a/lib/statespace_basic.h +++ b/lib/statespace_basic.h @@ -96,6 +96,28 @@ class StateSpaceBasic : public StateSpace, For, FP> { state.get()[p + 1] = im; } + // Sets state[i] = val where (i & mask) == bits + void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, + const std::complex& val) const { + BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val)); + } + + // Sets state[i] = complex(re, im) where (i & mask) == bits + void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re, + fp_type im) const { + auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv, + uint64_t bitsv, fp_type re_n, fp_type im_n, fp_type* p) { + auto s = p + 2 * i; + bool in_mask = (i & maskv) == bitsv; + + s[0] = in_mask ? re_n : s[0]; + s[1] = in_mask ? im_n : s[1]; + }; + + Base::for_.Run(MinSize(state.num_qubits()) / 2, f, mask, bits, re, im, + state.get()); + } + // Does the equivalent of dest += src elementwise. bool Add(const State& src, State& dest) const { if (src.num_qubits() != dest.num_qubits()) { diff --git a/lib/statespace_sse.h b/lib/statespace_sse.h index 9f95217e..5c1ecd41 100644 --- a/lib/statespace_sse.h +++ b/lib/statespace_sse.h @@ -200,6 +200,36 @@ class StateSpaceSSE : public StateSpace, For, float> { state.get()[p + 4] = im; } + // Sets state[i] = val where (i & mask) == bits + void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, + const std::complex& val) const { + BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val)); + } + + // Sets state[i] = complex(re, im) where (i & mask) == bits + void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re, + fp_type im) const { + __m128 re_reg = _mm_set1_ps(re); + __m128 im_reg = _mm_set1_ps(im); + + auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv, + uint64_t bitsv, __m128 re_n, __m128 im_n, fp_type* p) { + __m128 ml = _mm_castsi128_ps(detail::GetZeroMaskSSE(4 * i, maskv, bitsv)); + + __m128 re = _mm_load_ps(p + 8 * i); + __m128 im = _mm_load_ps(p + 8 * i + 4); + + re = _mm_blendv_ps(re, re_n, ml); + im = _mm_blendv_ps(im, im_n, ml); + + _mm_store_ps(p + 8 * i, re); + _mm_store_ps(p + 8 * i + 4, im); + }; + + Base::for_.Run(MinSize(state.num_qubits()) / 8, f, mask, bits, re_reg, + im_reg, state.get()); + } + // Does the equivalent of dest += src elementwise. bool Add(const State& src, State& dest) const { if (src.num_qubits() != dest.num_qubits()) { diff --git a/lib/unitary_calculator_basic.h b/lib/unitary_calculator_basic.h index 7afabd52..b26e6e95 100644 --- a/lib/unitary_calculator_basic.h +++ b/lib/unitary_calculator_basic.h @@ -18,6 +18,7 @@ #include #include +#include "bits.h" #include "unitaryspace_basic.h" namespace qsim { @@ -69,8 +70,9 @@ class UnitaryCalculatorBasic final { const fp_type* matrix, Unitary& state) const { if (qs.size() == 1) { ApplyControlledGate1(qs[0], cqs, cmask, matrix, state); + } else if (qs.size() == 2) { + ApplyControlledGate2(qs[0], qs[1], cqs, cmask, matrix, state); } - // Implement 2 qubit version. } private: @@ -164,6 +166,101 @@ class UnitaryCalculatorBasic final { emaskh, rstate); } + void ApplyControlledGate2(unsigned q0, unsigned q1, + const std::vector& cqs, uint64_t cmask, + const fp_type* matrix, State& state) const { + uint64_t xs[2]; + uint64_t ms[3]; + + xs[0] = uint64_t{1} << (q0 + 1); + ms[0] = (uint64_t{1} << q0) - 1; + + xs[1] = uint64_t{1} << (q1 + 1); + ms[1] = ((uint64_t{1} << q1) - 1) ^ (xs[0] - 1); + + ms[2] = ((uint64_t{1} << num_qubits_) - 1) ^ (xs[1] - 1); + + uint64_t xss[4]; + for (unsigned i = 0; i < 4; ++i) { + uint64_t a = 0; + for (uint64_t k = 0; k < 2; ++k) { + if (((i >> k) & 1) == 1) { + a += xs[k]; + } + } + xss[i] = a; + } + + uint64_t emaskh = 0; + + for (auto q : cqs) { + emaskh |= uint64_t{1} << q; + } + + uint64_t cmaskh = bits::ExpandBits(cmask, num_qubits_, emaskh); + + emaskh |= uint64_t{1} << q0; + emaskh |= uint64_t{1} << q1; + + emaskh = ~emaskh; + + auto f = [](unsigned n, unsigned m, uint64_t ii, const fp_type* v, + const uint64_t* ms, const uint64_t* xss, unsigned n_qb, + unsigned sqrt_size, uint64_t cmaskh, uint64_t emaskh, + fp_type* rstate) { + fp_type rn, in; + fp_type rs[16], is[16]; + + auto row_size = uint64_t{1} << n_qb; + + uint64_t i = ii % sqrt_size; + uint64_t j = ii / sqrt_size; + + uint64_t col_loc = (1 * i & ms[0]) | (2 * i & ms[1]) | (4 * i & ms[2]); + uint64_t row_loc = bits::ExpandBits(j, n_qb, emaskh) | cmaskh; + + auto p0 = rstate + row_size * 2 * row_loc + 2 * col_loc; + + for (unsigned l = 0; l < 4; ++l) { + for (unsigned k = 0; k < 4; ++k) { + rs[4 * l + k] = *(p0 + xss[l] * row_size + xss[k]); + is[4 * l + k] = *(p0 + xss[l] * row_size + xss[k] + 1); + } + } + + for (unsigned l = 0; l < 4; l++) { + uint64_t j = 0; + for (unsigned k = 0; k < 4; ++k) { + rn = rs[l] * v[j] - is[l] * v[j + 1]; + in = rs[l] * v[j + 1] + is[l] * v[j]; + j += 2; + + for (unsigned p = 1; p < 4; ++p) { + rn += rs[4 * p + l] * v[j] - is[4 * p + l] * v[j + 1]; + in += rs[4 * p + l] * v[j + 1] + is[4 * p + l] * v[j]; + + j += 2; + } + *(p0 + xss[k] * row_size + xss[l]) = rn; + *(p0 + xss[k] * row_size + xss[l] + 1) = in; + } + } + }; + + fp_type* rstate = state.get(); + + unsigned k = 2 + cqs.size(); + unsigned n = num_qubits_ > k ? num_qubits_ - k : 0; + uint64_t size = uint64_t{1} << n; + + unsigned kk = 2; + unsigned nn = num_qubits_ > kk ? num_qubits_ - kk : 0; + uint64_t size2 = uint64_t{1} << nn; + + for_.Run(size * size2, f, matrix, ms, xss, num_qubits_, size2, cmaskh, + emaskh, rstate); + } + void ApplyGate1(unsigned q0, const fp_type* matrix, Unitary& state) const { uint64_t xs[1]; uint64_t ms[2]; @@ -259,9 +356,9 @@ class UnitaryCalculatorBasic final { xss[i] = a; } - auto f = [q0, q1](unsigned n, unsigned m, uint64_t ii, const fp_type* v, - const uint64_t* ms, const uint64_t* xss, unsigned n_qb, - unsigned sqrt_size, fp_type* rstate) { + auto f = [](unsigned n, unsigned m, uint64_t ii, const fp_type* v, + const uint64_t* ms, const uint64_t* xss, unsigned n_qb, + unsigned sqrt_size, fp_type* rstate) { fp_type rn, in; fp_type rs[16], is[16]; diff --git a/setup.py b/setup.py index b4af4784..4459cbc7 100644 --- a/setup.py +++ b/setup.py @@ -74,7 +74,7 @@ def build_extension(self, ext): # README file as long_description. long_description = open('README.md', encoding='utf-8').read() -__version__ = '0.6.0' +__version__ = '0.7.1' setup( name='qsimcirq', diff --git a/tests/statespace_avx_test.cc b/tests/statespace_avx_test.cc index 90776091..ca72a13b 100644 --- a/tests/statespace_avx_test.cc +++ b/tests/statespace_avx_test.cc @@ -62,6 +62,10 @@ TEST(StateSpaceAVXTest, InvalidStateSize) { TestInvalidStateSize>(); } +TEST(StateSpaceBasicTest, BulkSetAmpl) { + TestBulkSetAmplitude>(); +} + } // namespace qsim int main(int argc, char** argv) { diff --git a/tests/statespace_basic_test.cc b/tests/statespace_basic_test.cc index 349108d6..6a789f74 100644 --- a/tests/statespace_basic_test.cc +++ b/tests/statespace_basic_test.cc @@ -62,6 +62,10 @@ TEST(StateSpaceBasicTest, InvalidStateSize) { TestInvalidStateSize>(); } +TEST(StateSpaceBasicTest, BulkSetAmpl) { + TestBulkSetAmplitude>(); +} + } // namespace qsim int main(int argc, char** argv) { diff --git a/tests/statespace_sse_test.cc b/tests/statespace_sse_test.cc index 24a8bc68..45f61e59 100644 --- a/tests/statespace_sse_test.cc +++ b/tests/statespace_sse_test.cc @@ -62,6 +62,10 @@ TEST(StateSpaceSSETest, InvalidStateSize) { TestInvalidStateSize>(); } +TEST(StateSpaceBasicTest, BulkSetAmpl) { + TestBulkSetAmplitude>(); +} + } // namespace qsim int main(int argc, char** argv) { diff --git a/tests/statespace_testfixture.h b/tests/statespace_testfixture.h index c9f7b42f..28ca0cfd 100644 --- a/tests/statespace_testfixture.h +++ b/tests/statespace_testfixture.h @@ -809,6 +809,67 @@ void TestInvalidStateSize() { EXPECT_FALSE(!std::isnan(state_space.RealInnerProduct(state1, state2))); } +template +void TestBulkSetAmplitude() { + using State = typename StateSpace::State; + unsigned num_qubits = 3; + + StateSpace state_space(1); + + State state = state_space.Create(num_qubits); + for(int i = 0; i < 8; i++) { + state_space.SetAmpl(state, i, 1, 1); + } + state_space.BulkSetAmpl(state, 1, 0, 0, 0); + EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex(1, 1)); + + for(int i = 0; i < 8; i++) { + state_space.SetAmpl(state, i, 1, 1); + } + state_space.BulkSetAmpl(state, 2, 0, 0, 0); + EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex(1, 1)); + + for(int i = 0; i < 8; i++) { + state_space.SetAmpl(state, i, 1, 1); + } + state_space.BulkSetAmpl(state, 4, 0, 0, 0); + EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex(1, 1)); + + for(int i = 0; i < 8; i++) { + state_space.SetAmpl(state, i, 1, 1); + } + state_space.BulkSetAmpl(state, 4 | 1, 4, 0, 0); + EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex(1, 1)); + EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex(0, 0)); + EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex(1, 1)); +} + } // namespace qsim #endif // STATESPACE_TESTFIXTURE_H_ diff --git a/tests/unitary_calculator_basic_test.cc b/tests/unitary_calculator_basic_test.cc index c7e9588f..de89fcaa 100644 --- a/tests/unitary_calculator_basic_test.cc +++ b/tests/unitary_calculator_basic_test.cc @@ -12,16 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "unitary_calculator_testfixture.h" - -#include "gtest/gtest.h" +#include "../lib/unitary_calculator_basic.h" #include "../lib/formux.h" #include "../lib/unitaryspace_basic.h" -#include "../lib/unitary_calculator_basic.h" +#include "gtest/gtest.h" +#include "unitary_calculator_testfixture.h" namespace qsim { - namespace unitary { namespace { @@ -37,11 +35,15 @@ TEST(UnitaryCalculatorTest, ApplyGate2) { TestApplyGate2>(); } +TEST(UnitaryCalculatorTest, ApplyControlledGate2) { + TestApplyControlledGate2>(); +} + TEST(UnitaryCalculatorTest, ApplyFusedGate) { TestApplyFusedGate>(); } -} // namspace +} // namespace } // namespace unitary } // namespace qsim diff --git a/tests/unitary_calculator_testfixture.h b/tests/unitary_calculator_testfixture.h index 4980cdef..da0b03bd 100644 --- a/tests/unitary_calculator_testfixture.h +++ b/tests/unitary_calculator_testfixture.h @@ -320,6 +320,75 @@ void TestApplyGate2() { EUnitaryEQ(us, u, n_qubits, expected_mat_02); } +template +void TestApplyControlledGate2() { + const int n_qubits = 3; + UC uc(n_qubits, 1); + using UnitarySpace = typename UC::UnitarySpace; + using Unitary = typename UC::Unitary; + + UnitarySpace us(n_qubits, 1); + Unitary u = us.CreateUnitary(); + + // clang-format off + float ref_gate[] = {1,2,3,4,5,6,7,8, + 9,10,11,12,13,14,15,16, + 17,18,19,20,21,22,23,24, + 25,26,27,28,29,30,31,32}; + // clang-format on + + // Test applying on qubit 0, 1 + FillMatrix(us, u, n_qubits); + // clang-format off + float expected_mat_01[] = { + 0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0, + 16.0,17.0,18.0,19.0,20.0,21.0,22.0,23.0,24.0,25.0,26.0,27.0,28.0,29.0,30.0,31.0, + 32.0,33.0,34.0,35.0,36.0,37.0,38.0,39.0,40.0,41.0,42.0,43.0,44.0,45.0,46.0,47.0, + 48.0,49.0,50.0,51.0,52.0,53.0,54.0,55.0,56.0,57.0,58.0,59.0,60.0,61.0,62.0,63.0, + -372.0,3504.0,-380.0,3576.0,-388.0,3648.0,-396.0,3720.0,-404.0,3792.0,-412.0,3864.0,-420.0,3936.0,-428.0,4008.0, + -404.0,9168.0,-412.0,9368.0,-420.0,9568.0,-428.0,9768.0,-436.0,9968.0,-444.0,10168.0,-452.0,10368.0,-460.0,10568.0, + -436.0,14832.0,-444.0,15160.0,-452.0,15488.0,-460.0,15816.0,-468.0,16144.0,-476.0,16472.0,-484.0,16800.0,-492.0,17128.0, + -468.0,20496.0,-476.0,20952.0,-484.0,21408.0,-492.0,21864.0,-500.0,22320.0,-508.0,22776.0,-516.0,23232.0,-524.0,23688.0, + }; + // clang-format on + uc.ApplyControlledGate({0, 1}, {2}, 1, ref_gate, u); + EUnitaryEQ(us, u, n_qubits, expected_mat_01); + + // Test applying on qubit 1, 2 + FillMatrix(us, u, n_qubits); + // clang-format off + float expected_mat_12[] = { + 0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0, + -276.0,2960.0,-284.0,3032.0,-292.0,3104.0,-300.0,3176.0,-308.0,3248.0,-316.0,3320.0,-324.0,3392.0,-332.0,3464.0, + 32.0,33.0,34.0,35.0,36.0,37.0,38.0,39.0,40.0,41.0,42.0,43.0,44.0,45.0,46.0,47.0, + -308.0,7088.0,-316.0,7288.0,-324.0,7488.0,-332.0,7688.0,-340.0,7888.0,-348.0,8088.0,-356.0,8288.0,-364.0,8488.0, + 64.0,65.0,66.0,67.0,68.0,69.0,70.0,71.0,72.0,73.0,74.0,75.0,76.0,77.0,78.0,79.0, + -340.0,11216.0,-348.0,11544.0,-356.0,11872.0,-364.0,12200.0,-372.0,12528.0,-380.0,12856.0,-388.0,13184.0,-396.0,13512.0, + 96.0,97.0,98.0,99.0,100.0,101.0,102.0,103.0,104.0,105.0,106.0,107.0,108.0,109.0,110.0,111.0, + -372.0,15344.0,-380.0,15800.0,-388.0,16256.0,-396.0,16712.0,-404.0,17168.0,-412.0,17624.0,-420.0,18080.0,-428.0,18536.0, + }; + // clang-format on + uc.ApplyControlledGate({1, 2}, {0}, 1, ref_gate, u); + EUnitaryEQ(us, u, n_qubits, expected_mat_12); + + // Test applying on qubit 0, 2 + FillMatrix(us, u, n_qubits); + // clang-format off + float expected_mat_02[] = { + 0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0, + 16.0,17.0,18.0,19.0,20.0,21.0,22.0,23.0,24.0,25.0,26.0,27.0,28.0,29.0,30.0,31.0, + -308.0,3184.0,-316.0,3256.0,-324.0,3328.0,-332.0,3400.0,-340.0,3472.0,-348.0,3544.0,-356.0,3616.0,-364.0,3688.0, + -340.0,7824.0,-348.0,8024.0,-356.0,8224.0,-364.0,8424.0,-372.0,8624.0,-380.0,8824.0,-388.0,9024.0,-396.0,9224.0, + 64.0,65.0,66.0,67.0,68.0,69.0,70.0,71.0,72.0,73.0,74.0,75.0,76.0,77.0,78.0,79.0, + 80.0,81.0,82.0,83.0,84.0,85.0,86.0,87.0,88.0,89.0,90.0,91.0,92.0,93.0,94.0,95.0, + -372.0,12464.0,-380.0,12792.0,-388.0,13120.0,-396.0,13448.0,-404.0,13776.0,-412.0,14104.0,-420.0,14432.0,-428.0,14760.0, + -404.0,17104.0,-412.0,17560.0,-420.0,18016.0,-428.0,18472.0,-436.0,18928.0,-444.0,19384.0,-452.0,19840.0,-460.0,20296.0, + }; + // clang-format on + uc.ApplyControlledGate({0, 2}, {1}, 1, ref_gate, u); + EUnitaryEQ(us, u, n_qubits, expected_mat_02); +} + template void TestApplyFusedGate() { using UnitarySpace = typename UnitaryCalculator::UnitarySpace;