Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update and correct pow function for PauliSum #6019

Merged
merged 25 commits into from
Mar 21, 2023

Conversation

TarunSinghania
Copy link
Contributor

@TarunSinghania TarunSinghania commented Feb 24, 2023

Updated pow function for PauliSum to use binary exponentiation.

Updated pow function for PauliSum to use binary exponentiation
@google-cla
Copy link

google-cla bot commented Feb 24, 2023

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

+Corrected formatting
+Corrected initialisation of remainder to identity
+Corrected formatting
+Corrected initialisation of remainder to identity
Copy link
Collaborator

@tanujkhattar tanujkhattar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a new test to cirq-core/cirq/ops/linear_combinations_test.py which covers this change, i.e. the test should have failed on the old code and pass on new code.

cirq-core/cirq/ops/linear_combinations.py Outdated Show resolved Hide resolved
@tanujkhattar
Copy link
Collaborator

closes #6017

@tanujkhattar tanujkhattar self-assigned this Feb 27, 2023
@CirqBot CirqBot added the size: S 10< lines changed <50 label Mar 2, 2023
@TarunSinghania
Copy link
Contributor Author

TarunSinghania commented Mar 14, 2023

Thanks everyone for your inputs, It was helpful in developing the below understanding:

There are two ways to implement the pow function i.e either use a linear multiplication approach or binary exponentiation.

To understand the performance aspect of both of them. I am leaving three comments.
First is a quick revision on PauliSum and its implementation of multiplication. #6019 (comment)
Second iterates over various ways to implement the exponentiation #6019 (comment)
Third is runtime tests and analysis. #6019 (comment)

@TarunSinghania
Copy link
Contributor Author

PauliSum:

PauliSum is a linear combination of PauliStrings. The way PauliSum stores it is as LinearDict[PauliString] = coefficient. The LinearDict implementation can be found in linear_dict.py

Multiplication: PauliSum*PauliSum : __mul__ -> __imul__
Flow of multiplication

Post Multiplication the number of terms can range from [1,N*M]

PauliString:

Tensor product of single qubit (non identity) pauli operations, each acting on a different qubit.If more than one pauli operation acts on the same set of qubits, their composition is immediately reduced to an equivalent (possibly multi-qubit) Pauli operator

The way we store PauliStrings are Dict[TKey, 'cirq.Pauli'], where Tkey = TypeVar('TKey', bound=raw_types.Qid) i.e a quantum object (qubit, qudit, resonator) and Where cirq.Pauli is a pauli gate operation which can be X,Y, Z Gates.

Multiplication of two PauliStrings: PauliString1*PauliString2
__mul__ -> constructor -> _inplace_left_multiply_by -> _imul_helper_checkpoint -> _imul_helper -> _imul_atom_helper

Pauli String Multiplication

if we multiply two paulistring with length a and b respectively, we will be left with a string of length [1,a+b] depending on the underlying objects which make up the Paulistrings

The time complexity of the multiplication operation is o(a + b) in average case considering pop is o(1)

Summary:Two PauliSum with N pauli strings and M pauliStrings take
O(N*M)*(Cost of multiplying PauliStrings) where cost of multiplying two PauliStrings of length a and b is O(a+b)

@TarunSinghania
Copy link
Contributor Author

TarunSinghania commented Mar 14, 2023

Various Approaches:

If we have a PauliSum with n PauliStrings each of length l acting on x qubits. Time complexity to exponentiate it to the power e

Linear Multiplication

def __pow__(self, exponent: int):
       if not isinstance(exponent, numbers.Integral):
           return NotImplemented
       if exponent == 0:
           return PauliSum(value.LinearDict({frozenset(): 1 + 0j}))
       if exponent > 0:
           base = self.copy()
           for _ in range(exponent - 1):
               base *= self.copy()
           return base
       return NotImplemented

Time Complexity: n + n**2 + n**3 + …n***e

Binary exponentiation

