Skip to content

Commit

Permalink
Merge pull request #8 from qsimulate/mps-subclass
Browse files Browse the repository at this point in the history
Make MPS subclassable
  • Loading branch information
hczhai authored Oct 23, 2023
2 parents c1c5501 + 5ae09fc commit 39b0732
Showing 1 changed file with 46 additions and 46 deletions.
92 changes: 46 additions & 46 deletions pyblock3/algebra/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,10 @@ def T(self):
tr = (0, *tuple(range(d + 1, d + d + 1)),
*tuple(range(1, d + 1)), d + d + 1)
tensors[i] = self.tensors[i].transpose(tr)
return MPS(tensors=tensors, const=self.const, opts=self.opts, dq=self.dq)
return type(self)(tensors=tensors, const=self.const, opts=self.opts, dq=self.dq)

@staticmethod
def ones(info, dtype=float, opts=None):
@classmethod
def ones(cls, info, dtype=float, opts=None):
"""Construct unfused MPS from MPSInfo, with identity matrix elements."""
tensors = [None] * info.n_sites
for i in range(info.n_sites):
Expand All @@ -200,10 +200,10 @@ def ones(info, dtype=float, opts=None):
else:
tensors[i] = FlatSparseTensor.ones(
(info.left_dims[i], info.basis[i], info.left_dims[i + 1]), dtype=dtype)
return MPS(tensors=tensors, opts=opts)
return cls(tensors=tensors, opts=opts)

@staticmethod
def zeros(info, dtype=float, opts=None):
@classmethod
def zeros(cls, info, dtype=float, opts=None):
"""Construct unfused MPS from MPSInfo, with zero matrix elements."""
tensors = [None] * info.n_sites
for i in range(info.n_sites):
Expand All @@ -213,10 +213,10 @@ def zeros(info, dtype=float, opts=None):
else:
tensors[i] = FlatSparseTensor.zeros(
(info.left_dims[i], info.basis[i], info.left_dims[i + 1]), dtype=dtype)
return MPS(tensors=tensors, opts=opts)
return cls(tensors=tensors, opts=opts)

@staticmethod
def random(info, low=0, high=1, dtype=float, opts=None):
@classmethod
def random(cls, info, low=0, high=1, dtype=float, opts=None):
"""Construct unfused MPS from MPSInfo, with random matrix elements."""
tensors = [None] * info.n_sites
for i in range(info.n_sites):
Expand All @@ -228,8 +228,8 @@ def random(info, low=0, high=1, dtype=float, opts=None):
tensors[i] = FlatSparseTensor.random(
(info.left_dims[i], info.basis[i], info.left_dims[i + 1]),
dtype=dtype) * (high - low) + low
return MPS(tensors=tensors, opts=opts)
return cls(tensors=tensors, opts=opts)

def fix_pattern(self, pattern=None):
dq = None
for i in range(self.n_sites):
Expand Down Expand Up @@ -294,14 +294,14 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
out.tensors = tensors
out.const = const
out.opts = self.opts
return MPS(tensors=tensors, const=const, opts=self.opts, dq=self.dq)
return type(self)(tensors=tensors, const=const, opts=self.opts, dq=self.dq)

def __array_function__(self, func, types, args, kwargs):
if func not in _mps_numpy_func_impls:
return NotImplemented
if not all(issubclass(t, self.__class__) for t in types):
return NotImplemented
return _mps_numpy_func_impls[func](*args, **kwargs)
return _mps_numpy_func_impls[func](type(self), *args, **kwargs)

def canonicalize(self, center):
"""
Expand All @@ -323,7 +323,7 @@ def canonicalize(self, center):
l, q = tensors[i].right_canonicalize()
tensors[i] = q
tensors[i - 1] = np.tensordot(tensors[i - 1], l, axes=1)
return MPS(tensors=tensors, opts=self.opts, const=self.const, dq=self.dq)
return type(self)(tensors=tensors, opts=self.opts, const=self.const, dq=self.dq)

def compress(self, **opts):
"""
Expand Down Expand Up @@ -365,22 +365,22 @@ def compress(self, **opts):
tensors[i] = ls
tensors[i + 1] = np.tensordot(r, tensors[i + 1], axes=1)
merror = max(merror, err)
return MPS(tensors=tensors, const=self.const, opts=self.opts, dq=self.dq), merror
return type(self)(tensors=tensors, const=self.const, opts=self.opts, dq=self.dq), merror

@staticmethod
def _add(a, b):
@classmethod
def _add(cls, a, b):
"""Add two MPS."""
assert isinstance(a, MPS) and isinstance(b, MPS)
assert isinstance(a, cls) and isinstance(b, cls)
assert a.n_sites == b.n_sites
n_sites = a.n_sites

if n_sites == 1:
return MPS(tensors=[a[0] + b[0]], const=a.const + b.const, opts=a.opts, dq=a.dq)
return cls(tensors=[a[0] + b[0]], const=a.const + b.const, opts=a.opts, dq=a.dq)

if any([t.n_blocks == 0 for t in a.tensors]):
return MPS(tensors=b.tensors, const=a.const + b.const, opts=b.opts, dq=b.dq)
return cls(tensors=b.tensors, const=a.const + b.const, opts=b.opts, dq=b.dq)
elif any([t.n_blocks == 0 for t in b.tensors]):
return MPS(tensors=a.tensors, const=a.const + b.const, opts=a.opts, dq=a.dq)
return cls(tensors=a.tensors, const=a.const + b.const, opts=a.opts, dq=a.dq)

ainfos = [t.infos for t in a.tensors]
binfos = [t.infos for t in b.tensors]
Expand All @@ -396,7 +396,7 @@ def _add(a, b):
for i in range(n_sites):
tensors.append(a.tensors[i].kron_add(
b.tensors[i], infos=(sum_bonds[i], sum_bonds[i + 1])))
return MPS(tensors=tensors, const=a.const + b.const, opts=a.opts, dq=a.dq)
return cls(tensors=tensors, const=a.const + b.const, opts=a.opts, dq=a.dq)

def __getitem__(self, i):
return self.tensors[i]
Expand All @@ -407,17 +407,17 @@ def __setitem__(self, i, ts):
def conj(self):
return np.conj(self)

@staticmethod
@classmethod
@implements(np.copy)
def _copy(x):
return MPS(tensors=[t.copy() for t in x.tensors], const=x.const, opts=x.opts.copy(), dq=x.dq)
def _copy(cls, x):
return cls(tensors=[t.copy() for t in x.tensors], const=x.const, opts=x.opts.copy(), dq=x.dq)

def copy(self):
return np.copy(self)

@staticmethod
@classmethod
@implements(np.dot)
def _dot(a, b, out=None):
def _dot(cls, a, b, out=None):
if isinstance(a, numbers.Number) or isinstance(b, numbers.Number):
if isinstance(a, numbers.Number) and a == 0:
return 0.0
Expand All @@ -426,7 +426,7 @@ def _dot(a, b, out=None):
else:
return np.multiply(a, b, out=out)

assert isinstance(a, MPS) and isinstance(b, MPS)
assert isinstance(a, cls) and isinstance(b, cls)
assert a.n_sites == b.n_sites

left = np.array(0)
Expand All @@ -446,7 +446,7 @@ def _dot(a, b, out=None):
else:
lbra = np.tensordot(left, a.tensors[i], axes=([0], [0]))
left = np.tensordot(lbra, b.tensors[i], axes=(cidx, cidx))

r = left if isinstance(left, float) else left.item()

if out is not None:
Expand All @@ -457,23 +457,23 @@ def _dot(a, b, out=None):
def dot(self, b, out=None):
return np.dot(self, b, out=out)

@staticmethod
@classmethod
@implements(np.linalg.norm)
def _norm(x):
def _norm(cls, x):
d = np.conj(x).dot(x)
assert (abs(d.real) > 1E-10 and abs(d.imag) / abs(d.real) < 1E-10) or abs(d.imag) < 1E-10
return np.sqrt(abs(d.real) if abs(d.real) < 1E-10 else d.real)

def norm(self):
return np.linalg.norm(self)

@staticmethod
@classmethod
@implements(np.matmul)
def _matmul(a, b, out=None):
def _matmul(cls, a, b, out=None):
if isinstance(a, numbers.Number) or isinstance(b, numbers.Number):
return np.multiply(a, b, out=out)

assert isinstance(a, MPS) and isinstance(b, MPS)
assert isinstance(a, cls) and isinstance(b, cls)

opts = {**a.opts, **b.opts}

