Initializers#
- tn4ml.initializers.zeros(std=1e-09, dtype=<class 'jax.numpy.float64'>)[source][source]#
Builds an initializer that initializes tensors with zeros. Plus small noise.
Examples
>>> import jax, jax.numpy as jnp >>> from tn4ml.initializers import zeros_init >>> initializer = zeros_init() >>> initializer(jax.random.key(42), (2, 2), jnp.float32) Array([[0., 0.], [0., 0.]], dtype=float32)
- Parameters:
std (Any)
dtype (Any)
- Return type:
Initializer
- tn4ml.initializers.ones(std=1e-09, dtype=<class 'jax.numpy.float64'>)[source][source]#
Builds an initializer that initializes tensors with ones. Plus small noise.
Examples
>>> import jax, jax.numpy as jnp >>> from tn4ml.initializers import ones_init >>> initializer = ones_init() >>> initializer(jax.random.key(42), (2, 2), jnp.float32) Array([[1., 1.], [1., 1.]], dtype=float32)
- Parameters:
std (Any)
dtype (Any)
- Return type:
Initializer
- tn4ml.initializers.gramschmidt(dist, scale=0.01, dtype=<class 'jax.numpy.float64'>)[source][source]#
Builds an initializer that initializes tensors with Gram-Schmidt orthogonalization procedure. First, arrays are sampled from uniform or normal distribution (specified by dist argument)
- Parameters:
dist (str) – Sampling distribution of arrays. Options: uniform, normal.
scale (Any (Optional). Default = 1e-2.) – Scaling factor for the sampled arrays.
dtype (Any (Optional)) – The initializer’s default dtype.
- Return type:
Initializer
Examples
>>> import jax, jax.numpy as jnp >>> from tn4ml.initializers import gramschmidt_init >>> initializer = gramschmidt_init('normal') >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.35777482, 0.65598017, 0.6645954 ], [-0.57674366, -0.40450865, 0.70974606]], dtype=float32)
- tn4ml.initializers.identity(type, std=None, dtype=<class 'jax.numpy.float64'>)[source][source]#
Builds an initializer that initializes tensors with identity either on diagonal elements, or in bond dimensions.
- Parameters:
type (str. Options: 'copy', 'bond') – ‘copy’ = diagonal elements, ‘bond’ = bond dimension elements
std (Any (Optional)) – Additonal noise
dtype (Any (Optional). Default = jnp.float_.) – The initializer’s default dtype.
- Return type:
Initializer
Examples
>>> import jax, jax.numpy as jnp >>> from tn4ml.initializers import gramschmidt_init >>> initializer = identity_init('copy', 1e-2) >>> initializer(jax.random.key(42), (3, 2), jnp.float32) Array([[ 1.0061227 , 0.01122588], [ 0.01137332, 0.99187267], [-0.00890405, 0.00126231]], dtype=float32)
- tn4ml.initializers.randn(std=1.0, mean=0.0, noise_std=None, noise_mean=None, dtype=<class 'jax.numpy.float64'>)[source][source]#
Builds an initializer that initializes tensor values with normal distribution.
- Parameters:
std (Any (Optional). Default = 1.0.) – Standard deviation of the normal distribution
mean (Any (Optional). Default = 0.0.) – Mean of the normal distribution.
noise_std (Any (Optional). Default = None.) – The standard deviation of the noise distribution (normal).
noise_mean (Any (Optional). Default = None.) – The mean of the noise distribution (normal).
dtype (Any (Optional). Default = jnp.float_.) – The initializer’s default dtype.
- Return type:
Initializer
Examples
>>> import jax, jax.numpy as jnp >>> from tn4ml.initializers import randn_init >>> initializer = randn(1e-2) >>> initializer(jax.random.key(42), (2, 2), jnp.float32) Array([[ 0.00186935, 0.01065333], [-0.01559313, -0.01535296]], dtype=float32)
- tn4ml.initializers.unitary_matrix(key, shape, dtype=<class 'jax.numpy.float64'>)[source][source]#
from @joserapa98/tensorkrowch
Generates random unitary matrix from the Haar measure of size n x n.
Unitary matrix is created as described in this paper.
- Parameters:
key (Any) – Random key.
shape (core.Shape) – Shape of the tensor.
dtype (Any) – Data type of the tensor.
- Returns:
Random unitary matrix.
- Return type:
jnp.ndarray
- tn4ml.initializers.rand_unitary(dtype=<class 'jax.numpy.float64'>)[source][source]#
Builds an initializer that initializes tensor with stack of random unitary matrices.
- Parameters:
dtype (Any (Optional). Default = jnp.float_.) – The initializer’s default dtype.
- Return type:
Initializer
Examples
>>> import jax, jax.numpy as jnp >>> from tn4ml.initializers import rand_unitary >>> initializer = rand_unitary() >>> initializer(jax.random.key(42), (2, 2), jnp.float32) Array([[ 0.11903083, 0.99289054], [-0.99289054, 0.11903088]], dtype=float32) >>> tensor = initializer(jax.random.key(42), (2, 2), jnp.float32) >>> jnp.allclose(tensor @ tensor.T.conj(), jnp.eye(2), atol=1e-6) True