Source code for tn4ml.embeddings

import abc
import itertools
import math
from numbers import Number
from typing import Collection, Any, Union, Optional, Dict, List, Tuple

import numpy as onp
import jax
import jax.numpy as jnp
from jax import lax
import quimb.tensor as qtn

import tn4ml.util as u
from tn4ml.scipy.special import eval_legendre, eval_laguerre, eval_hermite

[docs] class Embedding(abc.ABC): """Base class for data embeddings (feature maps). This abstract base class defines the interface for all embedding implementations. Each embedding maps input data to a higher dimensional space for tensor network operations. Attributes ---------- dtype : :class:`numpy.dtype` Data type for computations. Defaults to float32. """
[docs] def __init__(self, dtype: onp.dtype = onp.float32): """Initialize the embedding. Parameters ---------- dtype : :class:`numpy.dtype`, optional Data type for computations, by default float32 """ self.dtype = dtype
@property @abc.abstractmethod def dim(self) -> int: """Get the output dimension of the embedding. Returns ------- int The dimension of the output vector """ pass @property @abc.abstractmethod def input_dim(self) -> int: """Get the input dimension of the embedding. Returns ------- int The dimension of the input (1 for scalar, 2 for vector) """ pass @abc.abstractmethod def __call__(self, x: Number) -> jnp.ndarray: """Apply the embedding to input data. Parameters ---------- x : Number Input data to embed Returns ------- :class:`jax.numpy.ndarray` Embedded vector """ pass
class ComplexEmbedding(abc.ABC): """Base class for complex embeddings with multiple feature dimensions. This abstract base class extends Embedding to handle multiple features, each with its own embedding function. Attributes ---------- dtype : :class:`numpy.dtype` Data type for computations. Defaults to float32. """ def __init__(self, dtype: onp.dtype = onp.float32): """Initialize the complex embedding. Parameters ---------- dtype : :class:`numpy.dtype`, optional Data type for computations, by default float32 """ self.dtype = dtype @property @abc.abstractmethod def dims(self) -> Collection[int]: """Get the output dimensions for each feature. Returns ------- Collection[int] List of dimensions for each feature's output """ pass @property @abc.abstractmethod def input_dims(self) -> jnp.ndarray: """Get the input dimensions for each feature. Returns ------- :class:`jax.numpy.ndarray` Array of input dimensions (1 for scalar, 2 for vector) """ pass @property @abc.abstractmethod def embeddings(self) -> Collection[Embedding]: """Get the embedding functions for each feature. Returns ------- Collection[Embedding] List of embedding functions """ pass @abc.abstractmethod def __call__(self, x: Number) -> jnp.ndarray: """Apply the complex embedding to input data. Parameters ---------- x : Number Input data to embed Returns ------- :class:`jax.numpy.ndarray` Embedded vector """ pass class StateVectorToMPSEmbedding(abc.ABC): """Base class for converting state vectors to Matrix Product States (MPS). This abstract base class provides functionality to convert quantum state vectors into MPS representation. Attributes ---------- dtype : :class:`numpy.dtype` Data type for computations max_bond : Optional[int] Maximum bond dimension for MPS decomposition """ def __init__(self, dtype: onp.dtype = onp.float32, max_bond: Optional[int] = None): """Initialize the state vector to MPS embedding. Parameters ---------- dtype : :class:`numpy.dtype`, optional Data type for computations, by default float32 max_bond : Optional[int], optional Maximum bond dimension for MPS decomposition, by default None """ self.dtype = dtype self.max_bond = max_bond @property @abc.abstractmethod def dims(self) -> list: """Get dimensions of the MPS tensors. Returns ------- list List of tensor shapes """ pass @property @abc.abstractmethod def create_statevector(self, x: jnp.ndarray) -> jnp.ndarray: """Create a state vector from input data. Parameters ---------- x : :class:`jax.numpy.ndarray` Input data Returns ------- :class:`jax.numpy.ndarray` State vector representation """ pass @abc.abstractmethod def __call__(self, x: jnp.ndarray) -> jnp.ndarray: """Convert input data to MPS representation. Parameters ---------- x : :class:`jax.numpy.ndarray` Input data Returns ------- :class:`jax.numpy.ndarray` MPS representation """ pass class MPSEmbedding(abc.ABC): """Base class for converting input data to Matrix Product State (MPS). This abstract base class provides functionality to convert input data into MPS representation using custom decomposition strategies. Attributes ---------- dtype : :class:`numpy.dtype` Data type for computations max_bond : Optional[int] Maximum bond dimension for MPS decomposition """ def __init__(self, dtype: onp.dtype = onp.float32, max_bond: Optional[int] = None): """Initialize the MPS embedding. Parameters ---------- dtype : :class:`numpy.dtype`, optional Data type for computations, by default float32 max_bond : Optional[int], optional Maximum bond dimension for MPS decomposition, by default None """ self.dtype = dtype self.max_bond = max_bond @property @abc.abstractmethod def dims(self) -> list: """Get dimensions of the MPS tensors. Returns ------- list List of tensor shapes """ pass @property @abc.abstractmethod def decompose(self, x: Any, *args) -> jnp.ndarray: """Decompose input data into MPS format. Parameters ---------- x : Any Input data *args : Any Additional arguments for decomposition Returns ------- :class:`jax.numpy.ndarray` Decomposed data in MPS format """ pass @abc.abstractmethod def __call__(self, x: jnp.ndarray) -> jnp.ndarray: """Convert input data to MPS representation. Parameters ---------- x : :class:`jax.numpy.ndarray` Input data Returns ------- :class:`jax.numpy.ndarray` MPS representation """ pass
[docs] class TrigonometricEmbedding(Embedding): """TrigonometricEmbedding feature map with multiple frequency components. Maps input x to :math:`\\phi(x) = \\frac{1}{\\sqrt{k}}[\\cos(\\frac{\\pi}{2}x), \\sin(\\frac{\\pi}{2}x), ..., \\cos(\\frac{\\pi}{2^k}x), \\sin(\\frac{\\pi}{2^k}x)]` Attributes ---------- k : int Number of frequency components (dim/2) """
[docs] def __init__(self, k: int = 1, **kwargs): """Initialize the TrigonometricEmbedding. Parameters ---------- k : int, optional Number of frequency components, by default 1 **kwargs : Any Additional arguments passed to parent class Raises ------ AssertionError If k < 1 """ assert k >= 1, "k must be at least 1" self.k = k super().__init__(**kwargs)
@property def dim(self) -> int: """Get the output dimension (2k).""" return self.k * 2 @property def input_dim(self) -> int: """Get the input dimension (1 for scalar input).""" return 1 def __call__(self, x: Number) -> jnp.ndarray: """Apply TrigonometricEmbedding to input. Parameters ---------- x : Number Input value Returns ------- :class:`jax.numpy.ndarray` Embedded vector with cosine and sine components """ return 1 / jnp.sqrt(self.k) * jnp.array([ f((onp.pi * x / 2**i)) for f, i in itertools.product([jnp.cos, jnp.sin], range(1, self.k + 1)) ])
[docs] class FourierEmbedding(Embedding): """Fourier feature map with multiple frequency components. Maps input x to :math:`\\phi(x) = \\frac{1}{\\sqrt{p}}[\\cos(2\\pi 0 x), ..., \\cos(2\\pi (p-1) x), \\sin(2\\pi 0 x), ..., \\sin(2\\pi (p-1) x)]` Attributes ---------- p : int Number of frequency components """
[docs] def __init__(self, p: int = 2, **kwargs): """Initialize the Fourier embedding. Parameters ---------- p : int, optional Number of frequency components, by default 2 **kwargs : Any Additional arguments passed to parent class Raises ------ AssertionError If p < 1 """ assert p >= 1, "Number of frequency components must be at least 1" self.p = p super().__init__(**kwargs)
@property def dim(self) -> int: """Get the output dimension (2p).""" return self.p @property def input_dim(self) -> int: """Get the input dimension (1 for scalar input).""" return 1 def __call__(self, x: Number) -> jnp.ndarray: """Apply Fourier embedding to input. Parameters ---------- x : Number Input value in [0,1] Returns ------- :class:`jax.numpy.ndarray` Embedded vector with cosine and sine components """ return 1 / self.p * jnp.array([jnp.abs(sum((jnp.exp(1j * 2 * jnp.pi * k * ((self.p - 1) * x - j) / self.p) for k in range(self.p)))) for j in range(self.p)])
[docs] class LinearComplementEmbedding(Embedding): """Linear complement feature map. Maps input x to either [x, 1-x] or [1, x, 1-x] where x is in [0,1]. Attributes ---------- p : int Output dimension (2 or 3) """
[docs] def __init__(self, p: int = 2, **kwargs): """Initialize the linear complement embedding. Parameters ---------- p : int, optional Output dimension (2 or 3), by default 2 **kwargs : Any Additional arguments passed to parent class Raises ------ ValueError If p is not 2 or 3 """ if p not in [2, 3]: raise ValueError('p must be 2 or 3') self.p = p super().__init__(**kwargs)
@property def dim(self) -> int: """Get the output dimension.""" return self.p @property def input_dim(self) -> int: """Get the input dimension (1 for scalar input).""" return 1 def __call__(self, x: Number) -> jnp.ndarray: """Apply linear complement embedding to input. Parameters ---------- x : Number Input value in [0,1] Returns ------- :class:`jax.numpy.ndarray` Embedded vector [x, 1-x] or [1, x, 1-x] """ if self.p == 2: vector = jnp.asarray([x, 1.0 - x]) else: # p == 3 vector = jnp.asarray([1.0, x, 1.0 - x]) return vector / jnp.linalg.norm(vector)
class QuantumBasisEmbedding(Embedding): """Quantum basis feature map using dictionary of quantum states. Maps input x to quantum states from a predefined basis. Attributes ---------- basis : Dict[int, List[float]] Dictionary mapping input values to quantum states """ def __init__(self, basis: Dict[int, List[float]], **kwargs): """Initialize the quantum basis embedding. Parameters ---------- basis : Dict[int, List[float]] Dictionary mapping input values to quantum states **kwargs : Any Additional arguments passed to parent class """ self.basis = basis super().__init__(**kwargs) @property def dim(self) -> int: """Get the output dimension (size of basis).""" return len(self.basis.keys()) @property def input_dim(self) -> int: """Get the input dimension (1 for scalar input).""" return 1 def __call__(self, x: Number) -> jnp.ndarray: """Apply quantum basis embedding to input. Parameters ---------- x : Number Input value (0 or 1) Returns ------- :class:`jax.numpy.ndarray` Corresponding quantum state """ true_fun = lambda _: jnp.array(self.basis[0]) false_fun = lambda _: jnp.array(self.basis[1]) return lax.cond(x == 0, true_fun, false_fun, None)
[docs] class GaussianRBFEmbedding(Embedding): """Gaussian Radial Basis Function embedding. Maps input x to Gaussian RBF features centered at specified points. Attributes ---------- centers : onp.ndarray Centers for Gaussian RBFs gamma : float Scaling factor :math:`\\gamma=\\frac{1}{2\\sigma^2}` """
[docs] def __init__(self, centers: Optional[onp.ndarray] = None, gamma: Optional[float] = None, **kwargs): """Initialize the Gaussian RBF embedding. Parameters ---------- centers : Optional[onp.ndarray], optional Centers for Gaussian RBFs, by default None gamma : Optional[float], optional Scaling factor, by default None **kwargs : Any Additional arguments passed to parent class """ self.centers = centers self.gamma = gamma super().__init__(**kwargs)
@property def dim(self) -> int: """Get the output dimension (product of centers shape).""" return jnp.prod(onp.array(self.centers.shape)) @property def input_dim(self) -> int: """Get the input dimension (1 for scalar input).""" return 1 def __call__(self, x: jnp.ndarray) -> jnp.ndarray: """Apply Gaussian RBF embedding to input. Parameters ---------- x : :class:`jax.numpy.ndarray` Input value Returns ------- :class:`jax.numpy.ndarray` Gaussian RBF features """ vector = jnp.exp(-self.gamma * jnp.subtract(x, jnp.array(self.centers))) return vector / jnp.linalg.norm(vector)
[docs] class PolynomialEmbedding(Embedding): """PolynomialEmbedding feature map. Maps input x to PolynomialEmbedding features up to specified degree. Attributes ---------- degree : int Maximum PolynomialEmbedding degree n : int Number of input features include_bias : bool Whether to include constant term """
[docs] def __init__(self, degree: int, n: int, include_bias: bool = False, include_cross_terms: bool = False, **kwargs): """Initialize the PolynomialEmbedding. Parameters ---------- degree : int Maximum PolynomialEmbedding degree n : int Number of input features include_bias : bool, optional Whether to include constant term, by default False **kwargs : Any Additional arguments passed to parent class Raises ------ ValueError If degree < 1 """ if degree < 1: raise ValueError("Degree of PolynomialEmbedding must be at least 1.") self.degree = degree self.n = n self.include_bias = include_bias self.include_cross_terms = include_cross_terms super().__init__(**kwargs)
@property def dim(self) -> int: """Get the output dimension based on degree, bias, and cross terms options.""" total_features = 0 # Add bias term if requested if self.include_bias: total_features += 1 # Calculate features for each degree for d in range(1, self.degree + 1): if self.include_cross_terms: # Include all combinations (pure and cross terms) total_features += math.comb(self.input_dim + d - 1, d) else: # Only pure terms (no cross interactions) # For degree d, we have input_dim terms (x₁^d, x₂^d, ..., xₙ^d) total_features += self.input_dim return total_features @property def input_dim(self) -> int: """Get the input dimension.""" return self.n def __call__(self, x: Union[Number, onp.array]) -> jnp.ndarray: """Apply PolynomialEmbedding embedding to input. Parameters ---------- x : Union[:class:`jax.numpy.ndarray`, :class:`numpy.ndarray`] Input features Returns ------- :class:`jax.numpy.ndarray` PolynomialEmbedding features """ if x.ndim == 0: x = jnp.array([x]) features = [] # Add bias term if requested if self.include_bias: features.append(1.0) # Generate polynomial features for d in range(1, self.degree + 1): for combination in itertools.combinations_with_replacement(range(len(x)), d): # Check if this is a cross term (involves multiple different variables) is_cross_term = len(set(combination)) > 1 # Skip cross terms if they're not wanted if is_cross_term and not self.include_cross_terms: continue product = jnp.prod(x[jnp.array(combination)]) features.append(product) return jnp.array(features)
class LegendreEmbedding(Embedding): """Legendre PolynomialEmbedding feature map. Maps input x to Legendre PolynomialEmbedding features. Attributes ---------- degree : int Maximum PolynomialEmbedding degree """ def __init__(self, degree: int = 2, **kwargs): """Initialize the Legendre embedding. Parameters ---------- degree : int, optional Maximum PolynomialEmbedding degree, by default 2 **kwargs : Any Additional arguments passed to parent class """ self.degree = degree super().__init__(**kwargs) @property def dim(self) -> int: """Get the output dimension (degree + 1).""" return self.degree + 1 @property def input_dim(self) -> int: """Get the input dimension (1 for scalar input).""" return 1 def __call__(self, x: Number) -> jnp.ndarray: """Apply Legendre PolynomialEmbedding embedding to input. Parameters ---------- x : Number Input value in [-1, 1] Returns ------- :class:`jax.numpy.ndarray` Legendre PolynomialEmbedding features """ features = jnp.array([eval_legendre(k, x) for k in range(self.degree + 1)]) return features class LaguerreEmbedding(Embedding): """Laguerre PolynomialEmbedding feature map with isometric weighting. Maps input x to weighted Laguerre PolynomialEmbedding features. Attributes ---------- degree : int Maximum PolynomialEmbedding degree """ def __init__(self, degree: int = 2, **kwargs): """Initialize the Laguerre embedding. Parameters ---------- degree : int, optional Maximum PolynomialEmbedding degree, by default 2 **kwargs : Any Additional arguments passed to parent class """ self.degree = degree super().__init__(**kwargs) @property def dim(self) -> int: """Get the output dimension (degree + 1).""" return self.degree + 1 @property def input_dim(self) -> int: """Get the input dimension (1 for scalar input).""" return 1 def __call__(self, x: Number) -> jnp.ndarray: """Apply weighted Laguerre PolynomialEmbedding embedding to input. Parameters ---------- x : Number Input value in [0, ∞) Returns ------- :class:`jax.numpy.ndarray` Weighted Laguerre PolynomialEmbedding features """ weight = jnp.exp(-x / 2) features = jnp.array([weight * eval_laguerre(k, x) for k in range(self.degree + 1)]) return features class HermiteEmbedding(Embedding): """Hermite PolynomialEmbedding feature map with isometric weighting. Maps input x to weighted Hermite PolynomialEmbedding features. Attributes ---------- degree : int Maximum PolynomialEmbedding degree """ def __init__(self, degree: int = 2, **kwargs): """Initialize the Hermite embedding. Parameters ---------- degree : int, optional Maximum PolynomialEmbedding degree, by default 2 **kwargs : Any Additional arguments passed to parent class """ self.degree = degree super().__init__(**kwargs) @property def dim(self) -> int: """Get the output dimension (degree + 1).""" return self.degree + 1 @property def input_dim(self) -> int: """Get the input dimension (1 for scalar input).""" return 1 def __call__(self, x: Number) -> jnp.ndarray: """Apply weighted Hermite PolynomialEmbedding embedding to input. Parameters ---------- x : Number Input value in R Returns ------- :class:`jax.numpy.ndarray` Weighted Hermite PolynomialEmbedding features """ weight = jnp.exp(-0.5 * x**2) features = jnp.array([weight * eval_hermite(k, x) for k in range(self.degree + 1)]) return features
[docs] class JaxArraysEmbedding(Embedding): """Simple embedding that converts input arrays to JAX arrays. Optionally adds a bias term to the input. Attributes ---------- dim : Optional[int] Output dimension add_bias : bool Whether to add bias term input_dim : Optional[int] Input dimension """
[docs] def __init__(self, dim: Optional[int] = None, add_bias: bool = False, input_dim: Optional[int] = None, **kwargs): """Initialize the JAX arrays embedding. Parameters ---------- dim : Optional[int], optional Output dimension, by default None add_bias : bool, optional Whether to add bias term, by default False input_dim : Optional[int], optional Input dimension, by default None **kwargs : Any Additional arguments passed to parent class """ super().__init__(**kwargs) self._dim = dim self.add_bias = add_bias self._input_dim = input_dim
@property def dim(self) -> int: """Get the output dimension.""" return self._dim @property def input_dim(self) -> int: """Get the input dimension.""" return self._input_dim def __call__(self, x: Any) -> jnp.ndarray: """Convert input to JAX array, optionally adding bias. Parameters ---------- x : Any Input data Returns ------- :class:`jax.numpy.ndarray` JAX array with optional bias term """ if self.add_bias: return jnp.concatenate([jnp.array([1.]), x]) return jnp.array(x)
class TrigonometricEmbeddingChain(ComplexEmbedding): """TrigonometricEmbedding feature map for each dimension of feature. Maps each feature dimension to TrigonometricEmbedding features. Attributes ---------- k : int Number of frequency components per dimension input_shape : tuple Shape of input (n_features, n_dims_per_feature) """ def __init__(self, k: int = 1, input_shape: tuple = (2, 2), **kwargs): """Initialize the TrigonometricEmbedding chain embedding. Parameters ---------- k : int, optional Number of frequency components, by default 1 input_shape : tuple, optional Input shape (n_features, n_dims_per_feature), by default (2, 2) **kwargs : Any Additional arguments passed to parent class Raises ------ AssertionError If k < 1 """ assert k >= 1, "k must be at least 1" self.k = k self.input_shape = input_shape super().__init__(**kwargs) @property def dims(self) -> Collection[int]: """Get output dimensions for each feature.""" return [self.k * 2 * self.input_shape[1]] * self.input_shape[0] @property def input_dims(self) -> jnp.ndarray: """Get input dimensions for each feature.""" return jnp.array([1] * self.k) @property def embeddings(self) -> Collection[Embedding]: """Get TrigonometricEmbedding embeddings for each dimension.""" return [TrigonometricEmbedding(k=self.k) for _ in range(self.input_shape[1])] def __call__(self, x: Collection) -> jnp.ndarray: """Apply TrigonometricEmbedding chain embedding to input. Parameters ---------- x : Collection Input features Returns ------- :class:`jax.numpy.ndarray` Concatenated TrigonometricEmbedding features """ embedded = [] for f, xi in zip(self.embeddings, x): embedded.extend(f(xi)) return jnp.array(embedded) class TrigonometricEmbeddingAvg(ComplexEmbedding): """TrigonometricEmbedding feature map for mean of features. Maps the mean of input features to TrigonometricEmbedding features. Attributes ---------- k : int Number of frequency components input_shape : tuple Shape of input (n_features, n_dims_per_feature) """ def __init__(self, k: int = 1, input_shape: tuple = (2, 2), **kwargs): """Initialize the TrigonometricEmbedding average embedding. Parameters ---------- k : int, optional Number of frequency components, by default 1 input_shape : tuple, optional Input shape (n_features, n_dims_per_feature), by default (2, 2) **kwargs : Any Additional arguments passed to parent class Raises ------ AssertionError If k < 1 """ assert k >= 1, "k must be at least 1" self.k = k self.input_shape = input_shape super().__init__(**kwargs) @property def dims(self) -> int: """Get output dimensions for each feature.""" return [self.k * 2] * self.input_shape[0] @property def input_dims(self) -> jnp.ndarray: """Get input dimensions for each feature.""" return jnp.array([1] * self.k) @property def embeddings(self) -> Embedding: """Get TrigonometricEmbedding embedding.""" return TrigonometricEmbedding(k=self.k) def __call__(self, features: Collection) -> jnp.ndarray: """Apply TrigonometricEmbedding average embedding to input. Parameters ---------- features : Collection Input features Returns ------- :class:`jax.numpy.ndarray` TrigonometricEmbedding features of mean """ return self.embeddings(jnp.mean(features)) class BasePatchEmbedding(StateVectorToMPSEmbedding): """Base class for patch-based embeddings that convert input data to MPS. Attributes ---------- k : int Kernel size of patch window (k×k) mps : Optional[qtn.MatrixProductState] Current MPS representation """ def __init__(self, k: int = 2, **kwargs): """Initialize the base patch embedding. Parameters ---------- k : int, optional Kernel size of patch window, by default 2 **kwargs : Any Additional arguments passed to parent class """ super().__init__(**kwargs) self.k = k self.mps = None @property def dims(self) -> list: """Get dimensions of the MPS tensors.""" return list([tensor.shape for tensor in self.mps.tensors]) def pad_or_truncate_statevector(self, statevector: jnp.ndarray, target_size: int) -> jnp.ndarray: """Pad or truncate statevector to target size. Parameters ---------- statevector : :class:`jax.numpy.ndarray` Input statevector target_size : int Target size Returns ------- :class:`jax.numpy.ndarray` Padded or truncated statevector """ current_size = statevector.shape[0] if current_size < target_size: padding = [(0, target_size - current_size)] statevector = jnp.pad(statevector, padding, mode='constant') else: statevector = statevector[:target_size] return statevector def combine_mps_patches(self, mps_patches: onp.ndarray, n_qubits: int) -> jnp.ndarray: """Combine MPS patches into single MPS. Parameters ---------- mps_patches : onp.ndarray List of MPS patches n_qubits : int Number of qubits Returns ------- :class:`jax.numpy.ndarray` Combined MPS arrays """ new_arrays = [] number_interval = 0 for patch in mps_patches: for i, arr in enumerate(patch): if i == number_interval * n_qubits and len(arr.shape) == 2: new_arrays.append(jnp.expand_dims(arr, axis=0)) elif i == ((number_interval + 1) * n_qubits - 1) and len(arr.shape) == 2: new_arrays.append(jnp.expand_dims(arr, axis=-1)) number_interval += 1 else: new_arrays.append(arr) return new_arrays @property @abc.abstractmethod def create_statevector(self, x: jnp.ndarray) -> jnp.ndarray: """Create statevector from input data. Parameters ---------- x : :class:`jax.numpy.ndarray` Input data Returns ------- :class:`jax.numpy.ndarray` Statevector representation """ pass def __call__(self, x: jnp.ndarray) -> qtn.MatrixProductState: """Convert input data to MPS representation. Parameters ---------- x : :class:`jax.numpy.ndarray` Input data (typically an image) Returns ------- qtn.MatrixProductState MPS representation Raises ------ ValueError If input is not square or patch size is too large """ H, W = x.shape if H != W: raise ValueError(f"Only square matrices are supported, got {H}x{W} image.") if self.k > H: raise ValueError(f"Patch dimension k = {self.k} is too large for {H}x{W} images.") patches = u.divide_into_patches(x, self.k) mps_patches = [] for patch in patches: patch_data = patch.ravel() if not hasattr(self, 'flatten_snake') else self.flatten_snake(patch) statevector, n_qubits = self.create_statevector(patch_data) mps_arrays = u.from_dense_to_mps(statevector, n_qubits, self.max_bond) mps_patches.append(mps_arrays) new_arrays = self.combine_mps_patches(mps_patches, n_qubits) # Recreate the MPS with the reshaped arrays self.mps = qtn.MatrixProductState(new_arrays, shape='lrp') return self.mps
[docs] class PatchEmbedding(BasePatchEmbedding): """Embedding that converts image patches to MPS using basis encoding."""
[docs] def flatten_snake(self, image: jnp.ndarray) -> jnp.ndarray: """Flatten image in snake-like fashion. Parameters ---------- image : :class:`jax.numpy.ndarray` Input image Returns ------- :class:`jax.numpy.ndarray` Flattened image in snake-like order """ image = jnp.where( jnp.arange(image.shape[0])[:, None] % 2 == 1, jnp.flip(image, axis=1), image ) return image.reshape(-1)
[docs] def create_statevector(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, int]: """Create statevector using basis encoding. Parameters ---------- x : :class:`jax.numpy.ndarray` Input patch data Returns ------- Tuple[:class:`jax.numpy.ndarray`, int] Statevector and number of qubits """ # Number of pixels (N = 16 for a 4x4 image) N = len(x) # Number of address qubits is log2(N) = 4 n_address_qubits = int(onp.ceil(onp.log2(N))) # One color qubit n_color_qubit = 1 # Total number of qubits = address qubits + 1 color qubit n_qubits = n_address_qubits + n_color_qubit # Create index tensors for addressing state_indices = jnp.arange(2**n_qubits) # Color qubit is the least significant bit color_bits = state_indices % 2 # Address qubits are the most significant bits address_indices = state_indices // 2 # Calculate cos and sin for each pixel intensity cos_values = jnp.cos(math.pi * x / 2) sin_values = jnp.sin(math.pi * x / 2) # Create the statevector with color qubit encoding statevector = jnp.where( color_bits == 0, cos_values[address_indices], sin_values[address_indices] ) # Normalize the statevector statevector /= jnp.linalg.norm(statevector) # Pad or truncate to fixed size fixed_size = 2**n_qubits padded_statevector = self.pad_or_truncate_statevector(statevector.flatten(), fixed_size) return padded_statevector, n_qubits
[docs] class PatchAmplitudeEmbedding(BasePatchEmbedding): """Embedding that converts image patches to MPS using amplitude encoding."""
[docs] def create_statevector(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, int]: """Create statevector using amplitude encoding. Parameters ---------- x : :class:`jax.numpy.ndarray` Input patch data Returns ------- Tuple[:class:`jax.numpy.ndarray`, int] Statevector and number of qubits """ N = len(x) n_qubits = int(onp.ceil(onp.log2(N))) statevector = jnp.sqrt(x) statevector /= jnp.linalg.norm(statevector) fixed_size = 2**n_qubits padded_statevector = self.pad_or_truncate_statevector(statevector.flatten(), fixed_size) return padded_statevector, n_qubits
[docs] def embed(x: onp.ndarray, phi: Union[Embedding, ComplexEmbedding, StateVectorToMPSEmbedding], **mps_opts) -> qtn.MatrixProductState: """Create product state from feature vector. Works only if features are separated and not correlated. Parameters ---------- x : onp.ndarray Vector or matrix of features phi : Union[Embedding, ComplexEmbedding, StateVectorToMPSEmbedding] Feature map for each feature **mps_opts : Any Additional arguments passed to MatrixProductState Returns ------- qtn.MatrixProductState Product state representation Raises ------ TypeError If phi is not a valid embedding type """ if not issubclass(type(phi), (ComplexEmbedding, Embedding, StateVectorToMPSEmbedding)): raise TypeError('Invalid embedding type') if issubclass(type(phi), Embedding): arrays = [phi(xi).reshape((1, 1, phi.dim)) for xi in x] for i in [0, -1]: arrays[i] = arrays[i].reshape((1, phi.dim)) mps = qtn.MatrixProductState(arrays, **mps_opts) elif issubclass(type(phi), ComplexEmbedding) and x.ndim == 2: if type(phi.dims) == int: arrays = [phi(xi).reshape((1, 1, phi.dims)) for xi in x] for i in [0, -1]: arrays[i] = arrays[i].reshape((1, phi.dims)) else: arrays = [phi(xi).reshape((1, 1, phi.dims[i])) for i, xi in enumerate(x)] for i in [0, -1]: arrays[i] = arrays[i].reshape((1, phi.dims[i])) mps = qtn.MatrixProductState(arrays, **mps_opts) else: mps = phi(x) # Normalize if len(mps.tensors) > 200: # For large systems for i, tensor in enumerate(mps.tensors): if i == 0: mps.left_canonize_site(i) elif i == len(mps.tensors) - 1: tensor.modify(data=tensor.data / jnp.linalg.norm(tensor.data)) else: tensor.modify(data=tensor.data / jnp.linalg.norm(tensor.data)) mps.left_canonize_site(i) else: norm = mps.norm() for tensor in mps.tensors: tensor.modify(data=tensor.data / jnp.power(norm, 1 / len(mps.tensors))) return mps