Hands-on 07: Autoencoders for anomaly detection#

This week, we will look at autoencoders for anomaly detection.

The goal is to train an autoencoder to reconstruct QCD (background) jets, which are plentiful at the LHC. Then we will apply it to top quark (signal) jets to see if the reconstruction error is larger. The reconstruction error can then be used as an anomaly score in real data.

This autoencoder architecture used is inspired by this paper: https://arxiv.org/abs/1808.08992. Autoencoder

You may need to install the JetNet library if you don’t already have it.

pip install --user jetnet tables==3.8.0

Download the dataset#

We will use a validation dataset of 400k jets, which is plenty for our purposes. The full dataset is available at https://doi.org/10.5281/zenodo.2603255.

import jetnet

im_size = 16
jet_r = 0.8
max_jets = 50000

# download the validation data (400k jets, which is plenty for our purposes)
# full dataset is available here: https://doi.org/10.5281/zenodo.2603255
data = jetnet.datasets.TopTagging(
    jet_type="all",
    particle_features=["E", "px", "py", "pz"],
    jet_features=["type"],
    split="valid",
    data_dir="data/",
    particle_transform=jetnet.utils.cartesian_to_relEtaPhiPt,
    download=True,
)
/home/runner/miniconda3/envs/phys139/lib/python3.10/site-packages/jetnet/utils/utils.py:7: FutureWarning: In version 2024.7.0 (target date: 2024-06-30 11:59:59-05:00), this will be an error.
To raise these warnings as errors (and get stack traces to find out where they're called), run
    import warnings
    warnings.filterwarnings("error", module="coffea.*")
after the first `import coffea` or use `@pytest.mark.filterwarnings("error:::coffea.*")` in pytest.
Issue: coffea.nanoevents.methods.vector will be removed and replaced with scikit-hep vector. Nanoevents schemas internal to coffea will be migrated. Otherwise please consider using that package!.
  from coffea.nanoevents.methods import vector
Downloading val dataset to data/val.h5
Downloading dataset
[..................................................] 0%
[..................................................] 1%
[..................................................] 1%
[..................................................] 1%
[..................................................] 2%
[..................................................] 2%
[█.................................................] 2%
[█.................................................] 2%
[█.................................................] 3%
[█.................................................] 3%
[█.................................................] 3%
[█.................................................] 4%
[█.................................................] 4%
[██................................................] 4%
[██................................................] 5%
[██................................................] 5%
[██................................................] 5%
[██................................................] 5%
[██................................................] 6%
[███...............................................] 6%
[███...............................................] 6%
[███...............................................] 7%
[███...............................................] 7%
[███...............................................] 7%
[███...............................................] 8%
[███...............................................] 8%
[████..............................................] 8%
[████..............................................] 8%
[████..............................................] 9%
[████..............................................] 9%
[████..............................................] 9%
[████..............................................] 10%
[████..............................................] 10%
[█████.............................................] 10%
[█████.............................................] 11%
[█████.............................................] 11%
[█████.............................................] 11%
[█████.............................................] 11%
[█████.............................................] 12%
[██████............................................] 12%
[██████............................................] 12%
[██████............................................] 13%
[██████............................................] 13%
[██████............................................] 13%
[██████............................................] 14%
[██████............................................] 14%
[███████...........................................] 14%
[███████...........................................] 14%
[███████...........................................] 15%
[███████...........................................] 15%
[███████...........................................] 15%
[███████...........................................] 16%
[███████...........................................] 16%
[████████..........................................] 16%
[████████..........................................] 17%
[████████..........................................] 17%
[████████..........................................] 17%
[████████..........................................] 18%
[████████..........................................] 18%
[█████████.........................................] 18%
[█████████.........................................] 18%
[█████████.........................................] 19%
[█████████.........................................] 19%
[█████████.........................................] 19%
[█████████.........................................] 20%
[█████████.........................................] 20%
[██████████........................................] 20%
[██████████........................................] 21%
[██████████........................................] 21%
[██████████........................................] 21%
[██████████........................................] 21%
[██████████........................................] 22%
[███████████.......................................] 22%
[███████████.......................................] 22%
[███████████.......................................] 23%
[███████████.......................................] 23%
[███████████.......................................] 23%
[███████████.......................................] 24%
[███████████.......................................] 24%
[████████████......................................] 24%
[████████████......................................] 24%
[████████████......................................] 25%
[████████████......................................] 25%
[████████████......................................] 25%
[████████████......................................] 26%
[████████████......................................] 26%
[█████████████.....................................] 26%
[█████████████.....................................] 27%
[█████████████.....................................] 27%
[█████████████.....................................] 27%
[█████████████.....................................] 27%
[█████████████.....................................] 28%
[██████████████....................................] 28%
[██████████████....................................] 28%
[██████████████....................................] 29%
[██████████████....................................] 29%
[██████████████....................................] 29%
[██████████████....................................] 30%
[██████████████....................................] 30%
[███████████████...................................] 30%
[███████████████...................................] 30%
[███████████████...................................] 31%
[███████████████...................................] 31%
[███████████████...................................] 31%
[███████████████...................................] 32%
[███████████████...................................] 32%
[████████████████..................................] 32%
[████████████████..................................] 33%
[████████████████..................................] 33%
[████████████████..................................] 33%
[████████████████..................................] 34%
[████████████████..................................] 34%
[█████████████████.................................] 34%
[█████████████████.................................] 34%
[█████████████████.................................] 35%
[█████████████████.................................] 35%
[█████████████████.................................] 35%
[█████████████████.................................] 36%
[█████████████████.................................] 36%
[██████████████████................................] 36%
[██████████████████................................] 37%
[██████████████████................................] 37%
[██████████████████................................] 37%
[██████████████████................................] 37%
[██████████████████................................] 38%
[███████████████████...............................] 38%
[███████████████████...............................] 38%
[███████████████████...............................] 39%
[███████████████████...............................] 39%
[███████████████████...............................] 39%
[███████████████████...............................] 40%
[███████████████████...............................] 40%
[████████████████████..............................] 40%
[████████████████████..............................] 40%
[████████████████████..............................] 41%
[████████████████████..............................] 41%
[████████████████████..............................] 41%
[████████████████████..............................] 42%
[████████████████████..............................] 42%
[█████████████████████.............................] 42%
[█████████████████████.............................] 43%
[█████████████████████.............................] 43%
[█████████████████████.............................] 43%
[█████████████████████.............................] 43%
[█████████████████████.............................] 44%
[██████████████████████............................] 44%
[██████████████████████............................] 44%
[██████████████████████............................] 45%
[██████████████████████............................] 45%
[██████████████████████............................] 45%
[██████████████████████............................] 46%
[██████████████████████............................] 46%
[███████████████████████...........................] 46%
[███████████████████████...........................] 46%
[███████████████████████...........................] 47%
[███████████████████████...........................] 47%
[███████████████████████...........................] 47%
[███████████████████████...........................] 48%
[███████████████████████...........................] 48%
[████████████████████████..........................] 48%
[████████████████████████..........................] 49%
[████████████████████████..........................] 49%
[████████████████████████..........................] 49%
[████████████████████████..........................] 50%
[████████████████████████..........................] 50%
[█████████████████████████.........................] 50%
[█████████████████████████.........................] 50%
[█████████████████████████.........................] 51%
[█████████████████████████.........................] 51%
[█████████████████████████.........................] 51%
[█████████████████████████.........................] 52%
[█████████████████████████.........................] 52%
[██████████████████████████........................] 52%
[██████████████████████████........................] 53%
[██████████████████████████........................] 53%
[██████████████████████████........................] 53%
[██████████████████████████........................] 53%
[██████████████████████████........................] 54%
[███████████████████████████.......................] 54%
[███████████████████████████.......................] 54%
[███████████████████████████.......................] 55%
[███████████████████████████.......................] 55%
[███████████████████████████.......................] 55%
[███████████████████████████.......................] 56%
[███████████████████████████.......................] 56%
[████████████████████████████......................] 56%
[████████████████████████████......................] 56%
[████████████████████████████......................] 57%
[████████████████████████████......................] 57%
[████████████████████████████......................] 57%
[████████████████████████████......................] 58%
[████████████████████████████......................] 58%
[█████████████████████████████.....................] 58%
[█████████████████████████████.....................] 59%
[█████████████████████████████.....................] 59%
[█████████████████████████████.....................] 59%
[█████████████████████████████.....................] 59%
[█████████████████████████████.....................] 60%
[██████████████████████████████....................] 60%
[██████████████████████████████....................] 60%
[██████████████████████████████....................] 61%
[██████████████████████████████....................] 61%
[██████████████████████████████....................] 61%
[██████████████████████████████....................] 62%
[██████████████████████████████....................] 62%
[███████████████████████████████...................] 62%
[███████████████████████████████...................] 62%
[███████████████████████████████...................] 63%
[███████████████████████████████...................] 63%
[███████████████████████████████...................] 63%
[███████████████████████████████...................] 64%
[███████████████████████████████...................] 64%
[████████████████████████████████..................] 64%
[████████████████████████████████..................] 65%
[████████████████████████████████..................] 65%
[████████████████████████████████..................] 65%
[████████████████████████████████..................] 66%
[████████████████████████████████..................] 66%
[█████████████████████████████████.................] 66%
[█████████████████████████████████.................] 66%
[█████████████████████████████████.................] 67%
[█████████████████████████████████.................] 67%
[█████████████████████████████████.................] 67%
[█████████████████████████████████.................] 68%
[█████████████████████████████████.................] 68%
[██████████████████████████████████................] 68%
[██████████████████████████████████................] 69%
[██████████████████████████████████................] 69%
[██████████████████████████████████................] 69%
[██████████████████████████████████................] 69%
[██████████████████████████████████................] 70%
[███████████████████████████████████...............] 70%
[███████████████████████████████████...............] 70%
[███████████████████████████████████...............] 71%
[███████████████████████████████████...............] 71%
[███████████████████████████████████...............] 71%
[███████████████████████████████████...............] 72%
[███████████████████████████████████...............] 72%
[████████████████████████████████████..............] 72%
[████████████████████████████████████..............] 72%
[████████████████████████████████████..............] 73%
[████████████████████████████████████..............] 73%
[████████████████████████████████████..............] 73%
[████████████████████████████████████..............] 74%
[████████████████████████████████████..............] 74%
[█████████████████████████████████████.............] 74%
[█████████████████████████████████████.............] 75%
[█████████████████████████████████████.............] 75%
[█████████████████████████████████████.............] 75%
[█████████████████████████████████████.............] 75%
[█████████████████████████████████████.............] 76%
[██████████████████████████████████████............] 76%
[██████████████████████████████████████............] 76%
[██████████████████████████████████████............] 77%
[██████████████████████████████████████............] 77%
[██████████████████████████████████████............] 77%
[██████████████████████████████████████............] 78%
[██████████████████████████████████████............] 78%
[███████████████████████████████████████...........] 78%
[███████████████████████████████████████...........] 78%
[███████████████████████████████████████...........] 79%
[███████████████████████████████████████...........] 79%
[███████████████████████████████████████...........] 79%
[███████████████████████████████████████...........] 80%
[███████████████████████████████████████...........] 80%
[████████████████████████████████████████..........] 80%
[████████████████████████████████████████..........] 81%
[████████████████████████████████████████..........] 81%
[████████████████████████████████████████..........] 81%
[████████████████████████████████████████..........] 82%
[████████████████████████████████████████..........] 82%
[█████████████████████████████████████████.........] 82%
[█████████████████████████████████████████.........] 82%
[█████████████████████████████████████████.........] 83%
[█████████████████████████████████████████.........] 83%
[█████████████████████████████████████████.........] 83%
[█████████████████████████████████████████.........] 84%
[█████████████████████████████████████████.........] 84%
[██████████████████████████████████████████........] 84%
[██████████████████████████████████████████........] 85%
[██████████████████████████████████████████........] 85%
[██████████████████████████████████████████........] 85%
[██████████████████████████████████████████........] 85%
[██████████████████████████████████████████........] 86%
[███████████████████████████████████████████.......] 86%
[███████████████████████████████████████████.......] 86%
[███████████████████████████████████████████.......] 87%
[███████████████████████████████████████████.......] 87%
[███████████████████████████████████████████.......] 87%
[███████████████████████████████████████████.......] 88%
[███████████████████████████████████████████.......] 88%
[████████████████████████████████████████████......] 88%
[████████████████████████████████████████████......] 88%
[████████████████████████████████████████████......] 89%
[████████████████████████████████████████████......] 89%
[████████████████████████████████████████████......] 89%
[████████████████████████████████████████████......] 90%
[████████████████████████████████████████████......] 90%
[█████████████████████████████████████████████.....] 90%
[█████████████████████████████████████████████.....] 91%
[█████████████████████████████████████████████.....] 91%
[█████████████████████████████████████████████.....] 91%
[█████████████████████████████████████████████.....] 91%
[█████████████████████████████████████████████.....] 92%
[██████████████████████████████████████████████....] 92%
[██████████████████████████████████████████████....] 92%
[██████████████████████████████████████████████....] 93%
[██████████████████████████████████████████████....] 93%
[██████████████████████████████████████████████....] 93%
[██████████████████████████████████████████████....] 94%
[██████████████████████████████████████████████....] 94%
[███████████████████████████████████████████████...] 94%
[███████████████████████████████████████████████...] 94%
[███████████████████████████████████████████████...] 95%
[███████████████████████████████████████████████...] 95%
[███████████████████████████████████████████████...] 95%
[███████████████████████████████████████████████...] 96%
[███████████████████████████████████████████████...] 96%
[████████████████████████████████████████████████..] 96%
[████████████████████████████████████████████████..] 97%
[████████████████████████████████████████████████..] 97%
[████████████████████████████████████████████████..] 97%
[████████████████████████████████████████████████..] 97%
[████████████████████████████████████████████████..] 98%
[█████████████████████████████████████████████████.] 98%
[█████████████████████████████████████████████████.] 98%
[█████████████████████████████████████████████████.] 99%
[█████████████████████████████████████████████████.] 99%
[█████████████████████████████████████████████████.] 99%
[█████████████████████████████████████████████████.] 100%
[█████████████████████████████████████████████████.] 100%
[██████████████████████████████████████████████████] 100%

Transform and split the data#

The data is originally up to 200 particles per jet (zero-padded), and the features are the standard 4-vectors \((E, p_x, p_y, p_z)\). We can assume the particles are massless so \(E=\sqrt{p_x^2+p_y^2+p_z^2}\) and there are only 3 degrees of freedom.

We will transform to relative coordinates centered on the jet using the function jetnet.utils.cartesian_to_relEtaPhiPt:

\begin{align} \eta^\mathrm{rel} &=\eta^\mathrm{particle} - \eta^\mathrm{jet}\ \phi^\mathrm{rel} &=\phi^\mathrm{particle} - \phi^\mathrm{jet} \pmod{2\pi}\ p_\mathrm{T}^\mathrm{rel} &= p_\mathrm{T}^\mathrm{particle}/p_\mathrm{T}^\mathrm{jet} \end{align}

import numpy as np

indices = np.random.permutation(np.arange(len(data)))[:max_jets]
# transform the data
transformed_particle_data = data.particle_transform(data.particle_data[indices])
# split qcd background and top quark signal
qcd_data = transformed_particle_data[data.jet_data[indices][:, 0] == 0]
top_data = transformed_particle_data[data.jet_data[indices][:, 0] == 1]
from sklearn.model_selection import train_test_split

qcd_train, qcd_test = train_test_split(qcd_data, test_size=0.2, random_state=42)
top_train, top_test = train_test_split(top_data, test_size=0.2, random_state=42)

Convert data to jet images#

import numpy as np

#  convert full dataset
qcd_train_images = np.expand_dims(jetnet.utils.to_image(qcd_train, im_size=im_size, maxR=jet_r), axis=-1)
qcd_test_images = np.expand_dims(jetnet.utils.to_image(qcd_test, im_size=im_size, maxR=jet_r), axis=-1)
top_test_images = np.expand_dims(jetnet.utils.to_image(top_test, im_size=im_size, maxR=jet_r), axis=-1)

# rescale so sum is 1 (it should be close already)
qcd_train_images = qcd_train_images / np.sum(qcd_train_images.reshape(-1, 1, 1, 1, im_size * im_size), axis=-1)
qcd_test_images = qcd_test_images / np.sum(qcd_test_images.reshape(-1, 1, 1, 1, im_size * im_size), axis=-1)
top_test_images = top_test_images / np.sum(top_test_images.reshape(-1, 1, 1, 1, im_size * im_size), axis=-1)

Visualize the jet images#

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm


def plot_jet_images(images, titles, filename="jet_image.pdf"):

    n_images = len(images)
    plt.figure(figsize=(5 * n_images, 5))

    for i, (image, title) in enumerate(zip(images, titles)):
        plt.subplot(1, n_images, i + 1)
        plt.title(title)
        plt.imshow(image, origin="lower", norm=LogNorm(vmin=1e-3, vmax=1))
        cbar = plt.colorbar()
        plt.xlabel(r"$\Delta\eta$ cell", fontsize=15)
        plt.ylabel(r"$\Delta\phi$ cell", fontsize=15)
        cbar.set_label(r"$p_T/p_T^{jet}$", fontsize=15)

    plt.tight_layout()
    plt.savefig(filename)
plot_jet_images([qcd_test_images[0], top_test_images[0]], ["QCD jet image", "Top quark jet image"])
_images/ba5c0ff2245f610bed592e75c8a71ef812c70ab29ddd424229bcd798ab4e463f.png

Define the autoencoder model#

from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Dense,
    Input,
    Conv2D,
    Conv2DTranspose,
    Reshape,
    Flatten,
    Softmax,
)

x_in = Input(shape=(im_size, im_size, 1))
x = Conv2D(128, kernel_size=(3, 3), strides=(2, 2), activation="relu", padding="same")(x_in)
x = Conv2D(128, kernel_size=(3, 3), strides=(2, 2), activation="relu", padding="same")(x)
x = Flatten()(x)

x_enc = Dense(2, name="bottleneck")(x)

x = Dense(int(im_size * im_size / 16) * 128, activation="relu")(x_enc)
x = Reshape((int(im_size / 4), int(im_size / 4), 128))(x)
x = Conv2DTranspose(128, kernel_size=(3, 3), strides=(2, 2), activation="relu", padding="same")(x)
x = Conv2DTranspose(1, kernel_size=(3, 3), strides=(2, 2), activation="linear", padding="same")(x)
x_out = Softmax(name="softmax", axis=[-2, -3])(x)
model = Model(inputs=x_in, outputs=x_out, name="autoencoder")

model.compile(loss="mse", optimizer="adam")
model.summary()

# save the encoder-only model for easy access to latent space
encoder = Model(inputs=x_in, outputs=x_enc, name="encoder")
2024-06-05 04:35:03.871528: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[9], line 1
----> 1 from tensorflow.keras.models import Model
      2 from tensorflow.keras.layers import (
      3     Dense,
      4     Input,
   (...)
      9     Softmax,
     10 )
     12 x_in = Input(shape=(im_size, im_size, 1))

ModuleNotFoundError: No module named 'tensorflow.keras'

Train the autoencoder model#

history = model.fit(
    qcd_train_images,
    qcd_train_images,
    batch_size=32,
    epochs=10,
    verbose=0,
    validation_data=(qcd_test_images, qcd_test_images),
)

Reconstruction performance#

qcd_reco_image = model.predict(qcd_test_images[0:1], verbose=0).reshape(im_size, im_size)
plot_jet_images([qcd_test_images[0], qcd_reco_image], ["Input QCD jet image", "Reconstructed QCD jet image"])
top_reco_image = model.predict(top_test_images[0:1], verbose=0).reshape(im_size, im_size)
plot_jet_images([top_test_images[0], top_reco_image], ["Input top quark jet image", "Reconstructed top quark jet image"])

Anomaly detection performance#

qcd_reco_images = model.predict(qcd_test_images, verbose=0)
top_reco_images = model.predict(top_test_images, verbose=0)
diff_qcd = np.power((qcd_reco_images - qcd_test_images), 2)
loss_qcd = np.mean(diff_qcd.reshape(-1, im_size * im_size), axis=-1)

diff_top = np.power((top_reco_images - top_test_images), 2)
loss_top = np.mean(diff_top.reshape(-1, im_size * im_size), axis=-1)

loss_all = np.concatenate([loss_qcd, loss_top])
plt.figure()
bins = np.arange(0, np.max(loss_all), np.max(loss_all) / 100)
plt.hist(loss_qcd, label="QCD jets", bins=bins, alpha=0.8)
plt.hist(loss_top, label="Top quark jets", bins=bins, alpha=0.8)
plt.legend(title="Autoencoder")
plt.xlabel("MSE loss")
plt.ylabel("Jets")
plt.xlim(0, np.max(loss_all))
plt.show()

Exercises#

  1. Plot the ROC curve for the MSE loss of the autoencoder on the merged testing sample of QCD and top quark jets, assuming a label of 1 for top quark jets and a label of 0 for QCD jets. Report the AUC.

  2. Perform a PCA on only the QCD training images using sklearn.decomposition.PCA with 2 components. Note you will have to reshape the image tensors so that they are 2D instead of 4D (as required by the autoencoder), e.g. qcd_test_images.reshape(-1, im_size * im_size)). Plot the distribution of the reconstruction losses for top quark jets and QCD jets separately. Hint: review https://rittikghosh.com/autoencoder.html.

  3. Plot the PCA ROC curve similar to part 1. Report the AUC.

  4. Plot the 2D latent space for the QCD and top quark test images for both the autoencoder and the PCA.