Sweeping method for TN training

Sweeping method for TN training#

Anomaly detection with Spaced Matrix Product Operator trained using Sweeping method

  • images resized to \(7\times7\) resolution due to slower training of this procedure

[1]:
import os
os.environ["KMP_WARNINGS"] = "0"
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from jax.nn.initializers import *
import matplotlib.pyplot as plt
from sklearn.metrics import auc

from tn4ml.initializers import *
from tn4ml.models.smpo import *
from tn4ml.models.model import *
from tn4ml.embeddings import *
from tn4ml.metrics import *
from tn4ml.eval import *
from tn4ml.util import *
[2]:
jax.config.update("jax_enable_x64", True)
jax.config.update('jax_default_matmul_precision', 'highest')

Load dataset#

  • normalize to \([0, 1]\) range

[3]:
train, test = mnist.load_data()
data = {"X": dict(train=train[0], test=test[0]), "y": dict(train=train[1], test=test[1])}
[4]:
normal_class = 0
[5]:
# training data
X = {
"normal": data["X"]["train"][data["y"]["train"] == normal_class]/255.,
"anomaly": data["X"]["train"][data["y"]["train"] != normal_class]/255.,
}
[6]:
# test data
X_test = {
"normal": data["X"]["test"][data["y"]["test"] == normal_class]/255.,
"anomaly": data["X"]["test"][data["y"]["test"] != normal_class]/255.,
}
[7]:
# reduce size of images for faster training and reduce to 0-1 range
strides = (4,4) # (2,2) for 14x14 images; (4,4) for 7x7 images
pool_size = (2,2)
pool = tf.keras.layers.MaxPooling2D(pool_size=pool_size, strides=strides, padding="same")
[27]:
X_resized = {
    "normal": pool(tf.constant(X["normal"].reshape(-1,28,28,1))).numpy().reshape(-1,7,7),
    "anomaly": pool(tf.constant(X["anomaly"].reshape(-1,28,28,1))).numpy().reshape(-1,7,7),
}

X_test_resized = {
    "normal": pool(tf.constant(X_test["normal"].reshape(-1,28,28,1))).numpy().reshape(-1,7,7),
    "anomaly": pool(tf.constant(X_test["anomaly"].reshape(-1,28,28,1))).numpy().reshape(-1,7,7),
}

Rearrange pixels in zig-zag order#

MPS Params

[28]:
def zigzag_order(data):
    data = np.squeeze(data)
    data_zigzag = []
    for x in data:
        image = []
        for i in x:
            image.extend(i)
        data_zigzag.append(image)
    return np.asarray(data_zigzag)
[29]:
zigzag = True
[30]:
if zigzag:
    train_normal = zigzag_order(X_resized["normal"])
    test_normal = zigzag_order(X_test_resized["normal"])

    train_anomaly = zigzag_order(X_resized["anomaly"])
    test_anomaly = zigzag_order(X_test_resized["anomaly"])
else:
    train_normal = X_resized['normal'].reshape(-1, X_resized['normal'].shape[1]*X_resized['normal'].shape[2])
    test_normal = X_test_resized['normal'].reshape(-1, X_test_resized['normal'].shape[1]*X_test_resized['normal'].shape[2])

    train_anomaly = X_resized['anomaly'].reshape(-1, X_resized['anomaly'].shape[1]*X_resized['anomaly'].shape[2])
    test_anomaly = X_test_resized['anomaly'].reshape(-1, X_test_resized['anomaly'].shape[1]*X_test_resized['anomaly'].shape[2])
[31]:
# take train_size samples from normal class for training
train_size = 2048

indices = list(range(len(train_normal)))
np.random.shuffle(indices)

indices = indices[:train_size]
train_normal = np.take(train_normal, indices, axis=0)

Training setup#

[13]:
# define model parameters
L = 7*7
initializer = gramschmidt('normal', 1e-5)
key = jax.random.key(42)
shape_method = 'noteven'
bond_dim = 10
phys_dim = (2,2)
spacing = 8
add_identity = False
boundary='obc'
[14]:
model = SMPO_initialize(L=L,
                        initializer=initializer,
                        key=key,
                        shape_method=shape_method,
                        spacing=spacing,
                        bond_dim=bond_dim,
                        phys_dim=phys_dim,
                        cyclic=False,
                        compress=True,
                        add_identity=add_identity,
                        boundary=boundary)