def __pow__(self, exponent: int):
       if not isinstance(exponent, numbers.Integral):
           return NotImplemented
       if exponent == 0:
           return PauliSum(value.LinearDict({frozenset(): 1 + 0j}))
       if exponent > 0:
           remainder = PauliSum.from_pauli_strings([PauliString(coefficient=1)])
	    base = self.copy()
           while exponent > 0 :
               if exponent&1 :
                   remainder = remainder * base
               base *= base
               exponent = exponent >> 1
           return remainder
       return NotImplemented

Comparing Linear and Binary exponentiation

Exponent e Linear Binary exponentiation
2 n + n**2 n + n**2 + n**2 + n**4
3 n + n**2 + n**3 n + n + n**2 + n**3 + n**4
4 n + n**2 + n**3 + n**4 n + n**2 + n**4 + n**4 + n**16
5 n + n**2 + n**3 + n**4 + n**5 n + n + n**2 + n**4 + n**5 + n**16

The reason why this performs worse than linear multiplication: we go in the loop when e reduces to 1 and perform a extra base*base which costs very high
In order to prevent this the algorithm can be modified as below:

Modified Binary exponentiation 1

def __pow__(self, exponent: int):
       if not isinstance(exponent, numbers.Integral):
           return NotImplemented
       if exponent == 0:
           return PauliSum(value.LinearDict({frozenset(): 1 + 0j}))
       if exponent > 1:
           remainder = PauliSum.from_pauli_strings([PauliString(coefficient=1)])
	    base = self.copy()
           while exponent > 1 :
               if exponent&1 :
                   remainder = remainder * base
		exponent = exponent >> 1
               base *= base
	   return remainder * base     
       return NotImplemented
Exponent e Linear Binary exponentiation
2 n + n**2 n + n**2 + n**2 + n**4
3 n + n**2 + n**3 n + n + n**2 + n**3 + n**4
4 n + n**2 + n**3 + n**4 n + n**2 + n**4 + n**4 + n**16
5 n + n**2 + n**3 + n**4 + n**5 n + n + n**2  + n**4 + n**5 + n**16

We are still multiplying if remainder equals identity when we exit the loop, which can be removed

Modified Binary exponentiation 2

def __pow__(self, exponent: int):
       if not isinstance(exponent, numbers.Integral):
           return NotImplemented
       if exponent == 0:
           return PauliSum(value.LinearDict({frozenset(): 1 + 0j}))
       if exponent > 1:
           remainder = PauliSum.from_pauli_strings([PauliString(coefficient=1)])
	    base = self.copy()
           while exponent > 1 :
               if exponent&1 :
                   remainder = remainder * base
		exponent = exponent >> 1
               base *= base
	   return remainder * base     
       return NotImplemented
Exponent e Linear Binary exponentiation
2 n + n**2 n + n**2 + n**2
3 n + n**2 + n**3 n + n + n**2 + n**3
4 n + n**2 + n**3 + n**4 n + n**2 + n**4 + n**4
5 n + n**2 + n**3 + n**4 + n**5 n + n + n**2 + n**4 + n**5

We are still multiplying if remainder equals identity when we exit the loop, which can be removed

Modified Binary exponentiation 3

def __pow__(self, exponent: int):
       if not isinstance(exponent, numbers.Integral):
           return NotImplemented
       if exponent == 0:
           return PauliSum(value.LinearDict({frozenset(): 1 + 0j}))
       if exponent > 1:
           remainder = PauliSum.from_pauli_strings([PauliString(coefficient=1)])
	    base = self.copy()
	    identity = remainder.copy()
           while exponent > 1 :
               if exponent&1 :
                   remainder = remainder * base
		exponent = exponent >> 1
               base*=base
	   if(remainder != identity) base*=remainder
	   return base     
       return NotImplemented
e Linear Binary exponentiation
2 n + n**2 n + n**2
3 n + n**2 + n**3 n + n + n**2 + n**3
4 n + n**2 + n**3 + n**4 n + n**2 + n**4
5 n + n**2 + n**3 + n**4 + n**5 n + n + n**2  + n**4 + n**5
6 n + n**2 + n**3 + n**4 + n**5 + n**6 n + n**2 + n**2 + n**4 +n**6
7 n + n**2 + n**3 + n**4 + n**5 + n**6 + n**7 n + n + n**2 + n**3 + n**4 + n**7

