Skip to content

Commit

Permalink
Merge pull request #73 from JonasRigo/add_fermions
Browse files Browse the repository at this point in the history
Added fermionic creation-/annihilation operators (and more!)
  • Loading branch information
markusschmitt authored Jun 10, 2024
2 parents d17d618 + 8d07a55 commit e6accab
Show file tree
Hide file tree
Showing 4 changed files with 445 additions and 9 deletions.
306 changes: 306 additions & 0 deletions examples/ex7_fermions.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"#########################\n",
"# Example for using fermionic operators\n",
"# in the jVMC framework\n",
"#########################\n",
"\n",
"# jVMC\n",
"import jVMC\n",
"import jVMC.nets as nets\n",
"import jVMC.operator as op\n",
"from jVMC.operator import number, creation, annihilation\n",
"import jVMC.sampler\n",
"from jVMC.util import ground_state_search, measure\n",
"from jVMC.vqs import NQS\n",
"from jVMC.stats import SampledObs\n",
"from jVMC import global_defs\n",
"\n",
"# python stuff\n",
"import functools\n",
"\n",
"# jax\n",
"import jax\n",
"from jax.config import config\n",
"config.update(\"jax_enable_x64\", True)\n",
"import jax.numpy as jnp\n",
"import flax.linen as nn\n",
"\n",
"import jax.random as random\n",
"\n",
"# numpy\n",
"import numpy as np\n",
"\n",
"# plotting\n",
"import matplotlib.pyplot as plt\n",
"\n",
"#########################\n",
"# check against openfermion\n",
"#########################\n",
"import openfermion as of\n",
"from openfermion.ops import FermionOperator as fop\n",
"from openfermion.linalg import get_sparse_operator"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"##########################\n",
"# custon tarbget wave function\n",
"# specific to openfermion compatibility\n",
"class Target(nn.Module):\n",
" \"\"\"Target wave function, returns a vector with the same dimension as the Hilbert space\n",
"\n",
" Initialization arguments:\n",
" * ``L``: System size\n",
" * ``d``: local Hilbert space dimension\n",
" * ``delta``: small number to avoid log(0)\n",
"\n",
" \"\"\"\n",
" L: int\n",
" d: float = 2.00\n",
" delta: float = 1e-15\n",
"\n",
" @nn.compact\n",
" def __call__(self, s):\n",
" kernel = self.param('kernel',\n",
" nn.initializers.constant(1),\n",
" (int(self.d**self.L)))\n",
" # return amplitude for state s\n",
" idx = ((self.d**jnp.arange(self.L)).dot(s[::-1])).astype(int) # NOTE that the state is reversed to account for different bit conventions used in openfermion\n",
" return jnp.log(abs(kernel[idx]+self.delta)) + 1.j*jnp.angle(kernel[idx]) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fermionic operators\n",
"\n",
"Fermionic operators have to satisfy the following condition\n",
"$$\n",
"\\lbrace \\hat{c}^\\dagger_i, \\hat{c}_j\\rbrace = \\delta_{ij} \\; ,\n",
"$$\n",
"where $i,j$ are so-called *flavours*.As is done in 'openfermoin' we do not allow for a spin quantum number. In other words, all our fermionic operators can only carry a single flavour. For higher flavour indeces one has combine several distinct fermionis.\n",
"\n",
"The key to realizing fermions is the Jordan-Wigner factor. Every fermionic state is constructed using a filling order, then we have to count how many craetion operatros a given operator has to commute thorugh to arrive at his filling order position.\n",
"We can achieve this as follows.\n",
"$$\n",
"\\hat{c}^\\dagger\\vert 1, 0 \\rangle = (-1)^\\Omega \\vert 1,1\\rangle\n",
"$$\n",
"with \n",
"$$\n",
"\\Omega = \\sum^{j-1}_{i=0} s_i \\; .\n",
"$$\n",
"In the following we construct the repulsive Hubbard Model on a chain as an example and compare it to openfermion\n",
"$$\n",
"H = U \\sum^N_{i=1} \\hat{n}_{\\uparrow i}\\hat{n}_{\\downarrow i} + t \\sum^{N-1}_{i=1} \\hat{c}^\\dagger_{\\sigma i} \\hat{c}_{\\sigma i+1} + h.c. \\;.\n",
"$$ "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"#########################\n",
"# jVMC hamiltonian\n",
"#########################\n",
"t = - 1.0 # hopping\n",
"mu = -2.0 # chemical potential\n",
"V = 4.0 # interaction\n",
"L = 4 # number of sites\n",
"flavour = 2 # number of flavours\n",
"flavourL = flavour*L # number of spins times sites\n",
"\n",
"# initalize the Hamitonian\n",
"hamiltonian = op.BranchFreeOperator()\n",
"# impurity definitions\n",
"site1UP = 0\n",
"site1DO = flavourL-1#//flavour\n",
"# loop over the 1d lattice\n",
"for i in range(0,flavourL//flavour):\n",
" # interaction\n",
" hamiltonian.add(op.scal_opstr( V, ( number(site1UP + i) , number(site1DO - i) ) ) )\n",
" # chemical potential\n",
" hamiltonian.add(op.scal_opstr(mu , ( number(site1UP + i) ,) ) )\n",
" hamiltonian.add(op.scal_opstr(mu , ( number(site1DO - i) ,) ) )\n",
" if i == flavourL//flavour-1:\n",
" continue\n",
" # up chain hopping\n",
" hamiltonian.add(op.scal_opstr( t, ( annihilation(site1UP + i) ,creation(site1UP + i + 1) ) ) )\n",
" hamiltonian.add(op.scal_opstr( t, ( annihilation(site1UP + i + 1) ,creation(site1UP + i) ) ) )\n",
" # down chain hopping\n",
" hamiltonian.add(op.scal_opstr( t, ( annihilation(site1DO - i) ,creation(site1DO - i - 1) ) ) )\n",
" hamiltonian.add(op.scal_opstr( t, ( annihilation(site1DO - i - 1) ,creation(site1DO - i) ) ) )"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"#########################\n",
"# openfermion\n",
"#########################\n",
"\n",
"H = 0.0*fop()\n",
"# loop over the 1d lattice\n",
"for i in range(0,flavourL//flavour):\n",
" H += fop(((site1UP + i,1),(site1UP + i,0),(site1DO - i,1),(site1DO - i,0)),V) \n",
" H += fop(((site1UP + i,1),(site1UP + i,0)),mu) + fop(((site1DO - i,1),(site1DO - i,0)),mu)\n",
" if i == flavourL//flavour-1:\n",
" continue\n",
" # up chain\n",
" H += (fop(((site1UP + i,1),(site1UP + i + 1,0)),t) + fop(((site1UP + i + 1,1),(site1UP + i,0)),t))\n",
" # down chain\n",
" H += (fop(((site1DO - i,1),(site1DO - i - 1,0)),t) + fop(((site1DO - i - 1,1),(site1DO - i,0)),t))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"#########################\n",
"# diagonalize the Openfermion Hamiltonain\n",
"#########################\n",
"\n",
"ham = get_sparse_operator(H)\n",
"a, b = np.linalg.eigh(ham.toarray())\n",
"\n",
"chi_model = Target(L=flavourL, d=2)\n",
"chi = NQS(chi_model)\n",
"chi(jnp.array(jnp.ones((1, 1, flavourL))))\n",
"chi.set_parameters(b[:,0]+1e-14)\n",
"chiSampler = jVMC.sampler.ExactSampler(chi, (flavourL,))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"s, logPsi, p = chiSampler.sample()\n",
"sPrime, _ = hamiltonian.get_s_primes(s)\n",
"Oloc = hamiltonian.get_O_loc(s, chi, logPsi)\n",
"Omean = jVMC.mpi_wrapper.global_mean(Oloc,p)\n",
"\n",
"print(\"Ground state energy: \\njVMC: %.8f, Openfermion: %.8f\"%(Omean.real,a[0]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Finding the ground state brute force"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set up variational wave function\n",
"all_states = Target(L=flavourL, d=2)\n",
"psi = NQS(all_states)\n",
"# initialize NQS\n",
"print(\"Net init: \",psi(jnp.array(jnp.ones((1, 1, flavourL)))))\n",
"# Set up exact sampler\n",
"exactSampler = jVMC.sampler.ExactSampler(psi, flavourL)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set up sampler\n",
"sampler = jVMC.sampler.ExactSampler(psi, flavourL)\n",
"\n",
"# Set up TDVP\n",
"tdvpEquation = jVMC.util.tdvp.TDVP(sampler, rhsPrefactor=1.,diagonalShift=10, makeReal='real')\n",
"\n",
"stepper = jVMC.util.stepper.Euler(timeStep=5e-1) # ODE integrator\n",
"\n",
"n_steps = 500\n",
"res = []\n",
"for n in range(n_steps):\n",
"\n",
" dp, _ = stepper.step(0, tdvpEquation, psi.get_parameters(), hamiltonian=hamiltonian, psi=psi, numSamples=None)\n",
" psi.set_parameters(dp)\n",
"\n",
" print(n, jax.numpy.real(tdvpEquation.ElocMean0), tdvpEquation.ElocVar0)\n",
"\n",
" res.append([n, jax.numpy.real(tdvpEquation.ElocMean0), tdvpEquation.ElocVar0])\n",
"\n",
"res = np.array(res)\n",
"\n",
"fig, ax = plt.subplots(2, 1, sharex=True, figsize=[4.8, 4.8])\n",
"ax[0].semilogy(res[:, 0], res[:, 1] - a[0], '-', label=r\"$L=\" + str(L) + \"$\")\n",
"ax[0].set_ylabel(r'$(E-E_0)/L$')\n",
"\n",
"ax[1].semilogy(res[:, 0], res[:, 2], '-')\n",
"ax[1].set_ylabel(r'Var$(E)/L$')\n",
"ax[0].legend()\n",
"plt.xlabel('iteration')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"s, logPsi, _ = exactSampler.sample()\n",
"var_wf = np.real(np.exp(logPsi))[0]\n",
"# normalizing the wave function\n",
"var_wf /= var_wf.dot(var_wf)**0.5\n",
"\n",
"figure = plt.figure(dpi=100)\n",
"plt.xlabel('state')\n",
"plt.ylabel('amplitude')\n",
"plt.plot(var_wf,label='jVMC')\n",
"plt.plot(np.exp(chi(s)).real[0],'--',label='openfermion')\n",
"plt.legend()\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "jvmc",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.19"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit e6accab

Please sign in to comment.