# Examples of loss functions for supervised and unsupervised learning.
from numbers import Number
from typing import Callable, Optional, Union
import jax.numpy as jnp
import jax
import numpy as np
import optax
import quimb.tensor as qtn
from .models.model import Model
from .embeddings import Embedding, embed, TrigonometricEmbedding
from .models.smpo import SpacedMatrixProductOperator
from .models.mps import MatrixProductState
from .models.mpo import MatrixProductOperator
[docs]
def NegLogLikelihood(model: qtn.MatrixProductState, data: qtn.MatrixProductState) -> Number:
"""Negative Log-Likelihood loss.
Parameters
----------
model : :class:`quimb.tensor.MatrixProductState`
Matrix Product State model
data: :class:`quimb.tensor.MatrixProductState`
Input MPS
Returns
-------
float
"""
# check if physical dimensions match
assert model.tensors[0].shape[-1] == data.tensors[0].shape[-1]
if len(model.tensors) < len(data.tensors):
inds_contract = []
for i in range(len(data.tensors)):
inds_contract.append(f'k{i}')
output = (model.H & data)
for index in inds_contract:
output.contract_ind(index)
output = output^all
elif len(model.tensors) == len(data.tensors):
# assuming that model and data has same names for physical indices
output = (model.H & data)^all
else:
raise ValueError('Number of tensors for input data MPS needs to be higher or equal number of tensors in model.')
return - jax.lax.log(jax.lax.pow(output, 2))
[docs]
def LogFrobNorm(model) -> Number:
"""Regularization cost - log(Frobenius-norm of `model`)
Parameters
----------
model : :class:`quimb.tensor.MatrixProductState`
Matrix Product State model
Returns
-------
float
"""
if type(model) in [SpacedMatrixProductOperator]:
tn = model.H.apply(model)
norm = tn.contract_cumulative(tn.site_tags)
else:
norm = model.norm()
return jax.lax.log(norm)
[docs]
def LogPowFrobNorm(model) -> Number:
"""Regularization cost - log(squared(Frobenius-norm of `model`))
Parameters
----------
model : :class:`quimb.tensor.MatrixProductState`
Matrix Product State model
Returns
-------
float
"""
if type(model) in [SpacedMatrixProductOperator]:
tn = model.H.apply(model)
norm = tn.contract_cumulative(tn.site_tags)
else:
norm = model.norm()
return jax.lax.log(jax.lax.pow(norm, 2))
[docs]
def LogReLUFrobNorm(model) -> Number:
"""Regularization cost using ReLU of the log of the Frobenius-norm of `model`.
Parameters
----------
model : :class:`tn4ml.models.smpo.SpacedMatrixProductOperator`
Spaced Matrix Product Operator
Returns
-------
float
"""
if type(model) in [SpacedMatrixProductOperator]:
tn = model.H.apply(model)
norm = tn.contract_cumulative(tn.site_tags)
else:
norm = model.norm()
return jax.lax.max(0.0, jax.lax.log(norm).astype(jnp.float64))
[docs]
def QuadFrobNorm(model) -> Number:
"""Regularization cost using the quadratic formula centered in 1 of the Frobenius-norm of `model`.
Parameters
----------
model : :class:`tn4ml.models.smpo.SpacedMatrixProductOperator`
Spaced Matrix Product Operator
Returns
-------
float
"""
if type(model) in [SpacedMatrixProductOperator]:
tn = model.H.apply(model)
norm = tn.contract_cumulative(tn.site_tags)
else:
norm = model.norm()
return jax.lax.pow(jax.lax.log(norm) - 1.0, 2)
[docs]
def LogQuadNorm(model: SpacedMatrixProductOperator, data: qtn.MatrixProductState) -> Number:
"""Example of error calculation when applying :class:`tn4ml.models.smpo.SpacedMatrixProductOperator` `P` to `data`.
Parameters
----------
model : :class:`tn4ml.models.smpo.SpacedMatrixProductOperator`
Spaced Matrix Product Operator
data: :class:`quimb.tensor.MatrixProductState`
Input mps.
Returns
-------
float
"""
return jax.lax.pow((jax.lax.log(TransformedSquaredNorm(model, data)) - 1.0), 2)
[docs]
def QuadNorm(model: SpacedMatrixProductOperator, data: qtn.MatrixProductState) -> Number:
"""Example of error calculation when applying :class:`tn4ml.models.smpo.SpacedMatrixProductOperator` `P` to `data`.
Parameters
----------
model : :class:`tn4ml.models.smpo.SpacedMatrixProductOperator`
Spaced Matrix Product Operator
data: :class:`quimb.tensor.MatrixProductState`
Input mps.
Returns
-------
float
"""
return jax.lax.pow((TransformedSquaredNorm(model, data) - 1.0), 2)
[docs]
def SemiSupervisedLoss(model: SpacedMatrixProductOperator, data: qtn.MatrixProductState, y_true: Number, **kwargs) -> Number:
"""Loss function for semi-supervised learning.
Parameters
----------
model : :class:`tn4ml.models.smpo.SpacedMatrixProductOperator`
Spaced Matrix Product Operator
data: :class:`quimb.tensor.MatrixProductState`
Input Matrix Product State
y_true: :class:`Number`
Target class percentage.
Returns
-------
float
"""
norm = LogQuadNorm(model, data) + 0.3*LogReLUFrobNorm(model)
loss_value = jax.lax.pow(y_true*(1/norm) + (1-y_true)*norm, 2)
return loss_value[0]
[docs]
def SemiSupervisedNLL(model: SpacedMatrixProductOperator, data: qtn.MatrixProductState, y_true: Optional[jnp.array] = None, **kwargs) -> Number:
"""Loss function for semi-supervised learning.
Parameters
----------
model : :class:`tn4ml.models.smpo.SpacedMatrixProductOperator`
Spaced Matrix Product Operator
data: :class:`quimb.tensor.MatrixProductState`
Input Matrix Product State
y_true: :class:`Number`
Target class percentage.
Returns
-------
float
"""
mps = model.apply(data)
norm = jnp.array(mps.arrays).sum()
norm = jax.lax.pow(((jax.lax.log(norm) - 1.0)), 2)
output = (model.H & data)^all
output = output.data.reshape((2,))
class_error = optax.softmax_cross_entropy_with_integer_labels(output, jnp.squeeze(y_true))
loss_value = class_error + output*(1/(norm)) + (1-output)*(norm) + 0.3*LogReLUFrobNorm(model)
return loss_value
[docs]
def Softmax(z, position) -> Number:
""" Softmax function.
Parameters
----------
z : :class:`jnp.array``
Predicted probabilities.
position: int
Indicates for which class we are calculating softmax value.
Returns
-------
float
"""
return jnp.exp(z[position]) / jnp.sum(jnp.exp(z))
[docs]
def CrossEntropySoftmax(model: SpacedMatrixProductOperator, data: qtn.MatrixProductState, targets: jnp.array) -> Number:
"""Cross-entropy loss function for supervised learning.
Parameters
----------
model : :class:`tn4ml.models.smpo.SpacedMatrixProductOperator`
Spaced Matrix Product Operator
data: :class:`quimb.tensor.MatrixProductState`
Input Matrix Product State
targets: :class:`numpy.ndarray`
Target class vector. Example = [1 0 0 0] for n_classes = 4.
Returns
-------
float
"""
if len(model.tensors) < len(data.tensors):
inds_contract = []
for i in range(len(data.tensors)):
inds_contract.append(f'k{i}')
output = (model.H & data)
for index in inds_contract:
output.contract_ind(index)
output = output^all
elif len(model.tensors) == len(data.tensors):
if hasattr(model, 'apply'):
output = model.apply(data)^all
else:
output = (model.H & data)^all
else:
raise ValueError('Number of tensors for input data MPS needs to be higher or equal number of tensors in model.')
output = output.data.reshape((len(targets), ))
output = output/jnp.linalg.norm(output)
return - jnp.log(Softmax(output, jnp.argmax(targets)))
[docs]
def MeanSquaredError(model: SpacedMatrixProductOperator, data: qtn.MatrixProductState, targets: jnp.array) -> Number:
"""Mean Squared Error loss function for supervised learning.
Parameters
----------
model : :class:`tn4ml.models.smpo.SpacedMatrixProductOperator`
Spaced Matrix Product Operator
data: :class:`quimb.tensor.MatrixProductState`
Input Matrix Product State
targets: :class:`numpy.ndarray`
Target class vector. Example = [1 0 0 0] for n_classes = 4.
Returns
-------
float
"""
if len(model.tensors) < len(data.tensors):
inds_contract = []
for i in range(len(data.tensors)):
inds_contract.append(f'k{i}')
output = (model.H & data)
for index in inds_contract:
output.contract_ind(index)
output = output^all
elif len(model.tensors) == len(data.tensors):
if hasattr(model, 'apply'):
output = model.apply(data)^all
else:
output = model | data
for ind in data.outer_inds():
output.contract_ind(ind=ind)
tags = list(qtn.tensor_core.get_tags(output))
tags_to_drop = []
for j in range(len(model.tensors)//2-1):
output.contract_between(tags[j], tags[j + 1])
tags_to_drop.extend([tags[j]])
output.drop_tags(tags_to_drop)
output.fuse_multibonds_()
tags_to_drop=[]
for j in range(len(model.tensors)-1, len(model.tensors)//2-1, -1):
output.contract_between(tags[j], tags[j - 1])
tags_to_drop.extend([tags[j]])
output.drop_tags(tags_to_drop)
else:
raise ValueError('Number of tensors for input data MPS needs to be higher or equal number of tensors in model.')
output = output.tensors[0].data.reshape((len(targets), ))
output = output/jnp.linalg.norm(output)
return jnp.mean(jnp.square(output - targets))
[docs]
def OptaxWrapper(optax_loss = None) -> Callable:
"""Wrapper around optax loss functions for supervised learning. Make sure you got all inputs to loss function correct.
Refer to documentation for each loss to https://optax.readthedocs.io/en/latest/api/losses.html .
Make sure SMPO has only one output with dimension = number of classes.
Parameters
----------
model : :class:`tn4ml.models.model.Model`
Tensor Network model.
data: :class:`quimb.tensor.MatrixProductState`
Input Matrix Product State
y_true: :class:`numpy.ndarray`
Target class vector. Example = [1 0 0 0] for n_classes = 4.
kwargs : dict
Additional arguments for optax loss function.
Returns
-------
float
"""
assert optax_loss != None
def loss_optax(model: Model, data: qtn.MatrixProductState, y_true: Optional[jnp.array] = None, **kwargs) -> Number:
"""Loss function for learning. Make sure you got all inputs to loss function correct.
Parameters
----------
model : :class:`tn4ml.models.model.Model`
Tensor Network model.
data: :class:`quimb.tensor.MatrixProductState`
Input Matrix Product State
y_true: :class:`numpy.ndarray`
Target class vector. Example = [1 0 0 0] for n_classes = 4.
kwargs : dict
Additional arguments for optax loss function.
Returns
-------
float
"""
if isinstance(model, SpacedMatrixProductOperator):
if len(model.tensors) < len(data.tensors):
inds_contract = []
for i in range(len(data.tensors)):
inds_contract.append(f'k{i}')
output = (model.H & data)
for index in inds_contract:
output.contract_ind(index)
output = output^all
output = output.data.reshape((len(y_true), ))
y_pred = jnp.log(output)
else:
output = model.apply(data)
if len(output.tensors) > 1:
output = output^all
y_pred = output.data
else:
y_pred = jnp.expand_dims(jnp.squeeze(output.tensors[0].data), axis=0)
if y_true is not None:
y_true = jnp.expand_dims(jnp.squeeze(y_true), axis=0)
elif isinstance(model, MatrixProductState):
y_pred = (model & data)^all
else:
y_pred = (model & data)^all
y_pred = jnp.expand_dims(jnp.squeeze(y_pred.data), axis=0)
# normalize
y_pred = y_pred/jnp.linalg.norm(y_pred)
if y_true is not None:
if len(y_true.shape) == 1:
y_true = jnp.expand_dims(y_true, axis=0)
return optax_loss(y_pred, y_true, **kwargs)
else:
return optax_loss(y_pred, **kwargs)
return loss_optax
[docs]
def CrossEntropyWeighted(class_weights: jnp.array = None) -> Callable:
def cross_entropy(model: Model, data: MatrixProductState, y_true: jnp.array = None, **kwargs) -> Number:
"""
Compute the weighted cross-entropy loss.
Parameters
----------
model : :class:`tn4ml.models.model.Model`
Tensor Network model.
data: :class:`quimb.tensor.MatrixProductState`
Input Matrix Product State
y_true: :class:`jax.numpy.ndarray`
Target class vector. Example = [1 0 0 0] for n_classes = 4.
class_weights: :class:`numpy.ndarray`
Class weights, shape (num_classes,).
kwargs : dict
Additional arguments for optax loss function.
Returns
-------
Weighted cross-entropy loss.
"""
# Compute per-sample weights based on class labels
sample_weights = jnp.sum(y_true * jnp.array(class_weights), axis=-1) # Shape: (batch_size,)
logits = OptaxWrapper(optax.softmax_cross_entropy)(model, data, y_true)
# Apply sample weights
weighted_loss = logits * jnp.array(sample_weights) # Shape: (batch_size,)
# Return mean weighted loss
return jnp.array(weighted_loss)
return cross_entropy
[docs]
def CombinedLoss(model: Model,
data: Union[qtn.MatrixProductState, np.ndarray],
y_true: Optional[jnp.array] = None,
error: Callable = LogQuadNorm,
reg: Callable = NoReg,
embedding: Optional[Embedding] = None
) -> Number:
"""
Unified Loss function combining error computation and regularization.
Parameters
----------
model : :class:`tn4ml.models.Model`
Tensor Network with parametrized tensors.
data : Union[:class:`qtn.MatrixProductState`, :class:`numpy.ndarray`]
Data used for computing the loss value.
y_true : Optional[:class:`jnp.array`]
True labels for supervised learning tasks (only applicable for MatrixProductState data).
error : function
Function for error calculation.
reg : function
Function for regularization value calculation.
embedding : :class:`tn4ml.embeddings.Embedding`, optional
Data embedding function.
Returns
-------
float
Computed loss value.
"""
if data is None:
raise ValueError("Provide input data!")
if isinstance(data, np.ndarray):
# Compute loss for NumPy array data
if embedding:
data = [embed(sample, embedding) for sample in data]
else:
ValueError("Provide embedding function for NumPy array data to embed it into Tensor Network representation.")
if y_true is not None:
loss = jnp.mean(error(model, data, y_true)) + reg(model)
else:
loss = jnp.mean(error(model, data)) + reg(model)
return loss