[15]:
alpha = 0.4
def loss_combined(*args, **kwargs):
    error = LogQuadNorm
    reg = LogReLUFrobNorm
    return CombinedLoss(*args, **kwargs, error=error, reg=lambda P: alpha*reg(P))
[ ]:
# define training parameters
epochs = 100
batch_size = 256
optimizer = optax.adam
strategy = 'sweeps'
loss = loss_combined
train_type = TrainingType.UNSUPERVISED
embedding = TrigonometricEmbedding()
learning_rate = 1e-4
earlystop = EarlyStopping(min_delta=0, patience=10, monitor='loss', mode='min')

model.configure(optimizer=optimizer, strategy=strategy, loss=loss, train_type=train_type, learning_rate=learning_rate)
[ ]:
history = model.train(train_normal,
            epochs=epochs,
            batch_size=batch_size,
            embedding = embedding,
            normalize = True,
            dtype = jnp.float64,
            earlystop = earlystop,
            )

epoch:   3%|▎          3/100 , loss=4898.5405
Waiting for 1 epochs.
epoch:   4%|▍          4/100 , loss=4916.9816
Waiting for 2 epochs.
epoch:   5%|▌          5/100 , loss=4906.0683
Waiting for 3 epochs.
epoch:   8%|▊          8/100 , loss=4882.4867
Waiting for 1 epochs.
epoch:  10%|█          10/100 , loss=4880.3916
Waiting for 1 epochs.
epoch:  11%|█          11/100 , loss=4881.7900
Waiting for 2 epochs.
epoch:  12%|█▏         12/100 , loss=4898.7627
Waiting for 3 epochs.
epoch:  15%|█▌         15/100 , loss=4872.9330
Waiting for 1 epochs.
epoch:  16%|█▌         16/100 , loss=4881.6928
Waiting for 2 epochs.
epoch:  20%|██         20/100 , loss=4858.8395
Waiting for 1 epochs.
epoch:  21%|██         21/100 , loss=4862.0073
Waiting for 2 epochs.
epoch:  24%|██▍        24/100 , loss=4847.9034
Waiting for 1 epochs.
epoch:  26%|██▌        26/100 , loss=4845.9629
Waiting for 1 epochs.
epoch:  28%|██▊        28/100 , loss=4842.7239
Waiting for 1 epochs.
epoch:  29%|██▉        29/100 , loss=4844.3363
Waiting for 2 epochs.
epoch:  30%|███        30/100 , loss=4848.4053
Waiting for 3 epochs.
epoch:  31%|███        31/100 , loss=4847.3821
Waiting for 4 epochs.
epoch:  32%|███▏       32/100 , loss=4858.1263
Waiting for 5 epochs.
epoch:  33%|███▎       33/100 , loss=4852.0553
Waiting for 6 epochs.
epoch:  34%|███▍       34/100 , loss=4847.9779
Waiting for 7 epochs.
epoch:  38%|███▊       38/100 , loss=4827.4997
Waiting for 1 epochs.
epoch:  39%|███▉       39/100 , loss=4828.9107
Waiting for 2 epochs.
epoch:  40%|████       40/100 , loss=4829.7398
Waiting for 3 epochs.
epoch:  42%|████▏      42/100 , loss=4822.7272
Waiting for 1 epochs.
epoch:  43%|████▎      43/100 , loss=4832.1619
Waiting for 2 epochs.
epoch:  44%|████▍      44/100 , loss=4830.2362
Waiting for 3 epochs.
epoch:  45%|████▌      45/100 , loss=4825.7658
Waiting for 4 epochs.
epoch:  46%|████▌      46/100 , loss=4837.0979
Waiting for 5 epochs.
epoch:  47%|████▋      47/100 , loss=4830.5170
Waiting for 6 epochs.
epoch:  48%|████▊      48/100 , loss=4834.8636
Waiting for 7 epochs.
epoch:  50%|█████      50/100 , loss=4817.3382
Waiting for 1 epochs.
epoch:  51%|█████      51/100 , loss=4819.6702
Waiting for 2 epochs.
epoch:  53%|█████▎     53/100 , loss=4814.1478
Waiting for 1 epochs.
epoch:  58%|█████▊     58/100 , loss=4802.9595
Waiting for 1 epochs.
epoch:  59%|█████▉     59/100 , loss=4812.1689
Waiting for 2 epochs.
epoch:  60%|██████     60/100 , loss=4803.7250
Waiting for 3 epochs.
epoch:  61%|██████     61/100 , loss=4826.8882
Waiting for 4 epochs.
epoch:  62%|██████▏    62/100 , loss=4811.7140
Waiting for 5 epochs.
epoch:  63%|██████▎    63/100 , loss=4805.7986
Waiting for 6 epochs.
epoch:  64%|██████▍    64/100 , loss=4809.9926
Waiting for 7 epochs.
epoch:  66%|██████▌    66/100 , loss=4798.8627
Waiting for 1 epochs.
epoch:  67%|██████▋    67/100 , loss=4802.2691
Waiting for 2 epochs.
epoch:  68%|██████▊    68/100 , loss=4804.7045
Waiting for 3 epochs.
epoch:  69%|██████▉    69/100 , loss=4812.5455
Waiting for 4 epochs.
epoch:  70%|███████    70/100 , loss=4808.0869
Waiting for 5 epochs.
epoch:  71%|███████    71/100 , loss=4800.1221
Waiting for 6 epochs.
epoch:  72%|███████▏   72/100 , loss=4805.0641
Waiting for 7 epochs.
epoch:  73%|███████▎   73/100 , loss=4803.9018
Waiting for 8 epochs.
epoch:  74%|███████▍   74/100 , loss=4807.6348
Waiting for 9 epochs.
epoch:  75%|███████▌   75/100 , loss=4811.4392
Training stopped by EarlyStopping on epoch: 65
epoch:  76%|███████▌   76/100 , loss=4806.6447
[18]:
plot_loss(model.history, validation=False, figsize=(8, 6))
../../_images/source_examples_mnist_ad_sweeps_21_0.png
[19]:
# plot loss
plt.figure(figsize=(8, 6))
plt.loglog(range(len(history['loss'])), history['loss'], label='train')
plt.legend()
plt.show()
../../_images/source_examples_mnist_ad_sweeps_22_0.png

