Hands-on 06: Graph data and GNNs: Tagging Higgs boson jets#

This week, we will look at graph neural networks using the PyTorch Geometric library: https://pytorch-geometric.readthedocs.io/. See [] for more details.

The Docker image jmduarte/phys139:latest should work, but there’s some ongoing issues.

For now, use the standard ghcr.io/ucsd-ets/scipy-ml-notebook:2025.1-stable image, and install some additional libraries:

!pip install torch_geometric mplhep
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.2.1+cu121.html 

Data should also be downloaded locally:

!wget https://opendata.cern.ch/record/12102/files/assets/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/train/ntuple_merged_10.root data/ntuple_merged_10.root
!wget https://opendata.cern.ch/record/12102/files/assets/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/test/ntuple_merged_0.root data/ntuple_merged_0.root
import torch
import torch_geometric

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from tqdm.notebook import tqdm
import numpy as np
import yaml

with open("definitions.yml") as file:
    # The FullLoader parameter handles the conversion from YAML
    # scalar values to Python the dictionary format
    definitions = yaml.load(file, Loader=yaml.FullLoader)

features = definitions["features"]
spectators = definitions["spectators"]
labels = definitions["labels"]

nfeatures = definitions["nfeatures"]
nspectators = definitions["nspectators"]
nlabels = definitions["nlabels"]
ntracks = definitions["ntracks"]

Graph datasets#

Here we have to define the graph dataset. We do this in a separate class following this example: https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-larger-datasets

Formally, a graph is represented by a triplet \(\mathcal G = (\mathbf{u}, V, E)\), consisting of a graph-level, or global, feature vector \(\mathbf{u}\), a set of \(N^v\) nodes \(V\), and a set of \(N^e\) edges \(E\). The nodes are given by \(V = \{\mathbf{v}_i\}_{i=1:N^v}\), where \(\mathbf{v}_i\) represents the \(i\)th node’s attributes. The edges connect pairs of nodes and are given by \(E = \{\left(\mathbf{e}_k, r_k, s_k\right)\}_{k=1:N^e}\), where \(\mathbf{e}_k\) represents the \(k\)th edge’s attributes, and \(r_k\) and \(s_k\) are the indices of the “receiver” and “sender” nodes, respectively, connected by the \(k\)th edge (from the sender node to the receiver node). The receiver and sender index vectors are an alternative way of encoding the directed adjacency matrix.

attributes
from GraphDataset import GraphDataset

local = True # for using the local files

if local:
    file_names = ["$HOME/phys139_239/data/ntuple_merged_10.root"]
    file_names_test = ["$HOME/phys139_239/data/ntuple_merged_0.root"]
else:
    file_names = [
        "root://eospublic.cern.ch//eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/train/ntuple_merged_10.root"
    ]
    file_names_test = [
        "root://eospublic.cern.ch//eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/test/ntuple_merged_0.root"
    ]

graph_dataset = GraphDataset(
    "gdata_train", features, labels, spectators, n_events=1000, n_events_merge=1, file_names=file_names
)

test_dataset = GraphDataset(
    "gdata_test", features, labels, spectators, n_events=2000, n_events_merge=1, file_names=file_names_test
)
Processing...
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
Cell In[3], line 16
      9     file_names = [
     10         "root://eospublic.cern.ch//eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/train/ntuple_merged_10.root"
     11     ]
     12     file_names_test = [
     13         "root://eospublic.cern.ch//eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/test/ntuple_merged_0.root"
     14     ]
---> 16 graph_dataset = GraphDataset(
     17     "gdata_train", features, labels, spectators, n_events=1000, n_events_merge=1, file_names=file_names
     18 )
     20 test_dataset = GraphDataset(
     21     "gdata_test", features, labels, spectators, n_events=2000, n_events_merge=1, file_names=file_names_test
     22 )

File ~/work/phys139_239/phys139_239/notebooks/GraphDataset.py:44, in GraphDataset.__init__(self, root, features, labels, spectators, transform, pre_transform, n_events, n_events_merge, file_names, remove_unlabeled)
     42 self.file_names = file_names
     43 self.remove_unlabeled = remove_unlabeled
---> 44 super(GraphDataset, self).__init__(root, transform, pre_transform)

File ~/miniconda3/envs/phys139/lib/python3.11/site-packages/torch_geometric/data/dataset.py:115, in Dataset.__init__(self, root, transform, pre_transform, pre_filter, log, force_reload)
    112     self._download()
    114 if self.has_process:
--> 115     self._process()

File ~/miniconda3/envs/phys139/lib/python3.11/site-packages/torch_geometric/data/dataset.py:262, in Dataset._process(self)
    259     print('Processing...', file=sys.stderr)
    261 fs.makedirs(self.processed_dir, exist_ok=True)
--> 262 self.process()
    264 path = osp.join(self.processed_dir, 'pre_transform.pt')
    265 fs.torch_save(_repr(self.pre_transform), path)

File ~/work/phys139_239/phys139_239/notebooks/GraphDataset.py:85, in GraphDataset.process(self)
     77 """
     78 Handles conversion of dataset file at raw_path into graph dataset.
     79 
   (...)     82     k (int): Number of process (0,...,max_events // n_proc) to determine where to read file
     83 """
     84 for raw_path in self.raw_file_names:
---> 85     with uproot.open(raw_path, **get_file_handler(raw_path)) as root_file:
     87         tree = root_file["deepntuplizer/tree"]
     89         feature_array = tree.arrays(self.features, entry_stop=self.n_events, library="ak")

File ~/miniconda3/envs/phys139/lib/python3.11/site-packages/uproot/reading.py:142, in open(path, object_cache, array_cache, custom_classes, decompression_executor, interpretation_executor, **options)
    133 if not isinstance(file_path, str) and not (
    134     hasattr(file_path, "read") and hasattr(file_path, "seek")
    135 ):
    136     raise ValueError(
    137         "'path' must be a string, pathlib.Path, an object with 'read' and "
    138         "'seek' methods, or a length-1 dict of {file_path: object_path}, "
    139         f"not {path!r}"
    140     )
--> 142 file = ReadOnlyFile(
    143     file_path,
    144     object_cache=object_cache,
    145     array_cache=array_cache,
    146     custom_classes=custom_classes,
    147     decompression_executor=decompression_executor,
    148     interpretation_executor=interpretation_executor,
    149     **options,
    150 )
    152 if object_path is None:
    153     return file.root_directory

File ~/miniconda3/envs/phys139/lib/python3.11/site-packages/uproot/reading.py:573, in ReadOnlyFile.__init__(self, file_path, object_cache, array_cache, custom_classes, decompression_executor, interpretation_executor, **options)
    565 if self._options["begin_chunk_size"] < _file_header_fields_big.size:
    566     raise ValueError(
    567         "begin_chunk_size={} is not enough to read the TFile header ({})".format(
    568             self._options["begin_chunk_size"],
    569             _file_header_fields_big.size,
    570         )
    571     )
--> 573 self._begin_chunk = self._source.chunk(
    574     0, self._options["begin_chunk_size"]
    575 ).detach_memmap()
    577 self.hook_before_interpret()
    579 (
    580     magic,
    581     self._fVersion,
   (...)    595     self._begin_chunk, _file_header_fields_small, {}
    596 )

File ~/miniconda3/envs/phys139/lib/python3.11/site-packages/uproot/source/fsspec.py:117, in FSSpecSource.chunk(self, start, stop)
    115 self._num_requested_chunks += 1
    116 self._num_requested_bytes += stop - start
--> 117 data = self._fs.cat_file(self._file_path, start=start, end=stop)
    118 future = uproot.source.futures.TrivialFuture(data)
    119 return uproot.source.chunk.Chunk(self, start, stop, future)

File ~/miniconda3/envs/phys139/lib/python3.11/site-packages/fsspec/spec.py:802, in AbstractFileSystem.cat_file(self, path, start, end, **kwargs)
    790 """Get the content of a file
    791 
    792 Parameters
   (...)    799 kwargs: passed to ``open()``.
    800 """
    801 # explicitly set buffering off?
--> 802 with self.open(path, "rb", **kwargs) as f:
    803     if start is not None:
    804         if start >= 0:

File ~/miniconda3/envs/phys139/lib/python3.11/site-packages/fsspec/spec.py:1349, in AbstractFileSystem.open(self, path, mode, block_size, cache_options, compression, **kwargs)
   1347 else:
   1348     ac = kwargs.pop("autocommit", not self._intrans)
-> 1349     f = self._open(
   1350         path,
   1351         mode=mode,
   1352         block_size=block_size,
   1353         autocommit=ac,
   1354         cache_options=cache_options,
   1355         **kwargs,
   1356     )
   1357     if compression is not None:
   1358         from fsspec.compression import compr

File ~/miniconda3/envs/phys139/lib/python3.11/site-packages/fsspec/implementations/local.py:210, in LocalFileSystem._open(self, path, mode, block_size, **kwargs)
    208 if self.auto_mkdir and "w" in mode:
    209     self.makedirs(self._parent(path), exist_ok=True)
--> 210 return LocalFileOpener(path, mode, fs=self, **kwargs)

File ~/miniconda3/envs/phys139/lib/python3.11/site-packages/fsspec/implementations/local.py:387, in LocalFileOpener.__init__(self, path, mode, autocommit, fs, compression, **kwargs)
    385 self.compression = get_compression(path, compression)
    386 self.blocksize = io.DEFAULT_BUFFER_SIZE
--> 387 self._open()

File ~/miniconda3/envs/phys139/lib/python3.11/site-packages/fsspec/implementations/local.py:392, in LocalFileOpener._open(self)
    390 if self.f is None or self.f.closed:
    391     if self.autocommit or "w" not in self.mode:
--> 392         self.f = open(self.path, mode=self.mode)
    393         if self.compression:
    394             compress = compr[self.compression]

FileNotFoundError: [Errno 2] No such file or directory: '/home/runner/work/phys139_239/phys139_239/notebooks/$HOME/phys139_239/data/ntuple_merged_10.root'

Graph neural network#

Here, we recapitulate the “graph network” (GN) formalism [], which generalizes various GNNs and other similar methods. GNs are graph-to-graph mappings, whose output graphs have the same node and edge structure as the input. Formally, a GN block contains three “update” functions, \(\phi\), and three “aggregation” functions, \(\rho\). The stages of processing in a single GN block are:

\( \begin{align} \mathbf{e}'_k &= \phi^e\left(\mathbf{e}_k, \mathbf{v}_{r_k}, \mathbf{v}_{s_k}, \mathbf{u} \right) & \mathbf{\bar{e}}'_i &= \rho^{e \rightarrow v}\left(E'_i\right) & \text{(Edge block),}\\ \mathbf{v}'_i &= \phi^v\left(\mathbf{\bar{e}}'_i, \mathbf{v}_i, \mathbf{u}\right) & \mathbf{\bar{e}}' &= \rho^{e \rightarrow u}\left(E'\right) & \text{(Node block),}\\ \mathbf{u}' &= \phi^u\left(\mathbf{\bar{e}}', \mathbf{\bar{v}}', \mathbf{u}\right) & \mathbf{\bar{v}}' &= \rho^{v \rightarrow u}\left(V'\right) &\text{(Global block).} \label{eq:gn-functions} \end{align} \)

where \(E'_i = \left\{\left(\mathbf{e}'_k, r_k, s_k \right)\right\}_{r_k=i,\; k=1:N^e}\) contains the updated edge features for edges whose receiver node is the \(i\)th node, \(E' = \bigcup_i E_i' = \left\{\left(\mathbf{e}'_k, r_k, s_k \right)\right\}_{k=1:N^e}\) is the set of updated edges, and \(V'=\left\{\mathbf{v}'_i\right\}_{i=1:N^v}\) is the set of updated nodes.

GN full block

We will define an interaction network model similar to Ref. [1], but just modeling the particle-particle interactions. It will take as input all of the tracks (with 48 features) without truncating or zero-padding. Another modification is the use of batch normalization [] layers to improve the stability of the training.

import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.nn import EdgeConv, global_mean_pool
from torch.nn import Sequential as Seq, Linear as Lin, ReLU, BatchNorm1d
from torch_scatter import scatter_mean
from torch_geometric.nn import MetaLayer

inputs = 48
hidden = 128
outputs = 2


class EdgeBlock(torch.nn.Module):
    def __init__(self):
        super(EdgeBlock, self).__init__()
        self.edge_mlp = Seq(Lin(inputs * 2, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, hidden))

    def forward(self, src, dest, edge_attr, u, batch):
        out = torch.cat([src, dest], 1)
        return self.edge_mlp(out)


class NodeBlock(torch.nn.Module):
    def __init__(self):
        super(NodeBlock, self).__init__()
        self.node_mlp_1 = Seq(Lin(inputs + hidden, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, hidden))
        self.node_mlp_2 = Seq(Lin(inputs + hidden, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, hidden))

    def forward(self, x, edge_index, edge_attr, u, batch):
        row, col = edge_index
        out = torch.cat([x[row], edge_attr], dim=1)
        out = self.node_mlp_1(out)
        out = scatter_mean(out, col, dim=0, dim_size=x.size(0))
        out = torch.cat([x, out], dim=1)
        return self.node_mlp_2(out)


class GlobalBlock(torch.nn.Module):
    def __init__(self):
        super(GlobalBlock, self).__init__()
        self.global_mlp = Seq(Lin(hidden, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, outputs))

    def forward(self, x, edge_index, edge_attr, u, batch):
        out = scatter_mean(x, batch, dim=0)
        return self.global_mlp(out)


class InteractionNetwork(torch.nn.Module):
    def __init__(self):
        super(InteractionNetwork, self).__init__()
        self.interactionnetwork = MetaLayer(EdgeBlock(), NodeBlock(), GlobalBlock())
        self.bn = BatchNorm1d(inputs)

    def forward(self, x, edge_index, batch):

        x = self.bn(x)
        x, edge_attr, u = self.interactionnetwork(x, edge_index, None, None, batch)
        return u


model = InteractionNetwork().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

Define training loop#

@torch.no_grad()
def test(model, loader, total, batch_size, leave=False):
    model.eval()

    xentropy = nn.CrossEntropyLoss(reduction="mean")

    sum_loss = 0.0
    t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)
    for i, data in t:
        data = data.to(device)
        y = torch.argmax(data.y, dim=1)
        batch_output = model(data.x, data.edge_index, data.batch)
        batch_loss_item = xentropy(batch_output, y).item()
        sum_loss += batch_loss_item
        t.set_description("loss = %.5f" % (batch_loss_item))
        t.refresh()  # to show immediately the update

    return sum_loss / (i + 1)


def train(model, optimizer, loader, total, batch_size, leave=False):
    model.train()

    xentropy = nn.CrossEntropyLoss(reduction="mean")

    sum_loss = 0.0
    t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)
    for i, data in t:
        data = data.to(device)
        y = torch.argmax(data.y, dim=1)
        optimizer.zero_grad()
        batch_output = model(data.x, data.edge_index, data.batch)
        batch_loss = xentropy(batch_output, y)
        batch_loss.backward()
        batch_loss_item = batch_loss.item()
        t.set_description("loss = %.5f" % batch_loss_item)
        t.refresh()  # to show immediately the update
        sum_loss += batch_loss_item
        optimizer.step()

    return sum_loss / (i + 1)

Define training, validation, testing data generators#

from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataListLoader
from torch.utils.data import random_split


def collate(items):
    l = sum(items, [])
    return Batch.from_data_list(l)


torch.manual_seed(0)
valid_frac = 0.20
full_length = len(graph_dataset)
valid_num = int(valid_frac * full_length)
batch_size = 32

train_dataset, valid_dataset = random_split(graph_dataset, [full_length - valid_num, valid_num])

train_loader = DataListLoader(train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True)
train_loader.collate_fn = collate
valid_loader = DataListLoader(valid_dataset, batch_size=batch_size, pin_memory=True, shuffle=False)
valid_loader.collate_fn = collate
test_loader = DataListLoader(test_dataset, batch_size=batch_size, pin_memory=True, shuffle=False)
test_loader.collate_fn = collate


train_samples = len(train_dataset)
valid_samples = len(valid_dataset)
test_samples = len(test_dataset)
print(full_length)
print(train_samples)
print(valid_samples)
print(test_samples)

Train#

We’ll train for only 1 epoch to save time, but you can increase this to 10 or so to get better performance. Note early stopping with a patience of 5 epochs is implemented.

import os.path as osp

n_epochs = 1
stale_epochs = 0
best_valid_loss = 99999
patience = 5
t = tqdm(range(0, n_epochs))

for epoch in t:
    loss = train(model, optimizer, train_loader, train_samples, batch_size, leave=bool(epoch == n_epochs - 1))
    valid_loss = test(model, valid_loader, valid_samples, batch_size, leave=bool(epoch == n_epochs - 1))
    print("Epoch: {:02d}, Training Loss:   {:.4f}".format(epoch, loss))
    print("           Validation Loss: {:.4f}".format(valid_loss))

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        modpath = osp.join("interactionnetwork_best.pth")
        print("New best model saved to:", modpath)
        torch.save(model.state_dict(), modpath)
        stale_epochs = 0
    else:
        print("Stale epoch")
        stale_epochs += 1
    if stale_epochs >= patience:
        print("Early stopping after %i stale epochs" % patience)
        break

Evaluate on testing data#

model.eval()
t = tqdm(enumerate(test_loader), total=test_samples / batch_size)
y_test = []
y_predict = []
for i, data in t:
    data = data.to(device)
    batch_output = model(data.x, data.edge_index, data.batch)
    y_predict.append(batch_output.detach().cpu().numpy())
    y_test.append(data.y.cpu().numpy())
y_test = np.concatenate(y_test)
y_predict = np.concatenate(y_predict)
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import mplhep as hep

plt.style.use(hep.style.ROOT)
# create ROC curves
fpr_gnn, tpr_gnn, threshold_gnn = roc_curve(y_test[:, 1], y_predict[:, 1])

# plot ROC curves
plt.figure()
plt.plot(tpr_gnn, fpr_gnn, lw=2.5, label="GNN, AUC = {:.1f}%".format(auc(fpr_gnn, tpr_gnn) * 100))
plt.xlabel(r"True positive rate")
plt.ylabel(r"False positive rate")
plt.semilogy()
plt.ylim(0.001, 1)
plt.xlim(0, 1)
plt.grid(True)
plt.legend(loc="upper left")
plt.show()

Exercises#

  1. Replace the Interaction Network model with a Deep Set model and check the performance. A partial coding of the model is below:

class DeepSet(torch.nn.Module):
    def __init__(self):
        super(DeepSet, self).__init__()
        self.bn = BatchNorm1d(inputs)
        self.node_mlp = Seq(Lin(inputs, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, hidden))
        self.global_mlp = Seq(Lin(hidden, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, outputs))


    def forward(self, x, edge_index, batch):
        x = self.bn(x)
        x = ... # apply node_mlp to x
        mean = ... # take mean over all node features x in one graph (see GlobalBlock of InteractionNetwork)
        out = ... # apply global mlp to mean
        return out