Initializers#

tn4ml.initializers.zeros(std=1e-09, dtype=<class 'jax.numpy.float64'>)[source][source]#

Builds an initializer that initializes tensors with zeros. Plus small noise.

Examples

>>> import jax, jax.numpy as jnp
>>> from tn4ml.initializers import zeros_init
>>> initializer = zeros_init()
>>> initializer(jax.random.key(42), (2, 2), jnp.float32)
Array([[0., 0.],
   [0., 0.]], dtype=float32)
Parameters:
  • std (Any)

  • dtype (Any)

Return type:

Initializer

tn4ml.initializers.ones(std=1e-09, dtype=<class 'jax.numpy.float64'>)[source][source]#

Builds an initializer that initializes tensors with ones. Plus small noise.

Examples

>>> import jax, jax.numpy as jnp
>>> from tn4ml.initializers import ones_init
>>> initializer = ones_init()
>>> initializer(jax.random.key(42), (2, 2), jnp.float32)
Array([[1., 1.],
   [1., 1.]], dtype=float32)
Parameters:
  • std (Any)

  • dtype (Any)

Return type:

Initializer

tn4ml.initializers.gramschmidt(dist, scale=0.01, dtype=<class 'jax.numpy.float64'>)[source][source]#

Builds an initializer that initializes tensors with Gram-Schmidt orthogonalization procedure. First, arrays are sampled from uniform or normal distribution (specified by dist argument)

Parameters:
  • dist (str) – Sampling distribution of arrays. Options: uniform, normal.

  • scale (Any (Optional). Default = 1e-2.) – Scaling factor for the sampled arrays.

  • dtype (Any (Optional)) – The initializer’s default dtype.

Return type:

Initializer

Examples

>>> import jax, jax.numpy as jnp
>>> from tn4ml.initializers import gramschmidt_init
>>> initializer = gramschmidt_init('normal')
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)
Array([[ 0.35777482,  0.65598017,  0.6645954 ],
   [-0.57674366, -0.40450865,  0.70974606]], dtype=float32)
tn4ml.initializers.identity(type, std=None, dtype=<class 'jax.numpy.float64'>)[source][source]#

Builds an initializer that initializes tensors with identity either on diagonal elements, or in bond dimensions.

Parameters:
  • type (str. Options: 'copy', 'bond') – ‘copy’ = diagonal elements, ‘bond’ = bond dimension elements

  • std (Any (Optional)) – Additonal noise

  • dtype (Any (Optional). Default = jnp.float_.) – The initializer’s default dtype.

Return type:

Initializer

Examples

>>> import jax, jax.numpy as jnp
>>> from tn4ml.initializers import gramschmidt_init
>>> initializer = identity_init('copy', 1e-2)
>>> initializer(jax.random.key(42), (3, 2), jnp.float32)
Array([[ 1.0061227 ,  0.01122588],
   [ 0.01137332,  0.99187267],
   [-0.00890405,  0.00126231]], dtype=float32)
tn4ml.initializers.randn(std=1.0, mean=0.0, noise_std=None, noise_mean=None, dtype=<class 'jax.numpy.float64'>)[source][source]#

Builds an initializer that initializes tensor values with normal distribution.

Parameters:
  • std (Any (Optional). Default = 1.0.) – Standard deviation of the normal distribution

  • mean (Any (Optional). Default = 0.0.) – Mean of the normal distribution.

  • noise_std (Any (Optional). Default = None.) – The standard deviation of the noise distribution (normal).

  • noise_mean (Any (Optional). Default = None.) – The mean of the noise distribution (normal).

  • dtype (Any (Optional). Default = jnp.float_.) – The initializer’s default dtype.

Return type:

Initializer

Examples

>>> import jax, jax.numpy as jnp
>>> from tn4ml.initializers import randn_init
>>> initializer = randn(1e-2)
>>> initializer(jax.random.key(42), (2, 2), jnp.float32)
Array([[ 0.00186935,  0.01065333],
        [-0.01559313, -0.01535296]], dtype=float32)
tn4ml.initializers.unitary_matrix(key, shape, dtype=<class 'jax.numpy.float64'>)[source][source]#
  • from @joserapa98/tensorkrowch

Generates random unitary matrix from the Haar measure of size n x n.

Unitary matrix is created as described in this paper.

Parameters:
  • key (Any) – Random key.

  • shape (core.Shape) – Shape of the tensor.

  • dtype (Any) – Data type of the tensor.

Returns:

Random unitary matrix.

Return type:

jnp.ndarray

tn4ml.initializers.rand_unitary(dtype=<class 'jax.numpy.float64'>)[source][source]#

Builds an initializer that initializes tensor with stack of random unitary matrices.

Parameters:

dtype (Any (Optional). Default = jnp.float_.) – The initializer’s default dtype.

Return type:

Initializer

Examples

>>> import jax, jax.numpy as jnp
>>> from tn4ml.initializers import rand_unitary
>>> initializer = rand_unitary()
>>> initializer(jax.random.key(42), (2, 2), jnp.float32)
Array([[ 0.11903083,  0.99289054],
        [-0.99289054,  0.11903088]], dtype=float32)
>>> tensor = initializer(jax.random.key(42), (2, 2), jnp.float32)
>>> jnp.allclose(tensor @ tensor.T.conj(), jnp.eye(2), atol=1e-6)
True