from typing import Any, Collection
import numpy as np
import autoray as a
import math
from quimb import *
import quimb.tensor as qtn
from jax.nn.initializers import Initializer
import jax.numpy as jnp
import jax
from .model import Model
from .tn import TensorNetwork
from ..initializers import randn, rand_unitary
[docs]
class MatrixProductState(Model, qtn.MatrixProductState):
"""A Trainable MatrixProductState class.
See :class:`quimb.tensor.tensor_1d.MatrixProductState` for explanation of other attributes and methods.
"""
[docs]
def __init__(self,
arrays,
**kwargs):
"""Initializes the MatrixProductState.
Parameters
----------
arrays : list of array_like
The list of tensors, each of shape ``(D, D, d)``, where ``D`` is the bond dimension and ``d`` is the physical dimension.
**kwargs : dict
Additional arguments to be passed to the parent class.
"""
Model.__init__(self)
qtn.MatrixProductState.__init__(self, arrays, **kwargs)
[docs]
def normalize(self, insert=None):
if self.L > 200: # for large systems
for i, tensor in enumerate(self.tensors):
if i == 0:
self.left_canonize_site(i)
elif i == self.L - 1:
tensor.modify(data=tensor.data / jnp.linalg.norm(tensor.data))
else:
tensor.modify(data=tensor.data / jnp.linalg.norm(tensor.data))
self.left_canonize_site(i)
else:
norm = self.norm()
if insert == None:
for tensor in self.tensors:
tensor.modify(data=tensor.data / a.do("power", norm, 1 / self.L))
else:
self.tensors[insert].modify(data=self.tensors[insert].data / norm)
def trainable_wrapper(mps: qtn.MatrixProductState, **kwargs) -> MatrixProductState:
""" Creates a wrapper around qtn.MatrixProductState so it can be trainable.
Parameters
----------
mps : :class:`quimb.tensor.MatrixProductState`
Matrix Product State to be trained.
Returns
-------
:class:`tn4ml.models.mps.MatrixProductState`
"""
tensors = mps.arrays
return MatrixProductState(tensors, **kwargs)
def generate_shape(method: str,
L: int,
bond_dim: int = 2,
phys_dim: int = 2,
cyclic: bool = False,
position: int = None,
class_index: int = None,
class_dim: int = None,
) -> tuple:
"""Returns a shape of tensor .
Parameters
----------
method : str
Method on how to create shapes of tensors.
'even' = exact dimensions as given by parameters, anything else = truncated dimensions.
L : int
Number of tensors.
bond_dim : int
Dimension of virtual indices between tensors. *Default = 4*.
phys_dim : int
Dimension of physical index for individual tensor.
cyclic : bool
Flag for indicating if MatrixProductState this tensor is part of is cyclic. *Default=False*.
position : int
Position of tensor in MatrixProductState.
class_index : int
Index of tensor that is the output node. For classification tasks only.
class_dim : int
Dimension of output node, or number of classes for classification.
Returns
-------
tuple
"""
if method == 'even':
shape = (bond_dim, bond_dim, phys_dim, class_dim) if class_index is not None and position == class_index else (bond_dim, bond_dim, phys_dim)
if position == 1:
shape = (1, bond_dim, phys_dim, class_dim) if class_index is not None and position == class_index else (1, bond_dim, phys_dim)
if position == L:
shape = (bond_dim, 1, phys_dim, class_dim) if class_index is not None and position == class_index else (bond_dim, 1, phys_dim)
else:
assert not cyclic
j = (L + 1 - abs(2*position - L - 1)) // 2 if position > L // 2 else position
chir = min(bond_dim, phys_dim**j)
chil = min(bond_dim, phys_dim**(j-1))
if position > L // 2:
(chil, chir) = (chir, chil)
if position == 1:
shape = (1, chir, phys_dim, class_dim) if class_index is not None and position == class_index else (1, chir, phys_dim)
elif position == L:
shape = (chil, 1, phys_dim, class_dim) if class_index is not None and position == class_index else (chil, 1, phys_dim)
else:
shape = (chil, chir, phys_dim, class_dim) if class_index is not None and position == class_index else (chil, chir, phys_dim)
return shape
def generate_ind(L: int, shape: tuple, position: int, cyclic: bool = False, class_index: int = None) -> tuple:
"""
Returns the names of the tensor indices.
Parameters
----------
shape : tuple
Shape of tensor.
position : int
Position of tensor in MatrixProductState. Goes from 1 to L included.
cyclic : bool
Flag for indicating if MatrixProductState this tensor is part of is cyclic. *Default=False*.
class_index : int
Index of tensor that is the output node (that is having index for number of classes). For classification tasks only.
Returns
-------
tuple
String names of indices.
"""
if len(shape) == 3:
if position == 1:
if class_index == position:
ind = (f'bond_{position-1}', f'k{position-1}', f'b{position-1}')
else:
ind = (f'bond_{position-2}', f'bond_{position-1}', f'k{position-1}')
elif position == L:
if cyclic and class_index != position:
raise ValueError('Cyclic MPS cannot have class_dim')
ind = (f'bond_{position-2}', f'k{position-1}', f'b{position-1}') if class_index == position else (f'bond_{position-2}', f'bond_{position-1}', f'k{position-1}')
else:
ind = (f'bond_{position-2}', f'bond_{position-1}', f'k{position-1}')
else:
ind = (f'bond_{position-2}', f'bond_{position-1}', f'k{position-1}', f'b{position-1}')
return ind
def MPS_initialize(L: int,
arrays: list = None,
initializer: Initializer = None,
key: Any = None,
dtype: Any = jnp.float_,
shape_method: str = 'even',
bond_dim: int = 4,
phys_dim: int = 2,
cyclic: bool = False,
add_identity: bool = False,
add_to_output: bool = False,
boundary: str = 'obc',
class_index: int = None,
class_dim: int = None,
tags_id: str = 'I{}',
compress: bool = False,
insert: int = None,
canonical_center: int = None,
**kwargs):
"""Initializes :class:`tn4ml.models.mps.MatrixProductState`.
Parameters
----------
L : int
Number of tensors.
initializer : :class:`jax.nn.initializers.Initializer``
Type of tensor initialization function.
key : Array
Argument key is a PRNG key (e.g. from `jax.random.key()`), used to generate random numbers to initialize the array.
dtype : Any
Type of tensor data (from `jax.numpy.float_`)
shape_method : str
Method to generate shapes for tensors.
bond_dim : int
Dimension of virtual indices between tensors. *Default = 4*.
phys_dim : int
Dimension of physical index for individual tensor.
cyclic : bool
Flag for indicating if MatrixProductState is cyclic. *Default=False*.
add_identity : bool
Flag to add identity to tensors diagonal elements.
add_to_output : bool
Flag for adding identity to diagonal elements of tensors with output indices. *Default=False*.
boundary : str
Boundary condition of MatrixProductState. *Default = 'obc'*. obc = open boundary condition. pbc = periodic boundary condition.
class_index : int
Index of tensor that is the output node for class. For classification tasks only.
class_dim : int
Dimension of output node, or number of classes for classification.
compress : bool
Flag to truncate bond dimensions.
insert : int
Index of tensor divided by norm. When `None` the norm division is distributed across all tensors
canonical_center : int
If not `None` then create canonical form around canonical center index.
kwargs : dict
Additional arguments.
Returns
-------
:class:`tn4ml.models.mps.MatrixProductState`
"""
if cyclic and shape_method != 'even':
raise NotImplementedError("Change shape_method to 'even'.")
if initializer is not None and callable(initializer) and 'rand_unitary' in initializer.__qualname__:
if add_identity:
raise ValueError("rand_unitary initializer does not support add_identity.")
if compress:
raise ValueError("rand_unitary initializer does not support compress.")
if insert:
raise ValueError("rand_unitary initializer does not support insert.")
if boundary == 'obc':
boundary = None
if arrays is not None:
# This means MPS for classification needs to be created with qtn.tensor_1d.TensorNetwork1DFlat class
assert class_index is not None # class_index is required when arrays or shapes are provided
if initializer is None:
initializer = randn()
if class_index is not None:
# MPS for classification
if class_index > L:
raise ValueError("class_index should be less than L.")
tensors = []
if arrays is not None:
for i, array in enumerate(arrays):
ind = generate_ind(L, array.shape, i+1, cyclic, class_index)
tensors.append(qtn.Tensor(array,
inds=ind,
tags=tags_id.format(i)))
else:
for i in range(1, L+1):
shape = generate_shape(shape_method, L, bond_dim, phys_dim, cyclic, i, class_index, class_dim)
ind = generate_ind(L, shape, i, cyclic, class_index)
if callable(initializer) and 'rand_unitary' in initializer.__qualname__:
if i < class_index or i > class_index:
array = initializer(key, shape, dtype)
elif i == class_index:
# Output node
array = jnp.asarray(np.random.normal(0., 1., shape), dtype)
else:
raise ValueError("Check value of class_index. It should be less than L.")
else:
array = initializer(key, shape, dtype)
if add_identity:
if len(array.shape) == 3:
copy_array = jnp.copy(array)
copy_array = copy_array.at[:, :, 0].add(jnp.eye(array.shape[0],
array.shape[1],
dtype=dtype))
array = copy_array
elif len(array.shape) == 4: # output node
if add_to_output:
copy_array = jnp.copy(array)
identity = jnp.eye(array.shape[0],
array.shape[1],
dtype=dtype)
identity = jnp.expand_dims(identity, axis=2)
identity = jnp.broadcast_to(identity, (copy_array.shape[0], copy_array.shape[1], copy_array.shape[3]))
copy_array = copy_array.at[:, :, 0, :].add(identity)
array = copy_array
else:
raise ValueError("Tensors need to always be 3D or 4D in MPS for classification.")
if boundary == 'obc':
aux_array = jnp.zeros(array.shape, dtype=dtype)
if i == 1:
# Left node
aux_array = aux_array.at[:,0,:].set(array[:,0,:])
array = aux_array
elif i == L:
# Right node
aux_array = aux_array.at[0,:,:].set(array[0,:,:])
array = aux_array
tensors.append(qtn.Tensor(array,
inds=ind,
tags=tags_id.format(i-1)))
mps = TensorNetwork(tensors, cyclic=cyclic, site_tag_id=tags_id, **kwargs)
if L > 200: # for large systems
for i, tensor in enumerate(mps.tensors):
if i == 0:
mps.left_canonize_site(i)
elif i == L - 1:
tensor.modify(data=tensor.data / jnp.linalg.norm(tensor.data))
else:
tensor.modify(data=tensor.data / jnp.linalg.norm(tensor.data))
mps.left_canonize_site(i)
if canonical_center is not None:
mps.canonicalize(canonical_center, inplace=True)
mps.normalize(insert=canonical_center)
else:
# normalize
if canonical_center is None:
mps.normalize()
else:
mps.canonize(canonical_center, inplace=True)
mps.normalize(insert = canonical_center)
else:
# MPS for regression
if arrays is not None:
tensors = []
for i, array in enumerate(arrays):
tensors.append(jnp.squeeze(array))
else:
tensors = []
for i in range(1, L+1):
shape = generate_shape(shape_method, L, bond_dim, phys_dim, cyclic, i)
tensor = initializer(key, shape, dtype)
if callable(initializer) and 'rand_unitary' not in initializer.__qualname__:
if add_identity:
if len(tensor.shape) == 3:
copy_tensor = jnp.copy(tensor)
copy_tensor.at[:, :, 0].add(jnp.eye(tensor.shape[0],
tensor.shape[1],
dtype=dtype))
tensor = copy_tensor
else:
raise ValueError("There was an error in generating shape. They should be 3D")
if boundary == 'obc':
aux_tensor = jnp.zeros(tensor.shape, dtype=dtype)
if i == 1:
# Left node
aux_tensor = aux_tensor.at[:,0,:].set(tensor[:,0,:])
tensor = aux_tensor
elif i == L:
# Right node
aux_tensor = aux_tensor.at[0,:,:].set(tensor[0,:,:])
tensor = aux_tensor
tensors.append(jnp.squeeze(tensor))
if not (callable(initializer) and 'rand_unitary' in initializer.__qualname__):
if insert and insert < L and shape_method == 'even':
tensors[insert] /= jnp.sqrt(phys_dim)
mps = MatrixProductState(tensors, **kwargs)
if compress:
if shape_method == 'even':
mps.compress(form="flat", max_bond=bond_dim) # limit bond_dim
else:
raise ValueError('Compress only works with shape_method = "even".')
if L > 200: # for large systems
for i, tensor in enumerate(mps.tensors):
if i == 0:
mps.left_canonize_site(i)
elif i == L - 1:
tensor.modify(data=tensor.data / jnp.linalg.norm(tensor.data))
else:
tensor.modify(data=tensor.data / jnp.linalg.norm(tensor.data))
mps.left_canonize_site(i)
if canonical_center is not None:
mps.canonicalize(canonical_center, inplace=True)
mps.normalize(insert=canonical_center)
else:
if canonical_center is None:
norm = mps.norm()
for tensor in mps.tensors:
tensor.modify(data=tensor.data / a.do("power", norm, 1 / L))
else:
mps.canonicalize(canonical_center, inplace=True)
mps.normalize(insert = canonical_center)
return mps