Source code for tn4ml.embeddings

import abc
import itertools
import math
from numbers import Number
from typing import Collection, Any, Union
import numpy as onp
from autoray import numpy as np
import autoray as a
import jax.numpy as jnp
from jax import lax
import quimb.tensor as qtn
import tn4ml.util as u

[docs] class Embedding: """Data embedding (feature map) class. Attributes ---------- dtype: :class:`numpy.dtype` Data Type """
[docs] def __init__(self, dtype=onp.float32): self.dtype = dtype
@property @abc.abstractmethod def dim(self) -> int: """ Mapping dimension """ pass @property @abc.abstractmethod def input_dim(self) -> int: """ Dimensionality of input feature. 1 = number, 2 = vector """ pass @abc.abstractmethod def __call__(self, x: Number) -> jnp.ndarray: pass
class ComplexEmbedding: """Complex data embedding (feature map) class where each feature has its own choosen embedding. Attributes ---------- dtype: :class:`numpy.dtype` Data Type """ def __init__(self, dtype=onp.float32): self.dtype = dtype @property @abc.abstractmethod def dims(self) -> Collection[int]: """ Mapping dimensions per feature """ pass @property @abc.abstractmethod def input_dims(self) -> jnp.ndarray: """ Dimensionality of each input feature. 1 = number, 2 = vector """ pass @property @abc.abstractmethod def embeddings(self) -> Collection[Embedding]: """ Embedding for each feature """ pass @abc.abstractmethod def __call__(self, x: Number) -> jnp.ndarray: pass class StateVectorToMPSEmbedding: """ A class to convert a statevector into a Matrix Product State (MPS). """ def __init__(self, dtype=onp.float32, max_bond=None): self.dtype = dtype self.max_bond = max_bond @property @abc.abstractmethod def dims(self) -> list: """ Dimensions of mps arrays """ pass @property @abc.abstractmethod def create_statevector(self, x: jnp.ndarray) -> jnp.ndarray: """ Method to create a statevector """ pass @abc.abstractmethod def __call__(self, x: jnp.ndarray) -> jnp.ndarray: """ Method to convert a Statevector into an Matrix Product State """ pass
[docs] class trigonometric(Embedding): """ Trigonometric feature map. :math:`\\phi(x_\\textit{j}) = \\left[ cos(\\frac{\\pi}{2}x_\\textit{j}), sin(\\frac{\pi}{2}x_\\textit{j}) \\right]` Attributes ---------- k: int Custom parameter = ``dim/2``. """
[docs] def __init__(self, k: int = 1, **kwargs): assert k >= 1 self.k = k super().__init__(**kwargs)
@property def dim(self) -> int: """ Mapping dimension """ return self.k * 2 @property def input_dim(self) -> int: return 1 def __call__(self, x: Number) -> jnp.ndarray: """Embedding function for trigonometric. Parameters ---------- x: :class:`Number` Input feature. Returns ------- jnp.ndarray Embedding vector. """ return 1 / jnp.sqrt(self.k) * jnp.array([f((onp.pi * x / 2**i)) for f, i in itertools.product([jnp.cos, jnp.sin], range(1, self.k + 1))])
[docs] class fourier(Embedding): """ Fourier feature map. :math:`\\phi(x_\\textit{j}) = \\frac{1}{\\sqrt{k}}\\left[ cos(\\frac{\\pi x_\\textit{j}}{2}), sin(\\frac{\\pi x_\\textit{j}}{2}), ..., cos(\\frac{\\pi x_\\textit{j}}{2^k}), sin(\\frac{\\pi x_\\textit{j}}{2^k})\\right]` Attributes ---------- p: int Mapping dimension. """
[docs] def __init__(self, p: int = 2, **kwargs): #assert p >= 2 self.p = p super().__init__(**kwargs)
@property def dim(self) -> int: """ Mapping dimension """ return self.p @property def input_dim(self) -> int: return 1 def __call__(self, x: Number) -> jnp.ndarray: """Embedding function for Fourier. Parameters ---------- x: :class:`Number` Input feature. Returns ------- jnp.ndarray Embedding vector.""" return 1 / self.p * jnp.array([np.abs(sum((np.exp(1j * 2 * onp.pi * k * ((self.p - 1) * x - j) / self.p) for k in range(self.p)))) for j in range(self.p)])
[docs] class linear_complement_map(Embedding): """Feature map :math:`[x, 1-x]` or :math:`[1, x, 1-x]` where x = feature in range [0,1]. Attributes ---------- p: int Mapping dimension. """
[docs] def __init__(self, p: int = 2, **kwargs): self.p = p super().__init__(**kwargs)
@property def dim(self) -> int: """ Mapping dimension """ return self.p @property def input_dim(self) -> int: return 1 def __call__(self, x: Number) -> jnp.ndarray: """Embedding function for original inverse. Parameters ---------- x: :class:`Number` Input feature. Returns ------- jnp.ndarray Embedding vector. """ if self.p == 2: vector = jnp.asarray([x, 1.0 - x]) elif self.p == 3: vector = jnp.asarray([1.0, x, 1.0 - x]) else: raise ValueError('Invalid dimension') return vector / jnp.linalg.norm(vector)
class quantum_basis(Embedding): """ Basis quantum feature map. The basis is a dictionary of quantum states. Attributes ---------- basis: :class:`numpy.ndarray` quantum state map. Example: {0: [1, 0], 1: [0, 1]} """ def __init__(self, basis: dict, **kwargs): self.basis = basis super().__init__(**kwargs) @property def dim(self) -> int: """ Mapping dimension """ return len(self.basis.keys()) @property def input_dim(self) -> int: return 1 def __call__(self, x: Number) -> jnp.ndarray: """Embedding function for basis encoding. Parameters ---------- x: :class:`Number` Input feature. Returns ------- jnp.ndarray Embedding vector. """ true_fun = lambda _: jnp.array(self.basis[0]) false_fun = lambda _: jnp.array(self.basis[1]) return lax.cond(x == 0, true_fun, false_fun, None)
[docs] class gaussian_rbf(Embedding): """ Gaussian Radial Basis Function. Attributes ---------- centers: :class:`numpy.ndarray` Gaussian centers. gamma: float Scaling factor :math:`\\gamma=\\frac{1}{2\\sigma^2}` """
[docs] def __init__(self, centers: onp.ndarray = None , gamma: float = None, **kwargs): self.centers = centers self.gamma = gamma super().__init__(**kwargs)
@property def dim(self) -> int: """Mapping dimension""" return np.prod(self.centers.shape) @property def input_dim(self) -> int: """Dimensionality of input feature. 1 = number""" return 1 def __call__(self, x: jnp.ndarray) -> jnp.ndarray: """Embedding function for Gaussian RBF. Parameters ---------- x : :class:`Number` Input feature. Returns ------- jnp.ndarray Embedding vector. """ vector = jnp.exp(-self.gamma*jnp.subtract(x, jnp.array(self.centers))) return vector / jnp.linalg.norm(vector)
[docs] class polynomial(Embedding): """ Polynomial feature map Attributes ---------- degree : int Degree of polynomial. n : int Number of features. include_bias : bool Include bias term. """
[docs] def __init__(self, degree: int, n: int, include_bias: bool = False, **kwargs): if degree < 1: raise ValueError("Degree of polynomial embedding must be at least 1.") self.degree = degree self.n = n self.include_bias = include_bias super().__init__(**kwargs)
@property def dim(self) -> int: """ Mapping dimension """ if self.include_bias: return sum(math.comb(self.input_dim + k - 1, k) for k in range(0, self.degree + 1)) else: return sum(math.comb(self.input_dim + k - 1, k) for k in range(1, self.degree + 1)) @property def input_dim(self) -> int: """ Dimensionality of input feature""" return self.n def __call__(self, x: Union[Number, onp.array]) -> jnp.ndarray: """Embedding function for polynomial. Parameters ---------- x : :class:`Number` Input feature. Returns ------- jnp.ndarray Embedding vector. """ if x.ndim == 0: x = jnp.array([x]) if self.include_bias: features = [1.0] else: features = [] # Generate combinations of feature indices with repetition up to the specified degree for d in range(1, self.degree + 1): if d == 0: features.append(x) for combination in itertools.combinations_with_replacement(range(len(x)), d): # Compute the product of the selected features product = jnp.prod(x[jnp.array(combination)]) features.append(product) return jnp.array(features)
[docs] class jax_arrays(Embedding): """Input arrays to JAX arrays. No embedding. Optional: adding one to the input array. Attributes ---------- add_bias: bool Add bias term (1.0) """
[docs] def __init__(self, dim: int = None, add_bias: bool = False, **kwargs): super().__init__(**kwargs) self._dim = dim self.add_bias = add_bias
@property def dim(self) -> int: """ Mapping dimension """ return self._dim def __call__(self, x: Any) -> jnp.ndarray: """Embedding function for JAX arrays. Parameters ---------- x: list List of input features. Returns ------- jnp.ndarray Embedding vector. """ if self.add_bias: return jnp.concatenate([jnp.array([1.]), x]) return jnp.array(x)
class trigonometric_chain(ComplexEmbedding): """ Trigonometric feature map for each dimension of feature. Sample = [[x11, x12, x13], [x21, x22, x23]] ==> [[cos(x11), sin(x11), cos(x12), sin(x12), cos(x13), sin(x13)], [cos(x21), sin(x21), cos(x22), sin(x22), cos(x23), sin(x23)]] - dims = 6 for each feature. - input_dims = 1 for each feature. Attributes ---------- k: int Custom parameter = ``dim/2``. input_shape: tuple Input shape: number of features and number of dimensions per feature. """ def __init__(self, k: int = 1, input_shape: tuple = (2, 2), **kwargs): assert k >= 1 self.k = k self.input_shape = input_shape super().__init__(**kwargs) @property def dims(self) -> int: """ Mapping dimensions per feature """ return [self.k * 2 * self.input_shape[1]]*self.input_shape[0] @property def input_dims(self) -> jnp.ndarray: return jnp.array([1] * self.k) @property def embeddings(self) -> Collection[Embedding]: return [trigonometric(k=self.k) for _ in range(self.input_shape[1])] def __call__(self, x: Collection) -> jnp.ndarray: """Embedding function for trigonometric chain. Parameters ---------- x: list List of input features. Returns ------- jnp.ndarray Embedding vector. """ embedded = [] for f, xi in zip(self.embeddings, x): embedded.extend(f(xi)) return jnp.array(embedded) class trigonometric_avg(ComplexEmbedding): """ Trigonometric feature map for mean(features). Attributes ---------- k: int Custom parameter = ``dim/2``. """ def __init__(self, k: int = 1, input_shape: tuple = (2, 2), **kwargs): assert k >= 1 self.k = k self.input_shape = input_shape super().__init__(**kwargs) @property def dims(self) -> int: """ Mapping dimensions per feature """ return [self.k * 2]*self.input_shape[0] @property def input_dims(self) -> jnp.ndarray: return jnp.array([1] * self.k) @property def embeddings(self) -> Embedding: return trigonometric(k=self.k) def __call__(self, features: Collection) -> jnp.ndarray: """ Embedding function for average of input features. Parameters ---------- features: list List of input features. Returns ------- :class:`jax.numpy.ndarray` Embedded vector. """ return self.embeddings(jnp.mean(features))
[docs] class PatchEmbedding(StateVectorToMPSEmbedding):
[docs] def __init__(self, k = 2, **kwargs): """ Initialize the PatchedEmbedding class. Parameters ---------- k: int The kernel size of the patch window kxk. Returns ------- None """ super().__init__(**kwargs) self.k = k self.mps = None
@property def dims(self) -> list: return list([tensor.shape for tensor in self.mps.tensors])
[docs] def pad_or_truncate_statevector(self, statevector: jnp.ndarray, target_size: int) -> jnp.ndarray: """ Pad or truncate the statevector to a target size. Parameters ---------- statevector: :class:`jax.numpy.ndarray` The input statevector. target_size: int The desired size of the statevector. Returns ------- :class:`jax.numpy.ndarray` A statevector of the target size. """ current_size = statevector.shape[0] # Pad or truncate if current_size < target_size: # Pad with zeros if smaller than target size padding = [(0, target_size - current_size)] statevector = jnp.pad(statevector, padding, mode='constant') else: # Truncate if larger than target size statevector = statevector[:target_size] return statevector
[docs] def create_statevector(self, x: jnp.ndarray) -> jnp.ndarray: """ Create a statevector representation of an input array (vector like). Parameters ---------- x: :class:`jax.numpy.ndarray` An array of patch pixel intensities flattened from original patch k x k. Returns ------- :class:`jax.numpy.ndarray` A statevector representation of the input array. """ # Number of pixels (N = 16 for a 4x4 image) N = len(x) # Number of address qubits is log2(N) = 4 n_address_qubits = int(np.ceil(np.log2(N))) # One color qubit n_color_qubit = 1 # Total number of qubits = address qubits + 1 color qubit n_qubits = n_address_qubits + n_color_qubit # Create index tensors for addressing state_indices = jnp.arange(2**n_qubits) color_bits = state_indices % 2 # Extract color qubit (last bit) address_indices = state_indices // 2 # Extract address state # Calculate cos and sin for each pixel intensity cos_values = jnp.cos(math.pi * x / 2) sin_values = jnp.sin(math.pi * x / 2) # Create the statevector with color qubit encoding statevector = jnp.where( color_bits == 0, cos_values[address_indices], sin_values[address_indices] ) # Normalize the statevector statevector /= jnp.linalg.norm(statevector) # Pad or truncate to fixed size fixed_size = 2**n_qubits padded_statevector = self.pad_or_truncate_statevector(statevector.flatten(), fixed_size) return padded_statevector, n_qubits
[docs] def flatten_snake(self, image: jnp.ndarray) -> jnp.ndarray: """ Flatten an image in a snake-like fashion. Parameters ---------- image: :class:`jax.numpy.ndarray` A 2D array of pixel intensities. Returns ------- :class:`jax.numpy.ndarray` A 1D array of pixel intensities in snake-like order. """ # Flip every other row by slicing image = jnp.where(jnp.arange(image.shape[0])[:, None] % 2 == 1, jnp.flip(image, axis=1), image) # Flatten the image flattened = image.reshape(-1) return flattened
[docs] def combine_mps_patches(self, mps_patches: onp.ndarray, n_qubits: int) -> jnp.ndarray: """ Combine arrays of each MPS patch into a single MPS. Parameters ---------- mps_patches: :class:`numpy.ndarray` List of MPS patches (nested lists of arrays). n_qubits: int Number of qubits. Returns ------- :class:`jax.numpy.ndarray` A list of arrays for combined MPS. """ new_arrays = [] number_interval = 0 for patch in mps_patches: for i, arr in enumerate(patch): # Check if current array index matches the start or end of an interval if i == number_interval * n_qubits and len(arr.shape) == 2: # Add a new axis at the beginning (dim=0) new_arrays.append(jnp.expand_dims(arr, axis=0)) elif i == ((number_interval + 1) * n_qubits - 1) and len(arr.shape) == 2: # Add a new axis at the end (dim=-1) new_arrays.append(jnp.expand_dims(arr, axis=-1)) number_interval += 1 else: # Add the array as is new_arrays.append(arr) return new_arrays
def __call__(self, x: jnp.ndarray) -> qtn.MatrixProductState: """ Convert a Statevector into a Matrix Product State (MPS). Parameters ---------- x: :class:`jax.numpy.ndarray` A Statevector. Returns ------- :class:`quimb.tensor.MatrixProductState` A Matrix Product State representation of the input Statevector. """ H, W = x.shape # H: height, W: width patches: number of patches if H != W: raise ValueError("Only square matrix input is supported.") patches = u.divide_into_patches(x, self.k) mps_patches = [] for patch in patches: patch_pixels = self.flatten_snake(patch) statevector, n_qubits = self.create_statevector(patch_pixels) mps_arrays = u.from_dense_to_mps(statevector, n_qubits, self.max_bond) mps_patches.append(mps_arrays) new_arrays = self.combine_mps_patches(mps_patches, n_qubits) # Recreate the MPS with the reshaped arrays self.mps = qtn.MatrixProductState(new_arrays, shape='lrp') return self.mps
class PatchAmplitudeEmbedding(StateVectorToMPSEmbedding): def __init__(self, k = 2, **kwargs): """ Initialize the AmplitudeToMPSEmbedding class. Parameters ---------- k: int The kernel size of the patch window kxk. Returns ------- None """ super().__init__(**kwargs) self.k = k self.mps = None @property def dims(self) -> list: return list([tensor.shape for tensor in self.mps.tensors]) def create_statevector(self, x: jnp.ndarray) -> jnp.ndarray: """ Create a statevector representation of an input array (vector like). Parameters ---------- x: :class:`jax.numpy.ndarray` An array of patch pixel intensities flattened from original image. Returns ------- :class:`jax.numpy.ndarray` A statevector representation of the input array. """ # Number of pixels (N = 784 for a 28x28 image) N = len(x) # Number of address qubits is ceil(log2(N)) = 10 for a 28x28 image n_qubits = int(np.ceil(np.log2(N))) # Create the state vector and fill it with square roots of the pixel values statevector = jnp.sqrt(x) # Normalize the statevector statevector /= jnp.linalg.norm(statevector) # Pad or truncate to fixed size fixed_size = 2**n_qubits padded_statevector = self.pad_or_truncate_statevector(statevector.flatten(), fixed_size) return padded_statevector, n_qubits def pad_or_truncate_statevector(self, statevector: jnp.ndarray, target_size: int) -> jnp.ndarray: """ Pad or truncate the statevector to a target size. Parameters ---------- statevector: :class:`jax.numpy.ndarray` The input statevector. target_size: int The desired size of the statevector. Returns ------- :class:`jax.numpy.ndarray` A statevector of the target size. """ current_size = statevector.shape[0] # Pad or truncate if current_size < target_size: # Pad with zeros if smaller than target size padding = [(0, target_size - current_size)] statevector = jnp.pad(statevector, padding, mode='constant') else: # Truncate if larger than target size statevector = statevector[:target_size] return statevector def combine_mps_patches(self, mps_patches: onp.ndarray, n_qubits: int) -> jnp.ndarray: """ Combine arrays of each MPS patch into a single MPS. Parameters ---------- mps_patches: :class:`numpy.ndarray` List of MPS patches (nested lists of arrays). n_qubits: int Number of qubits. Returns ------- :class:`jax.numpy.ndarray` A list of arrays for combined MPS. """ new_arrays = [] number_interval = 0 for patch in mps_patches: for i, arr in enumerate(patch): # Check if current array index matches the start or end of an interval if i == number_interval * n_qubits and len(arr.shape) == 2: # Add a new axis at the beginning (dim=0) new_arrays.append(jnp.expand_dims(arr, axis=0)) elif i == ((number_interval + 1) * n_qubits - 1) and len(arr.shape) == 2: # Add a new axis at the end (dim=-1) new_arrays.append(jnp.expand_dims(arr, axis=-1)) number_interval += 1 else: # Add the array as is new_arrays.append(arr) return new_arrays def __call__(self, x: jnp.ndarray) -> qtn.MatrixProductState: """ Convert a Statevector into a Matrix Product State (MPS). Parameters ---------- x: :class:`jax.numpy.ndarray` A Statevector. Returns ------- :class:`quimb.tensor.MatrixProductState` A Matrix Product State representation of the input Statevector. """ H, W = x.shape # H: height, W: width patches: number of patches if H != W: # TODO: Discuss about rectangular images (they could be supported, at least in principle) raise ValueError("Only square matrix input is supported.") if self.k > H: raise ValueError(f"Patch dimension k = {self.k} is too large for {H}x{W} images.") patches = u.divide_into_patches(x, self.k) mps_patches = [] for patch in patches: statevector, n_qubits = self.create_statevector(patch.ravel()) mps_arrays = u.from_dense_to_mps(statevector, n_qubits, self.max_bond) mps_patches.append(mps_arrays) new_arrays = self.combine_mps_patches(mps_patches, n_qubits) # Recreate the MPS with the reshaped arrays self.mps = qtn.MatrixProductState(new_arrays, shape='lrp') return self.mps
[docs] def embed(x: onp.ndarray, phi: Union[Embedding, ComplexEmbedding, StateVectorToMPSEmbedding], **mps_opts): """ Creates a product state from a vector of features `x`. Works only if features are separated and not correlated (this check you need to do yourself). Parameters ---------- x: :class:`numpy.ndarray` Vector or Matrix of features. phi: :class:`tn4ml.embeddings.Embedding` or :class:`tn4ml.embeddings.ComplexEmbedding` or :class:`tn4ml.embeddings.StateVectorToMPSEmbedding` Feature map for each feature. mps_opts: Optional parameters. Additional arguments passed to MatrixProductState class. """ if not issubclass(type(phi), ComplexEmbedding) and not issubclass(type(phi), Embedding) and not issubclass(type(phi), StateVectorToMPSEmbedding): raise TypeError('Invalid embedding type') if issubclass(type(phi), Embedding): arrays = [phi(xi).reshape((1, 1, phi.dim)) for xi in x] for i in [0, -1]: arrays[i] = arrays[i].reshape((1, phi.dim)) mps = qtn.MatrixProductState(arrays, **mps_opts) elif issubclass(type(phi), ComplexEmbedding) and x.ndim == 2: if type(phi.dims) == int: arrays = [phi(xi).reshape((1, 1, phi.dims)) for xi in x] for i in [0, -1]: arrays[i] = arrays[i].reshape((1, phi.dims)) else: arrays = [phi(xi).reshape((1, 1, phi.dims[i])) for i, xi in enumerate(x)] for i in [0, -1]: arrays[i] = arrays[i].reshape((1, phi.dims[i])) mps = qtn.MatrixProductState(arrays, **mps_opts) else: mps = phi(x) # normalize if len(mps.tensors) > 200: # for large systems for i, tensor in enumerate(mps.tensors): if i == 0: mps.left_canonize_site(i) elif i == len(mps.tensors) - 1: tensor.modify(data=tensor.data / jnp.linalg.norm(tensor.data)) else: tensor.modify(data=tensor.data / jnp.linalg.norm(tensor.data)) mps.left_canonize_site(i) else: norm = mps.norm() for tensor in mps.tensors: tensor.modify(data=tensor.data / a.do("power", norm, 1 / len(mps.tensors))) return mps