Observations:
We can see from e = 7 for example we are preventing n5 + n6 - n PauliString multiplications if we opt for binary exponentiation.
Even though the number of iterations inside the loop is reduced to log(n) each iteration the cost of multiplication grows exponentially. The improvements for low e values (2-20) and various values of n are indicated in the attached sheet
For lower values of N (2-10) the improvements can go upto 40-50% for certain values of e.
As N increases the improvements reduce to 6% for n = 16 and 1-2% for n= 50
The only place where the binary method performs bad w.r.t linear is when e equals 3 where we do extra n steps. This can be handled as for n <=3
We can apply regular linear multiplication and for n>3 binary exponentiaiton.

Final Binary exponentiation

def __pow__(self, exponent: int):
       if not isinstance(exponent, numbers.Integral):
           return NotImplemented
       identity = PauliSum.from_pauli_strings([PauliString(coefficient=1)])
       remainder = identity.copy()
       base = self.copy()
       if exponent == 0:
           return PauliSum(value.LinearDict({frozenset(): 1 + 0j}))
       if exponent > 0 and exponent <= 3:
           for _ in range(exponent - 1):
              base *= self.copy()
           return base
       if exponent > 3:
           while exponent > 1:
               if exponent & 1:
                   remainder = remainder * base
               base *= base
               exponent = exponent >> 1
           if remainder != identity:
               base*=remainder
           return base
       return NotImplemented

In all of these we are missing that, every time we multiply PauliSum of n terms with itself - we have n PauliStrings multiplying with each other taking o(n**2) steps. The resultant PauliSum can have number of terms between [1, n**2 - n + 1]. As we exponentiate, more terms will merge to identity and the number of terms in PauliSum decreases

Later as we run tests we will observe its effects.

@TarunSinghania
Copy link
Contributor Author

TarunSinghania commented Mar 14, 2023

Example test: (One can vary n, e, and construction of PauliSum)
N is the number of terms in a PauliSum and e is the exponent to which we are raising it
Our PauliSum is like 1 + X(q0) .... X(qn) [Why I choose this is explained later]

import cirq
import time
n = 7
e = 20
qubits = cirq.LineQubit.range(n)
pstrings = []
for qubit in qubits:
   pstrings.append(cirq.PauliString(1, cirq.X(qubit)))
psum = cirq.PauliSum.from_pauli_strings(pstrings)
for n in range (2, e):
   st = time.process_time()
   psum2 = (psum**n)
   et = time.process_time()
   res = et - st
   print(psum2)
   print(res)
   #print('CPU Execution time:', res, 'seconds')

From test runs we see, for n = 7 and e = 7 the binary exponentiation takes twice as time as linear which is counter intuitive as n**5 + n**6 - n PauliString multiplications should have been prevented in the binary exponentiation approach.

e Linear Binary exponentiation
7 n + n**2 + n**3 + n**4 + n**5 + n**6 + n**7 n + n + n**2 + n**3 + n**4 + n**7

Looking at more details

For Binary
Time inside while loop for first n**2 + n was 0.001870999999999956
n here has 7 paulistrings, after multiplication n**2 has 21 PauliStrings
Time inside while loop for n**4 + n**3 = 0.018809000000000076
n**4 now has 56 paulistrings while n**3 has 41 paulistrings
Final multiplication costs: 0.08067499999999983

Overall total multiplications performed = 7 + 7 + 49 + 21*7 + 21*21 + 56*41 = 2947

For Linear

1st iteration in while loop was 0.0014419999999999433
n**2 has 21 terms ....
Total time = 0.05953900000000001

Overall total multiplications performed = 7 + 7*7 + 21*7 + 41*7 + 56*7 + 62*7 + 63*7 + 64*7 = 2205

The rate at which the number of PauliStrings that decreased with increasing power when multipled by constant number of Paulistrings gave better results.

Overall we observe:
pow fun behaviour

