from typing import Any, Tuple
import numpy as np
import autoray as a
from quimb import *
import quimb.tensor as qtn
from quimb.tensor.tensor_1d import MatrixProductOperator
from jax.nn.initializers import Initializer
import jax.numpy as jnp
from .model import Model
[docs]
class MatrixProductOperator(Model, qtn.MatrixProductOperator):
"""A Trainable MatrixProductOperator class.
See :class:`quimb.tensor.tensor_1d.MatrixProductOperator` for explanation of other attributes and methods.
"""
[docs]
def __init__(self, arrays, **kwargs):
# if isinstance(arrays, MatrixProductState):
# Model.__init__(self)
# return
Model.__init__(self)
qtn.MatrixProductOperator.__init__(self, arrays, **kwargs)
[docs]
def normalize(self, insert=None):
"""Function for normalizing tensors of :class:`tn4ml.models.mpo.MatrixProductOperator`.
Parameters
----------
insert : int
Index of tensor divided by norm. *Default = None*. When `None` the norm division is distributed across all tensors.
"""
if self.L > 200: # for large systems
for i, tensor in enumerate(self.tensors):
if i == 0:
self.left_canonize_site(i)
elif i == self.L - 1:
tensor.modify(data=tensor.data / jnp.linalg.norm(tensor.data))
else:
tensor.modify(data=tensor.data / jnp.linalg.norm(tensor.data))
self.left_canonize_site(i)
else:
norm = self.norm()
if insert == None:
for tensor in self.tensors:
tensor.modify(data=tensor.data / a.do("power", norm, 1 / self.L))
else:
self.tensors[insert].modify(data=self.tensors[insert].data / norm)
def trainable_wrapper(mps: qtn.MatrixProductOperator, **kwargs) -> MatrixProductOperator:
""" Creates a wrapper around qtn.MatrixProductOperator so it can be trainable.
Parameters
----------
mps : :class:`quimb.tensor.MatrixProductOperator`
Matrix Product Operator to be trained.
Returns
-------
:class:`tn4ml.models.mps.MatrixProductOperator`
"""
tensors = mps.arrays
return MatrixProductOperator(tensors, **kwargs)
def generate_shape(method: str,
L: int,
bond_dim: int = 2,
phys_dim: Tuple[int, int] = (2, 2),
cyclic: bool = False,
position: int = None,
) -> tuple:
"""Returns a shape of tensor.
Parameters
----------
method : str
Method on how to create shapes of tensors.
'even' = exact dimensions as given by parameters, anything else = truncated dimensions.
L : int
Number of tensors.
bond_dim : int
Dimension of virtual indices between tensors. *Default = 4*.
phys_dim : tuple(int, int)
Dimension of physical indices for individual tensor - *up* and *down*.
cyclic : bool
Flag for indicating if MatrixProductOperator this tensor is part of is cyclic. *Default=False*.
position : int
Position of tensor in MatrixProductOperator.
Returns
-------
tuple
"""
if method == 'even':
shape = (bond_dim, bond_dim, *phys_dim)
if position == 1:
shape = (1, bond_dim, *phys_dim)
if position == L:
shape = (bond_dim, 1, *phys_dim)
else:
# not sure is this needed if I can use compress
assert not cyclic
if position > L // 2:
j = (L + 1 - abs(2*position - L - 1)) // 2
else:
j = position
chir = min(bond_dim, phys_dim[0]**j * phys_dim[1]**j)
chil = min(bond_dim, phys_dim[0]**(j-1) * phys_dim[1] ** (j-1))
if position > L // 2:
(chil, chir) = (chir, chil)
if position == 1:
shape = (chir, *phys_dim)
elif position == L:
shape = (chil, *phys_dim)
else:
shape = (chil, chir, *phys_dim)
return shape
def MPO_initialize(L: int,
initializer: Initializer,
key: Any,
dtype: Any = jnp.float_,
shape_method: str = 'even',
bond_dim: int = 4,
phys_dim: Tuple[int, int] = (2, 2),
add_identity: bool = False,
boundary: str = 'obc',
cyclic: bool = False,
compress: bool = False,
insert: int = None,
canonical_center: int = None,
**kwargs):
"""Generates :class:`tn4ml.models.mps.MatrixProductOperator`.
Parameters
----------
L : int
Number of tensors.
initializer : :class:`jax.nn.initializers.Initializer``
Type of tensor initialization function.
key : Array
Argument key is a PRNG key (e.g. from `jax.random.key()`), used to generate random numbers to initialize the array.
dtype : Any
Type of tensor data (from `jax.numpy.float_`)
shape_method : str
Method to generate shapes for tensors.
bond_dim : int
Dimension of virtual indices between tensors. *Default = 4*.
phys_dim : tuple(int, int)
Dimension of physical indices for individual tensor - *up* and *down*.
add_identity : bool
Flag for adding identity to tensor diagonal elements. *Default = False*.
boundary : str
Boundary conditions for the MatrixProductOperator. *Default = 'obc'*.
obc = open boundary conditions. pbc = periodic boundary conditions.
cyclic : bool
Flag for indicating if MatrixProductState is cyclic. *Default=False*.
compress : bool
Flag to truncate bond dimensions.
insert : int
Index of tensor divided by norm. When `None` the norm division is distributed across all tensors
canonical_center : int
If not `None` then create canonical form around canonical center index.
Returns
-------
:class:`tn4ml.models.mps.MatrixProductOperator`
"""
if cyclic and shape_method != 'even':
raise NotImplementedError("Change shape_method to 'even'.")
tensors = []
for i in range(1, L+1):
shape = generate_shape(shape_method, L, bond_dim, phys_dim, cyclic, i)
tensor = initializer(key, shape, dtype)
if add_identity:
if len(tensor.shape) == 3:
copy_tensor = jnp.copy(tensor)
copy_tensor.at[:, :, 0].set(jnp.eye(tensor.shape[0],
tensor.shape[1],
dtype=dtype))
tensor = copy_tensor
elif len(tensor.shape) == 4: # output node
copy_tensor = jnp.copy(tensor)
identity = jnp.eye(tensor.shape[0],
tensor.shape[1],
dtype=dtype)
identity = jnp.expand_dims(identity, axis=2)
identity = jnp.broadcast_to(identity, (copy_tensor.shape[0], copy_tensor.shape[1], copy_tensor.shape[3]))
copy_tensor.at[:, :, 0, :].set(identity)
tensor = copy_tensor
if boundary == 'obc':
aux_tensor = jnp.zeros(tensor.shape, dtype=dtype)
if len(tensor.shape) == 3:
if i == 1:
# Left node
aux_tensor = aux_tensor.at[:,0,:].set(tensor[:,0,:])
tensor = aux_tensor
elif i == L:
# Right node
aux_tensor = aux_tensor.at[0,:,:].set(tensor[0,:,:])
tensor = aux_tensor
elif len(tensor.shape) == 4:
if i == 1:
# Left node
aux_tensor = aux_tensor.at[:,0,:,:].set(tensor[:,0,:,:])
tensor = aux_tensor
elif i == L:
# Right node
aux_tensor = aux_tensor.at[0,:,:,:].set(tensor[0,:,:,:])
tensor = aux_tensor
tensors.append(tensor)
if insert and insert < L and shape_method == 'even':
tensors[insert] /= np.sqrt(min(bond_dim, phys_dim[0]))
mpo = MatrixProductOperator(tensors, **kwargs)
if compress and shape_method == 'even':
mpo.compress(form="flat", max_bond=bond_dim) # limit bond_dim
if L > 200: # for large systems
for i, tensor in enumerate(mpo.tensors):
if i == 0:
mpo.left_canonize_site(i)
elif i == L - 1:
tensor.modify(data=tensor.data / jnp.linalg.norm(tensor.data))
else:
tensor.modify(data=tensor.data / jnp.linalg.norm(tensor.data))
mpo.left_canonize_site(i)
if canonical_center is not None:
mpo.canonicalize(canonical_center, inplace=True)
mpo.normalize(insert=canonical_center)
else:
if canonical_center == None:
mpo.normalize()
else:
mpo.canonicalize(canonical_center, inplace=True)
mpo.normalize(insert = canonical_center)
return mpo