Expand Down Expand Up @@ -516,7 +516,7 @@ def _matmul(a, b, out=None):
tensors[i] = tensors[i].fuse(-2, -1, info=prod_bonds[i + 1]
).fuse(0, 1, info=prod_bonds[i])

r = MPS(tensors=tensors)
r = cls(tensors=tensors)

# const terms
if a.const != 0 and b.const == 0:
Expand All @@ -530,7 +530,7 @@ def _matmul(a, b, out=None):
# compression
if len(opts) != 0:
if r.n_sites > 1:
r, _ = MPS.compress(r, **opts)
r, _ = cls.compress(r, **opts)
r.opts = opts

if out is not None:
Expand Down Expand Up @@ -592,8 +592,8 @@ def symmetry_fuse(self, symm_map, info=None):
sinfos = (bonds[i], *minfos, bonds[i + 1])
finfos = [BondFusingInfo.get_symmetry_fusing_info(i, symm_map) for i in sinfos]
tensors[i] = self[i].symmetry_fuse(finfos, symm_map)
return MPS(tensors=tensors, const=self.const, opts=self.opts, dq=self.dq)

return type(self)(tensors=tensors, const=self.const, opts=self.opts, dq=self.dq)

@staticmethod
def _to_sliceable(a, info=None):
Expand Down Expand Up @@ -629,7 +629,7 @@ def to_sliceable(self, info=None):
sinfos = (bonds[i], *minfos, bonds[i + 1])
tensors[i] = self[i].to_sliceable(infos=sinfos)

return MPS(tensors=tensors, const=self.const, opts=self.opts, dq=self.dq)
return type(self)(tensors=tensors, const=self.const, opts=self.opts, dq=self.dq)

@staticmethod
def _to_sparse(a):
Expand All @@ -642,8 +642,8 @@ def to_sparse(self):
tensors[it] = ts
else:
tensors[it] = ts.to_sparse(dq=None if it == 0 else self.dq)
return MPS(tensors=tensors, const=self.const, opts=self.opts, dq=self.dq)
return type(self)(tensors=tensors, const=self.const, opts=self.opts, dq=self.dq)

@staticmethod
def _to_ad_sparse(a):
return a.to_sparse()
Expand All @@ -658,7 +658,7 @@ def to_ad_sparse(self):
tensors[it] = ADFT.from_non_ad(ts, pattern='++--')
else:
assert False
return MPS(tensors=tensors, const=self.const, opts=self.opts, dq=self.dq)
return type(self)(tensors=tensors, const=self.const, opts=self.opts, dq=self.dq)

@staticmethod
def _to_symbolic(a):
Expand Down Expand Up @@ -689,7 +689,7 @@ def to_symbolic(self):
else:
tensors[it] = SymbolicSparseTensor.from_sparse(ts,
pos=pos, infos=bonds[it:it + 2], dq=None if it == 0 else self.dq)
return MPS(tensors=tensors, const=self.const, opts=self.opts, dq=self.dq)
return type(self)(tensors=tensors, const=self.const, opts=self.opts, dq=self.dq)

@staticmethod
def _to_non_flat(a):
Expand All @@ -706,7 +706,7 @@ def to_non_flat(self):
tensors[it] = ts
else:
tensors[it] = ts.to_sparse()
return MPS(tensors=tensors, const=self.const, opts=self.opts, dq=self.dq)
return type(self)(tensors=tensors, const=self.const, opts=self.opts, dq=self.dq)

@staticmethod
def _to_flat(a):
Expand All @@ -723,7 +723,7 @@ def to_flat(self):
tensors[it] = FlatFermionTensor.from_fermion(ts)
else:
tensors[it] = ts.to_flat()
return MPS(tensors=tensors, const=self.const, opts=self.opts, dq=self.dq)
return type(self)(tensors=tensors, const=self.const, opts=self.opts, dq=self.dq)

@staticmethod
def _simplify(a):
Expand All @@ -747,7 +747,7 @@ def simplify(self):
tensors[i] = tensors[i].simplify(r, left=True)
tensors[i - 1] = tensors[i - 1].simplify(r, left=False)

return MPS(tensors=tensors, const=self.const, opts=self.opts, dq=self.dq)
return type(self)(tensors=tensors, const=self.const, opts=self.opts, dq=self.dq)

@staticmethod
def _amplitude(a, det):
Expand Down

0 comments on commit 39b0732

Please sign in to comment.