Until the exponent saturates(meeting point) the cost of binary exponentiation is higher.This is because:
In Binary:
self*self : At each step where we multiply is getting exponentially expensive
Both left and right have large number of terms with each term consisting of multiple qubits
In Linear:
base*self.copy(): self.copy() is fixed and doesn't grow. Each PualiString in self.copy() is also smaller. On every step the number of terms added to base increases but the rate of growth is slower.

If n = x meeting point is generally 2**(x-1) for the above category of PauliSum. Question is can we say for sure that meeting point cant be behind 2**(x-1) for other inputs as well. If we can prove that it ensures most delayed saturation we can probably implement the pow function as

if e < 2 ** (n - 1)
  linear multiplication exponentiation 
else 
  binary exponentiation 

Intuitively, I chose each PauliString in the PauliSum referencing a unique qubit as it tends to delay thes saturation. For saturation we need X2 = Y2 = Z**2 = -iXYZ = I

However, given that we dont have a formal proof and the improvements are for larger values of e which itself increases as n increases. The percentage of cases in which it performs betters looks small. Will be going ahead and pushing the implementation of regular multiplication.

Copy link
Collaborator

@pavoljuhas pavoljuhas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nearly there, but we should also test with exponent = 1.

cirq-core/cirq/ops/linear_combinations_test.py Outdated Show resolved Hide resolved
@tanujkhattar tanujkhattar removed their assignment Mar 14, 2023
@pavoljuhas
Copy link
Collaborator

Hi Tarun, when updating the PR, please add new changes as new commits on top of your branch rather than amending old commits, which may result in complicated commit graphs with duplicate original and amended commits. It is also not necessary to merge-in the mainline master in every update - provided the master does not touch files worked on in this PR.

It is not a big deal as in the end we squash each PR to a single commit on master,
but while the PR is open it is easier to follow its progress with a simple commit history.

@TarunSinghania
Copy link
Contributor Author

TarunSinghania commented Mar 19, 2023

Hi Tarun, when updating the PR, please add new changes as new commits on top of your branch rather than amending old commits, which may result in complicated commit graphs with duplicate original and amended commits. It is also not necessary to merge-in the mainline master in every update - provided the master does not touch files worked on in this PR.

It is not a big deal as in the end we squash each PR to a single commit on master, but while the PR is open it is easier to follow its progress with a simple commit history.

Thanks for the comment. I am new to using git, I thought amending a commit would edit the last commit. Now I understand the behavior, that it only amends the last commit on the local repository (on my laptop). The commit on the remote (forked repository) is already written. e.g:
Locally: commit-1 -> commit-2 -> push to remote (forked cirq repository)
Now if I amend commit-2 locally, the history locally will be commit1->commit3. Locally the branch would be one commit (commit-2) behind the main branch We will have two diverging histories for the same branch, which we will have to either-
a)Merge b) Rebase. Using github desktop and generating a Pull request automatically uses the merge option and led to complicated commit graphs with duplicate original and amended commits.

I will be adding in new commits on top of my branch (which is main branch of my local repository for this issue, but will create separate branches from now on).

Copy link
Collaborator

@pavoljuhas pavoljuhas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the updates Tarun.
LGTM with 2 small suggestions, also please fix the formatting issues reported by the CI check.
Otherwise it is good to go.

Copy link
Collaborator

@pavoljuhas pavoljuhas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tanujkhattar tanujkhattar merged commit 402260c into quantumlib:master Mar 21, 2023
harry-phasecraft pushed a commit to PhaseCraft/Cirq that referenced this pull request Oct 31, 2024
* Update and correct pow function for PauliSum

Updated pow function for PauliSum to use binary exponentiation

* Format and remainder initialisation correction

+Corrected formatting
+Corrected initialisation of remainder to identity

* Format and remainder initialisation correction

+Corrected formatting
+Corrected initialisation of remainder to identity

* using local variable instead of self and reformatting

* reverting to regular linear multiplication and adding tests

* Adding test for range (1,9)

* trailing blank

* remove trailing blank test

* Getting rid of extraneous variables

* Fix up final formatting issue

---------

Co-authored-by: Tanuj Khattar <tanujkhattar@google.com>
Co-authored-by: Pavol Juhas <juhas@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
size: S 10< lines changed <50
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants