MNIST classification#
[1]:
import os
os.environ["KMP_WARNINGS"] = "0"
import numpy as np
import jax.numpy as jnp
import jax
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from jax.nn.initializers import *
import quimb.tensor as qtn
from tn4ml.initializers import *
from tn4ml.models.mps import *
from tn4ml.models.model import *
from tn4ml.embeddings import *
from tn4ml.metrics import *
from tn4ml.strategy import *
from tn4ml.util import *
from tn4ml.eval import *
[2]:
jax.config.update("jax_enable_x64", True)
jax.config.update('jax_default_matmul_precision', 'highest')
Load dataset#
MNIST images → grayscale images
size: 28x28
0-9 numbers
[3]:
train, test = mnist.load_data()
[4]:
train_labels = train[1]
train_images = train[0].reshape(-1, 28, 28)
[5]:
import matplotlib.pyplot as plt
hfont = {'fontname':'Courier New', 'fontsize': 15, 'fontweight': 'bold'}
for i in range(10):
plt.subplot(2, 5, i + 1)
# Select the first image of each digit
digit_indices = np.where(train_labels == i)[0]
plt.imshow(train_images[digit_indices[0]], cmap='gray')
plt.title(f'Label: {i}', **hfont)
plt.axis('off')
plt.tight_layout()
[6]:
data = {"X": dict(train=train[0], test=test[0]), "y": dict(train=train[1], test=test[1])}
Reduce size of the image
[7]:
def resize_images(images):
resized_images = tf.image.resize(images, [14, 14], method=tf.image.ResizeMethod.AREA)
return resized_images.numpy()
[8]:
X_resized = resize_images(data['X']['train'].reshape(-1,28,28,1)).reshape(-1,14,14)/255.0
X_test_resized = resize_images(data['X']['test'].reshape(-1,28,28,1)).reshape(-1,14,14)/255.0
Rearagne pixels in zig-zag order#

[9]:
def zigzag_order(data):
data_zigzag = []
for x in data:
image = []
for i in x:
image.extend(i)
data_zigzag.append(image)
return np.asarray(data_zigzag)
[10]:
train_data = zigzag_order(X_resized)
test_data = zigzag_order(X_test_resized)
One-hot encoding of labels#
0 → [1 0 0 0 0 0 0 0 0 0] 1 → [0 1 0 0 0 0 0 0 0 0] …. 9 → [0 0 0 0 0 0 0 0 0 1]
[11]:
n_classes = 10
[12]:
y_train = integer_to_one_hot(data['y']['train'], n_classes)
y_test = integer_to_one_hot(data['y']['test'], n_classes)
Take samples for training, validation and testing
[13]:
from sklearn.model_selection import train_test_split
[ ]:
train_inputs, _, train_targets, _ = train_test_split(train_data, y_train, test_size=0.9, random_state=42) # take only 10% of the training data - to speed up the training
[15]:
train_inputs, val_inputs, train_targets, val_targets = train_test_split(train_inputs, train_targets, test_size=0.2, random_state=42)
TN as ML model#
Specify parameters and initialize a tensor network

[16]:
L = 14*14 # number of tensors in the MPS
initializer = randn(1e-2) # MPS tensors are initialized with random normal values
key = jax.random.key(42)
shape_method = 'noteven'
bond_dim = 10 # bond dimension of the MPS
phys_dim = 3 # when polyomial embedding is used p = 3, when trigonometric embedding is used p = 2
class_dim = 10 # number of classes
index_class = L//2 if L%2==0 else L//2+1
cyclic = False
add_identity = True
boundary = 'obc' # open boundary conditions
[17]:
model = MPS_initialize(L,
initializer=initializer,
key=key,
shape_method=shape_method,
bond_dim=bond_dim,
phys_dim=phys_dim,
cyclic=False,
add_identity=add_identity,
class_dim=class_dim,
class_index=index_class,
canonical_center=index_class,
boundary=boundary,
dtype=jnp.float64)
Define training parameters
[18]:
def cross_entropy_loss(*args, **kwargs):
return OptaxWrapper(optax.softmax_cross_entropy)(*args, **kwargs).mean()
[19]:
# training parameters
optimizer = optax.adam
strategy = 'global' # Global Gradient Descent
loss = cross_entropy_loss
train_type = 1 # 0 for unsupervised, 1 for supervised
embedding = polynomial(degree=2, n=1, include_bias=True) # if using randn
learning_rate = 5e-4
device = 'cpu'
[20]:
model.configure(optimizer=optimizer, strategy=strategy, loss=loss, train_type=train_type, learning_rate=learning_rate, device=device)
[21]:
epochs = 100
batch_size = 128
To obtain loss scalar value, contract:

[22]:
history = model.train(train_inputs,
targets = train_targets,
val_inputs = val_inputs,
val_targets = val_targets,
epochs = epochs,
batch_size = batch_size,
canonize = (True, index_class),
embedding = embedding,
normalize = True,
cache=False, # for now True not working for classification
display_val_acc=True,
val_batch_size=batch_size,
eval_metric = cross_entropy_loss,
dtype = jnp.float64)
epoch: 100%|██████████ 100/100 , loss=1.4337, val_loss=1.4886, val_acc=0.9453
[23]:
plot_loss(history, validation=True, figsize=(8, 6))
[24]:
plot_accuracy(history, figsize=(8, 6))
Save model
[25]:
model.save('model', 'results/mnist_class', tn=True) # tn=True because MPS for classification if TensorNetwork object
Evaluate
Calculate accuracy of the classification
[26]:
model.accuracy(test_data, y_test, embedding=embedding, batch_size=64)
[26]:
0.9537259615384616
Retrain with exponential decay#
[27]:
# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
init_value=1e-3,
transition_steps=1000,
decay_rate=0.01)
# Combining gradient transforms using `optax.chain`.
gradient_transforms = [
optax.clip_by_global_norm(1.0), # Clip by the gradient by the global norm.
optax.scale_by_adam(), # Use the updates from adam.
optax.scale_by_schedule(scheduler), # Use the learning rate from the scheduler.
# Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.
optax.scale(-1.0)
]
[28]:
model.configure(gradient_transforms=gradient_transforms, strategy=strategy, loss=loss, train_type=train_type, learning_rate=learning_rate)
[29]:
epochs = 50
batch_size = 256
[30]:
history = model.train(train_inputs,
targets = train_targets,
val_inputs = val_inputs,
val_targets = val_targets,
epochs = epochs,
batch_size = batch_size,
canonize = (True, index_class),
embedding = embedding,
normalize = True,
cache=False, # for now True not working for classification
display_val_acc=True,
eval_metric = cross_entropy_loss,
val_batch_size=64,
dtype = jnp.float64)
epoch: 100%|██████████ 50/50 , loss=1.4263, val_loss=1.4836, val_acc=0.9470
[31]:
model.accuracy(test_data, y_test, embedding=embedding, batch_size=64)
[31]:
0.9556290064102564