Source code for tn4ml.models.tn

import copy
from typing import Any, Collection
import numpy as np
import autoray as a

from quimb import *
import quimb.tensor as qtn
from jax.nn.initializers import Initializer
import jax.numpy as jnp

from .model import Model
from ..initializers import *

[docs] class TensorNetwork(Model, qtn.tensor_1d.TensorNetwork1DFlat): """A Trainable TensorNetwork class. See :class:`quimb.tensor.tensor_core.TensorNetwork` for explanation of other attributes and methods. """ _EXTRA_PROPS = ("_L", "_site_tag_id", "cyclic")
[docs] def __init__(self, tensors, site_tag_id:str="I{}", cyclic:bool=False, **kwargs): """Initializes :class:`tn4ml.models.tn.ParametrizedTensorNetwork`. Parameters ---------- tensors : list or TensorNetwork List of tensors of :class:`quimb.tensor.tensor_core.Tensor` or :class:quimb.tensor.tensor_core.TensorNetwork. kwargs : dict Additional arguments. """ if isinstance(tensors, TensorNetwork): Model.__init__(self) return Model.__init__(self) qtn.tensor_1d.TensorNetwork1DFlat.__init__(self, tensors, **kwargs) self._L = len(self.tensors) self.cyclic = cyclic self._site_tag_id = site_tag_id
[docs] def canonize(self, where, cur_orthog='calc', info=None, bra=None, inplace=False): """Canonizes the tensor network. """ self.canonicalize(where, cur_orthog=cur_orthog, info=info, bra=bra, inplace=inplace)
[docs] def copy(self, virtual: bool=False, deep: bool=False): """Copies the model. Returns ------- Model of the same type. """ if deep: return copy.deepcopy(self) model = self.__class__(self, virtual=virtual) for key in self.__dict__.keys(): model.__dict__[key] = self.__dict__[key] return model
[docs] def norm(self, **contract_opts) -> float: """Calculates norm of :class:`tn4ml.models.tn.TensorNetwork`. Parameters ---------- contract_opts : Optional Arguments passed to ``contract()``. Returns ------- float Norm of :class:`tn4ml.models.smpo.SpacedMatrixProductOperator` """ norm = self.conj() & self return norm.contract(**contract_opts) ** 0.5
[docs] def normalize(self, insert=None) -> None: """Function for normalizing tensors of :class:`tn4ml.models.tn.TensorNetwork`. Parameters ---------- insert : int Index of tensor divided by norm. *Default = None*. When `None` the norm division is distributed across all tensors. """ if not self.tensors: raise ValueError("The tensor network is empty.") if self.L > 200: # for large systems for i, tensor in enumerate(self.tensors): if i == 0: self.left_canonize_site(i) elif i == self.L - 1: tensor.modify(data=tensor.data / jnp.linalg.norm(tensor.data)) else: tensor.modify(data=tensor.data / jnp.linalg.norm(tensor.data)) self.left_canonize_site(i) else: norm = self.norm() if insert == None: for tensor in self.tensors: tensor.modify(data=tensor.data / a.do("power", norm, 1 / self.L)) else: if not (0 <= insert < len(self.tensors)): raise IndexError(f"Insert index {insert} is out of bounds for the tensor list.") self.tensors[insert].modify(data=self.tensors[insert].data / norm)
def trainable_wrapper(tn: qtn.tensor_1d.TensorNetwork1DFlat, **kwargs) -> qtn.tensor_1d.TensorNetwork1DFlat: """ Creates a wrapper around qtn.tensor_1d.TensorNetwork1DFlat so it can be trainable. Parameters ---------- tn : :class:`quimb.tensor.TensorNetwork` Tensor Network to be trained. Returns ------- :class:`tn4ml.models.tn.TensorNetwork` """ tensors = tn.tensors return TensorNetwork(tensors, **kwargs) def TN_initialize(arrays: list = None, shapes: list = None, key: Any = None, initializer: Initializer = None, inds: Collection[Collection[str]] = None, tags_id: str = 'I{}', cyclic: bool = False, dtype: Any = jnp.float_, **kwargs) -> TensorNetwork: """Initializes a TensorNetwork. Parameters ---------- arrays : list List of arrays to be used as tensors. *Default = None*. If None, shapes must be provided. shapes : list List of shapes of tensors. Each shape should be in LRP(P) format : (left, right, physical) *Default = None*. If None, arrays must be provided. key : Any Random key for initialization. *Default = None*. initializer : from `tn4ml.initializers` or `jax.nn.initializers` Initializer for tensors. *Default = None*. If None, tensors are initialized with random values. Only provided if arrays is None. inds : sequence of arrays of str List of indices for tensors. *Default = None*. Neeeds to be provided because its showing connectivity between tensors. Example for TN with 3 tensors: >>> inds = [['bond0', 'k0'], ['bond0', 'bond1', 'k2'], ['bond1', 'k3']] tags_id : str Tag identifier for tensors. *Default = 'I{}'*. The tag identifier should have a single placeholder for tag number. dtype : Any Data type for tensors. *Default = jnp.float_*. kwargs : dict Additional arguments. Returns ------- :class:`tn4ml.models.tn.TensorNetwork` """ if arrays is None and shapes is None: raise ValueError("Provide either arrays or shapes to create Tensor Network.") L = len(arrays) if arrays is not None else len(shapes) if inds is None: raise ValueError("Provide indices for tensors - connectivity map between tensors.") tensors = [] if arrays is not None: if len(arrays) != len(inds): raise ValueError("Number of tensors and indices should be same.") for i, array in enumerate(arrays): tensors.append(qtn.Tensor(array, inds=inds[i], tags=tags_id.format(i))) elif shapes is not None: if len(shapes) != len(inds): raise ValueError("Number of tensors and indices should be same.") for i, shape in zip(range(1, L+1), shapes): if initializer is not None: array = initializer(key, shape, dtype) else: array = np.asarray(np.random.normal(0., 1., shape), dtype) tensors.append(qtn.Tensor(array, inds=inds[i-1], tags=tags_id.format(i-1))) tn = TensorNetwork(tensors, cyclic=cyclic, site_tag_id=tags_id, **kwargs) # normalize tn.normalize() return tn