Evaluate

[32]:
indices = list(range(len(test_anomaly)))
np.random.shuffle(indices)

indices = indices[:len(test_normal)]
test_anomaly = np.take(test_anomaly, indices, axis=0)
[33]:
test_anomaly.shape, test_normal.shape
[33]:
((980, 49), (980, 49))
[ ]:
loss = LogQuadNorm

anomaly_score = model.evaluate(test_anomaly, evaluate_type=train_type, return_list=True, dtype=jnp.float64, batch_size=128, embedding=embedding, metric = loss)
normal_score = model.evaluate(test_normal, evaluate_type=train_type, return_list=True, dtype=jnp.float64, batch_size=128, embedding=embedding, metric = loss)
[35]:
anomaly_score.shape, normal_score.shape
[35]:
((896,), (896,))
[39]:
fpr, tpr = get_roc_curve_data(anomaly_score, normal_score, anomaly_det=True)
auc_value = auc(fpr, tpr)

Plot anomaly scores and ROC curve

[40]:
plt.figure()
plt.hist(anomaly_score, bins=100, histtype='step', label='anomaly')
plt.hist(normal_score, bins=100, histtype='step', label='normal')
plt.title('Anomaly score distribution')
plt.legend()
plt.text(0.5, -0.1, f'AUC Value: {auc_value}', ha='center', transform=plt.gca().transAxes)
plt.legend()
plt.show()
../../_images/source_examples_mnist_ad_sweeps_30_0.png
[41]:
# Plot ROC curve
plot_ROC_curve_from_data(fpr, tpr)

../../_images/source_examples_mnist_ad_sweeps_31_0.png