-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #73 from JonasRigo/add_fermions
Added fermionic creation-/annihilation operators (and more!)
- Loading branch information
Showing
4 changed files
with
445 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,306 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#########################\n", | ||
"# Example for using fermionic operators\n", | ||
"# in the jVMC framework\n", | ||
"#########################\n", | ||
"\n", | ||
"# jVMC\n", | ||
"import jVMC\n", | ||
"import jVMC.nets as nets\n", | ||
"import jVMC.operator as op\n", | ||
"from jVMC.operator import number, creation, annihilation\n", | ||
"import jVMC.sampler\n", | ||
"from jVMC.util import ground_state_search, measure\n", | ||
"from jVMC.vqs import NQS\n", | ||
"from jVMC.stats import SampledObs\n", | ||
"from jVMC import global_defs\n", | ||
"\n", | ||
"# python stuff\n", | ||
"import functools\n", | ||
"\n", | ||
"# jax\n", | ||
"import jax\n", | ||
"from jax.config import config\n", | ||
"config.update(\"jax_enable_x64\", True)\n", | ||
"import jax.numpy as jnp\n", | ||
"import flax.linen as nn\n", | ||
"\n", | ||
"import jax.random as random\n", | ||
"\n", | ||
"# numpy\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"# plotting\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"\n", | ||
"#########################\n", | ||
"# check against openfermion\n", | ||
"#########################\n", | ||
"import openfermion as of\n", | ||
"from openfermion.ops import FermionOperator as fop\n", | ||
"from openfermion.linalg import get_sparse_operator" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"##########################\n", | ||
"# custon tarbget wave function\n", | ||
"# specific to openfermion compatibility\n", | ||
"class Target(nn.Module):\n", | ||
" \"\"\"Target wave function, returns a vector with the same dimension as the Hilbert space\n", | ||
"\n", | ||
" Initialization arguments:\n", | ||
" * ``L``: System size\n", | ||
" * ``d``: local Hilbert space dimension\n", | ||
" * ``delta``: small number to avoid log(0)\n", | ||
"\n", | ||
" \"\"\"\n", | ||
" L: int\n", | ||
" d: float = 2.00\n", | ||
" delta: float = 1e-15\n", | ||
"\n", | ||
" @nn.compact\n", | ||
" def __call__(self, s):\n", | ||
" kernel = self.param('kernel',\n", | ||
" nn.initializers.constant(1),\n", | ||
" (int(self.d**self.L)))\n", | ||
" # return amplitude for state s\n", | ||
" idx = ((self.d**jnp.arange(self.L)).dot(s[::-1])).astype(int) # NOTE that the state is reversed to account for different bit conventions used in openfermion\n", | ||
" return jnp.log(abs(kernel[idx]+self.delta)) + 1.j*jnp.angle(kernel[idx]) " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Fermionic operators\n", | ||
"\n", | ||
"Fermionic operators have to satisfy the following condition\n", | ||
"$$\n", | ||
"\\lbrace \\hat{c}^\\dagger_i, \\hat{c}_j\\rbrace = \\delta_{ij} \\; ,\n", | ||
"$$\n", | ||
"where $i,j$ are so-called *flavours*.As is done in 'openfermoin' we do not allow for a spin quantum number. In other words, all our fermionic operators can only carry a single flavour. For higher flavour indeces one has combine several distinct fermionis.\n", | ||
"\n", | ||
"The key to realizing fermions is the Jordan-Wigner factor. Every fermionic state is constructed using a filling order, then we have to count how many craetion operatros a given operator has to commute thorugh to arrive at his filling order position.\n", | ||
"We can achieve this as follows.\n", | ||
"$$\n", | ||
"\\hat{c}^\\dagger\\vert 1, 0 \\rangle = (-1)^\\Omega \\vert 1,1\\rangle\n", | ||
"$$\n", | ||
"with \n", | ||
"$$\n", | ||
"\\Omega = \\sum^{j-1}_{i=0} s_i \\; .\n", | ||
"$$\n", | ||
"In the following we construct the repulsive Hubbard Model on a chain as an example and compare it to openfermion\n", | ||
"$$\n", | ||
"H = U \\sum^N_{i=1} \\hat{n}_{\\uparrow i}\\hat{n}_{\\downarrow i} + t \\sum^{N-1}_{i=1} \\hat{c}^\\dagger_{\\sigma i} \\hat{c}_{\\sigma i+1} + h.c. \\;.\n", | ||
"$$ " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#########################\n", | ||
"# jVMC hamiltonian\n", | ||
"#########################\n", | ||
"t = - 1.0 # hopping\n", | ||
"mu = -2.0 # chemical potential\n", | ||
"V = 4.0 # interaction\n", | ||
"L = 4 # number of sites\n", | ||
"flavour = 2 # number of flavours\n", | ||
"flavourL = flavour*L # number of spins times sites\n", | ||
"\n", | ||
"# initalize the Hamitonian\n", | ||
"hamiltonian = op.BranchFreeOperator()\n", | ||
"# impurity definitions\n", | ||
"site1UP = 0\n", | ||
"site1DO = flavourL-1#//flavour\n", | ||
"# loop over the 1d lattice\n", | ||
"for i in range(0,flavourL//flavour):\n", | ||
" # interaction\n", | ||
" hamiltonian.add(op.scal_opstr( V, ( number(site1UP + i) , number(site1DO - i) ) ) )\n", | ||
" # chemical potential\n", | ||
" hamiltonian.add(op.scal_opstr(mu , ( number(site1UP + i) ,) ) )\n", | ||
" hamiltonian.add(op.scal_opstr(mu , ( number(site1DO - i) ,) ) )\n", | ||
" if i == flavourL//flavour-1:\n", | ||
" continue\n", | ||
" # up chain hopping\n", | ||
" hamiltonian.add(op.scal_opstr( t, ( annihilation(site1UP + i) ,creation(site1UP + i + 1) ) ) )\n", | ||
" hamiltonian.add(op.scal_opstr( t, ( annihilation(site1UP + i + 1) ,creation(site1UP + i) ) ) )\n", | ||
" # down chain hopping\n", | ||
" hamiltonian.add(op.scal_opstr( t, ( annihilation(site1DO - i) ,creation(site1DO - i - 1) ) ) )\n", | ||
" hamiltonian.add(op.scal_opstr( t, ( annihilation(site1DO - i - 1) ,creation(site1DO - i) ) ) )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#########################\n", | ||
"# openfermion\n", | ||
"#########################\n", | ||
"\n", | ||
"H = 0.0*fop()\n", | ||
"# loop over the 1d lattice\n", | ||
"for i in range(0,flavourL//flavour):\n", | ||
" H += fop(((site1UP + i,1),(site1UP + i,0),(site1DO - i,1),(site1DO - i,0)),V) \n", | ||
" H += fop(((site1UP + i,1),(site1UP + i,0)),mu) + fop(((site1DO - i,1),(site1DO - i,0)),mu)\n", | ||
" if i == flavourL//flavour-1:\n", | ||
" continue\n", | ||
" # up chain\n", | ||
" H += (fop(((site1UP + i,1),(site1UP + i + 1,0)),t) + fop(((site1UP + i + 1,1),(site1UP + i,0)),t))\n", | ||
" # down chain\n", | ||
" H += (fop(((site1DO - i,1),(site1DO - i - 1,0)),t) + fop(((site1DO - i - 1,1),(site1DO - i,0)),t))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#########################\n", | ||
"# diagonalize the Openfermion Hamiltonain\n", | ||
"#########################\n", | ||
"\n", | ||
"ham = get_sparse_operator(H)\n", | ||
"a, b = np.linalg.eigh(ham.toarray())\n", | ||
"\n", | ||
"chi_model = Target(L=flavourL, d=2)\n", | ||
"chi = NQS(chi_model)\n", | ||
"chi(jnp.array(jnp.ones((1, 1, flavourL))))\n", | ||
"chi.set_parameters(b[:,0]+1e-14)\n", | ||
"chiSampler = jVMC.sampler.ExactSampler(chi, (flavourL,))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"s, logPsi, p = chiSampler.sample()\n", | ||
"sPrime, _ = hamiltonian.get_s_primes(s)\n", | ||
"Oloc = hamiltonian.get_O_loc(s, chi, logPsi)\n", | ||
"Omean = jVMC.mpi_wrapper.global_mean(Oloc,p)\n", | ||
"\n", | ||
"print(\"Ground state energy: \\njVMC: %.8f, Openfermion: %.8f\"%(Omean.real,a[0]))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Finding the ground state brute force" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Set up variational wave function\n", | ||
"all_states = Target(L=flavourL, d=2)\n", | ||
"psi = NQS(all_states)\n", | ||
"# initialize NQS\n", | ||
"print(\"Net init: \",psi(jnp.array(jnp.ones((1, 1, flavourL)))))\n", | ||
"# Set up exact sampler\n", | ||
"exactSampler = jVMC.sampler.ExactSampler(psi, flavourL)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Set up sampler\n", | ||
"sampler = jVMC.sampler.ExactSampler(psi, flavourL)\n", | ||
"\n", | ||
"# Set up TDVP\n", | ||
"tdvpEquation = jVMC.util.tdvp.TDVP(sampler, rhsPrefactor=1.,diagonalShift=10, makeReal='real')\n", | ||
"\n", | ||
"stepper = jVMC.util.stepper.Euler(timeStep=5e-1) # ODE integrator\n", | ||
"\n", | ||
"n_steps = 500\n", | ||
"res = []\n", | ||
"for n in range(n_steps):\n", | ||
"\n", | ||
" dp, _ = stepper.step(0, tdvpEquation, psi.get_parameters(), hamiltonian=hamiltonian, psi=psi, numSamples=None)\n", | ||
" psi.set_parameters(dp)\n", | ||
"\n", | ||
" print(n, jax.numpy.real(tdvpEquation.ElocMean0), tdvpEquation.ElocVar0)\n", | ||
"\n", | ||
" res.append([n, jax.numpy.real(tdvpEquation.ElocMean0), tdvpEquation.ElocVar0])\n", | ||
"\n", | ||
"res = np.array(res)\n", | ||
"\n", | ||
"fig, ax = plt.subplots(2, 1, sharex=True, figsize=[4.8, 4.8])\n", | ||
"ax[0].semilogy(res[:, 0], res[:, 1] - a[0], '-', label=r\"$L=\" + str(L) + \"$\")\n", | ||
"ax[0].set_ylabel(r'$(E-E_0)/L$')\n", | ||
"\n", | ||
"ax[1].semilogy(res[:, 0], res[:, 2], '-')\n", | ||
"ax[1].set_ylabel(r'Var$(E)/L$')\n", | ||
"ax[0].legend()\n", | ||
"plt.xlabel('iteration')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"s, logPsi, _ = exactSampler.sample()\n", | ||
"var_wf = np.real(np.exp(logPsi))[0]\n", | ||
"# normalizing the wave function\n", | ||
"var_wf /= var_wf.dot(var_wf)**0.5\n", | ||
"\n", | ||
"figure = plt.figure(dpi=100)\n", | ||
"plt.xlabel('state')\n", | ||
"plt.ylabel('amplitude')\n", | ||
"plt.plot(var_wf,label='jVMC')\n", | ||
"plt.plot(np.exp(chi(s)).real[0],'--',label='openfermion')\n", | ||
"plt.legend()\n", | ||
"plt.show()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "jvmc", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.19" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.