Hands-on 06: Graph data and GNNs: Tagging Higgs boson jets#
This week, we will look at graph neural networks using the PyTorch Geometric library: https://pytorch-geometric.readthedocs.io/. See [] for more details.
If you chose the Docker image jmduarte/phys139:latest
, this notebook should work out of the box.
Depending on the environment used, you may need to install additional libaries:
!pip install uproot torch_geometric mplhep
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.2+cu118.html
Data can also be downloaded locally:
!wget https://opendata.cern.ch/record/12102/files/assets/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/train/ntuple_merged_10.root
!wget https://opendata.cern.ch/record/12102/files/assets/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/test/ntuple_merged_0.root
on DataHub two files have been copied locally to
/higgs/train/ntuple_merged_10.root
/higgs/test/ntuple_merged_0.root
import torch
import torch_geometric
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from tqdm.notebook import tqdm
import numpy as np
import yaml
with open("definitions.yml") as file:
# The FullLoader parameter handles the conversion from YAML
# scalar values to Python the dictionary format
definitions = yaml.load(file, Loader=yaml.FullLoader)
features = definitions["features"]
spectators = definitions["spectators"]
labels = definitions["labels"]
nfeatures = definitions["nfeatures"]
nspectators = definitions["nspectators"]
nlabels = definitions["nlabels"]
ntracks = definitions["ntracks"]
Graph datasets#
Here we have to define the graph dataset. We do this in a separate class following this example: https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-larger-datasets
Formally, a graph is represented by a triplet \(\mathcal G = (\mathbf{u}, V, E)\), consisting of a graph-level, or global, feature vector \(\mathbf{u}\), a set of \(N^v\) nodes \(V\), and a set of \(N^e\) edges \(E\). The nodes are given by \(V = \{\mathbf{v}_i\}_{i=1:N^v}\), where \(\mathbf{v}_i\) represents the \(i\)th node’s attributes. The edges connect pairs of nodes and are given by \(E = \{\left(\mathbf{e}_k, r_k, s_k\right)\}_{k=1:N^e}\), where \(\mathbf{e}_k\) represents the \(k\)th edge’s attributes, and \(r_k\) and \(s_k\) are the indices of the “receiver” and “sender” nodes, respectively, connected by the \(k\)th edge (from the sender node to the receiver node). The receiver and sender index vectors are an alternative way of encoding the directed adjacency matrix.
from GraphDataset import GraphDataset
local = False
if local:
file_names = ["/higgs/train/ntuple_merged_10.root"]
file_names_test = ["/higgs/test/ntuple_merged_0.root"]
else:
file_names = [
"root://eospublic.cern.ch//eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/train/ntuple_merged_10.root"
]
file_names_test = [
"root://eospublic.cern.ch//eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/test/ntuple_merged_0.root"
]
graph_dataset = GraphDataset(
"gdata_train", features, labels, spectators, n_events=1000, n_events_merge=1, file_names=file_names
)
test_dataset = GraphDataset(
"gdata_test", features, labels, spectators, n_events=2000, n_events_merge=1, file_names=file_names_test
)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[3], line 16
9 file_names = [
10 "root://eospublic.cern.ch//eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/train/ntuple_merged_10.root"
11 ]
12 file_names_test = [
13 "root://eospublic.cern.ch//eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/test/ntuple_merged_0.root"
14 ]
---> 16 graph_dataset = GraphDataset(
17 "gdata_train", features, labels, spectators, n_events=1000, n_events_merge=1, file_names=file_names
18 )
20 test_dataset = GraphDataset(
21 "gdata_test", features, labels, spectators, n_events=2000, n_events_merge=1, file_names=file_names_test
22 )
File ~/work/phys139_239/phys139_239/notebooks/GraphDataset.py:44, in GraphDataset.__init__(self, root, features, labels, spectators, transform, pre_transform, n_events, n_events_merge, file_names, remove_unlabeled)
42 self.file_names = file_names
43 self.remove_unlabeled = remove_unlabeled
---> 44 super(GraphDataset, self).__init__(root, transform, pre_transform)
File ~/miniconda3/envs/phys139/lib/python3.10/site-packages/torch_geometric/data/dataset.py:112, in Dataset.__init__(self, root, transform, pre_transform, pre_filter, log, force_reload)
109 self.force_reload = force_reload
111 if self.has_download:
--> 112 self._download()
114 if self.has_process:
115 self._process()
File ~/miniconda3/envs/phys139/lib/python3.10/site-packages/torch_geometric/data/dataset.py:225, in Dataset._download(self)
224 def _download(self):
--> 225 if files_exist(self.raw_paths): # pragma: no cover
226 return
228 fs.makedirs(self.raw_dir, exist_ok=True)
File ~/miniconda3/envs/phys139/lib/python3.10/site-packages/torch_geometric/data/dataset.py:422, in files_exist(files)
419 def files_exist(files: List[str]) -> bool:
420 # NOTE: We return `False` in case `files` is empty, leading to a
421 # re-processing of files on every instantiation.
--> 422 return len(files) != 0 and all([fs.exists(f) for f in files])
File ~/miniconda3/envs/phys139/lib/python3.10/site-packages/torch_geometric/data/dataset.py:422, in <listcomp>(.0)
419 def files_exist(files: List[str]) -> bool:
420 # NOTE: We return `False` in case `files` is empty, leading to a
421 # re-processing of files on every instantiation.
--> 422 return len(files) != 0 and all([fs.exists(f) for f in files])
File ~/miniconda3/envs/phys139/lib/python3.10/site-packages/torch_geometric/io/fs.py:51, in exists(path)
50 def exists(path: str) -> bool:
---> 51 return get_fs(path).exists(path)
File ~/miniconda3/envs/phys139/lib/python3.10/site-packages/torch_geometric/io/fs.py:41, in get_fs(path)
15 def get_fs(path: str) -> fsspec.AbstractFileSystem:
16 r"""Get filesystem backend given a path URI to the resource.
17
18 Here are some common example paths and dispatch result:
(...)
39 :obj:`"gs://home/me/file"`, :obj:`"s3://..."`.
40 """
---> 41 return fsspec.core.url_to_fs(path)[0]
File ~/miniconda3/envs/phys139/lib/python3.10/site-packages/fsspec/core.py:396, in url_to_fs(url, **kwargs)
385 known_kwargs = {
386 "compression",
387 "encoding",
(...)
393 "num",
394 }
395 kwargs = {k: v for k, v in kwargs.items() if k not in known_kwargs}
--> 396 chain = _un_chain(url, kwargs)
397 inkwargs = {}
398 # Reverse iterate the chain, creating a nested target_* structure
File ~/miniconda3/envs/phys139/lib/python3.10/site-packages/fsspec/core.py:344, in _un_chain(path, kwargs)
342 for bit in reversed(bits):
343 protocol = kwargs.pop("protocol", None) or split_protocol(bit)[0] or "file"
--> 344 cls = get_filesystem_class(protocol)
345 extra_kwargs = cls._get_kwargs_from_urls(bit)
346 kws = kwargs.pop(protocol, {})
File ~/miniconda3/envs/phys139/lib/python3.10/site-packages/fsspec/registry.py:239, in get_filesystem_class(protocol)
237 if protocol not in registry:
238 if protocol not in known_implementations:
--> 239 raise ValueError(f"Protocol not known: {protocol}")
240 bit = known_implementations[protocol]
241 try:
ValueError: Protocol not known: gdata_train/raw/root
Graph neural network#
Here, we recapitulate the “graph network” (GN) formalism [], which generalizes various GNNs and other similar methods. GNs are graph-to-graph mappings, whose output graphs have the same node and edge structure as the input. Formally, a GN block contains three “update” functions, \(\phi\), and three “aggregation” functions, \(\rho\). The stages of processing in a single GN block are:
\( \begin{align} \mathbf{e}'_k &= \phi^e\left(\mathbf{e}_k, \mathbf{v}_{r_k}, \mathbf{v}_{s_k}, \mathbf{u} \right) & \mathbf{\bar{e}}'_i &= \rho^{e \rightarrow v}\left(E'_i\right) & \text{(Edge block),}\\ \mathbf{v}'_i &= \phi^v\left(\mathbf{\bar{e}}'_i, \mathbf{v}_i, \mathbf{u}\right) & \mathbf{\bar{e}}' &= \rho^{e \rightarrow u}\left(E'\right) & \text{(Node block),}\\ \mathbf{u}' &= \phi^u\left(\mathbf{\bar{e}}', \mathbf{\bar{v}}', \mathbf{u}\right) & \mathbf{\bar{v}}' &= \rho^{v \rightarrow u}\left(V'\right) &\text{(Global block).} \label{eq:gn-functions} \end{align} \)
where \(E'_i = \left\{\left(\mathbf{e}'_k, r_k, s_k \right)\right\}_{r_k=i,\; k=1:N^e}\) contains the updated edge features for edges whose receiver node is the \(i\)th node, \(E' = \bigcup_i E_i' = \left\{\left(\mathbf{e}'_k, r_k, s_k \right)\right\}_{k=1:N^e}\) is the set of updated edges, and \(V'=\left\{\mathbf{v}'_i\right\}_{i=1:N^v}\) is the set of updated nodes.
We will define an interaction network model similar to Ref. [1], but just modeling the particle-particle interactions. It will take as input all of the tracks (with 48 features) without truncating or zero-padding. Another modification is the use of batch normalization [] layers to improve the stability of the training.
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.nn import EdgeConv, global_mean_pool
from torch.nn import Sequential as Seq, Linear as Lin, ReLU, BatchNorm1d
from torch_scatter import scatter_mean
from torch_geometric.nn import MetaLayer
inputs = 48
hidden = 128
outputs = 2
class EdgeBlock(torch.nn.Module):
def __init__(self):
super(EdgeBlock, self).__init__()
self.edge_mlp = Seq(Lin(inputs * 2, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, hidden))
def forward(self, src, dest, edge_attr, u, batch):
out = torch.cat([src, dest], 1)
return self.edge_mlp(out)
class NodeBlock(torch.nn.Module):
def __init__(self):
super(NodeBlock, self).__init__()
self.node_mlp_1 = Seq(Lin(inputs + hidden, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, hidden))
self.node_mlp_2 = Seq(Lin(inputs + hidden, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, hidden))
def forward(self, x, edge_index, edge_attr, u, batch):
row, col = edge_index
out = torch.cat([x[row], edge_attr], dim=1)
out = self.node_mlp_1(out)
out = scatter_mean(out, col, dim=0, dim_size=x.size(0))
out = torch.cat([x, out], dim=1)
return self.node_mlp_2(out)
class GlobalBlock(torch.nn.Module):
def __init__(self):
super(GlobalBlock, self).__init__()
self.global_mlp = Seq(Lin(hidden, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, outputs))
def forward(self, x, edge_index, edge_attr, u, batch):
out = scatter_mean(x, batch, dim=0)
return self.global_mlp(out)
class InteractionNetwork(torch.nn.Module):
def __init__(self):
super(InteractionNetwork, self).__init__()
self.interactionnetwork = MetaLayer(EdgeBlock(), NodeBlock(), GlobalBlock())
self.bn = BatchNorm1d(inputs)
def forward(self, x, edge_index, batch):
x = self.bn(x)
x, edge_attr, u = self.interactionnetwork(x, edge_index, None, None, batch)
return u
model = InteractionNetwork().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
Define training loop#
@torch.no_grad()
def test(model, loader, total, batch_size, leave=False):
model.eval()
xentropy = nn.CrossEntropyLoss(reduction="mean")
sum_loss = 0.0
t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)
for i, data in t:
data = data.to(device)
y = torch.argmax(data.y, dim=1)
batch_output = model(data.x, data.edge_index, data.batch)
batch_loss_item = xentropy(batch_output, y).item()
sum_loss += batch_loss_item
t.set_description("loss = %.5f" % (batch_loss_item))
t.refresh() # to show immediately the update
return sum_loss / (i + 1)
def train(model, optimizer, loader, total, batch_size, leave=False):
model.train()
xentropy = nn.CrossEntropyLoss(reduction="mean")
sum_loss = 0.0
t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)
for i, data in t:
data = data.to(device)
y = torch.argmax(data.y, dim=1)
optimizer.zero_grad()
batch_output = model(data.x, data.edge_index, data.batch)
batch_loss = xentropy(batch_output, y)
batch_loss.backward()
batch_loss_item = batch_loss.item()
t.set_description("loss = %.5f" % batch_loss_item)
t.refresh() # to show immediately the update
sum_loss += batch_loss_item
optimizer.step()
return sum_loss / (i + 1)
Define training, validation, testing data generators#
from torch_geometric.data import Data, DataListLoader, Batch
from torch.utils.data import random_split
def collate(items):
l = sum(items, [])
return Batch.from_data_list(l)
torch.manual_seed(0)
valid_frac = 0.20
full_length = len(graph_dataset)
valid_num = int(valid_frac * full_length)
batch_size = 32
train_dataset, valid_dataset = random_split(graph_dataset, [full_length - valid_num, valid_num])
train_loader = DataListLoader(train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True)
train_loader.collate_fn = collate
valid_loader = DataListLoader(valid_dataset, batch_size=batch_size, pin_memory=True, shuffle=False)
valid_loader.collate_fn = collate
test_loader = DataListLoader(test_dataset, batch_size=batch_size, pin_memory=True, shuffle=False)
test_loader.collate_fn = collate
train_samples = len(train_dataset)
valid_samples = len(valid_dataset)
test_samples = len(test_dataset)
print(full_length)
print(train_samples)
print(valid_samples)
print(test_samples)
Train#
import os.path as osp
n_epochs = 10
stale_epochs = 0
best_valid_loss = 99999
patience = 5
t = tqdm(range(0, n_epochs))
for epoch in t:
loss = train(model, optimizer, train_loader, train_samples, batch_size, leave=bool(epoch == n_epochs - 1))
valid_loss = test(model, valid_loader, valid_samples, batch_size, leave=bool(epoch == n_epochs - 1))
print("Epoch: {:02d}, Training Loss: {:.4f}".format(epoch, loss))
print(" Validation Loss: {:.4f}".format(valid_loss))
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
modpath = osp.join("interactionnetwork_best.pth")
print("New best model saved to:", modpath)
torch.save(model.state_dict(), modpath)
stale_epochs = 0
else:
print("Stale epoch")
stale_epochs += 1
if stale_epochs >= patience:
print("Early stopping after %i stale epochs" % patience)
break
Evaluate on testing data#
model.eval()
t = tqdm(enumerate(test_loader), total=test_samples / batch_size)
y_test = []
y_predict = []
for i, data in t:
data = data.to(device)
batch_output = model(data.x, data.edge_index, data.batch)
y_predict.append(batch_output.detach().cpu().numpy())
y_test.append(data.y.cpu().numpy())
y_test = np.concatenate(y_test)
y_predict = np.concatenate(y_predict)
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import mplhep as hep
plt.style.use(hep.style.ROOT)
# create ROC curves
fpr_gnn, tpr_gnn, threshold_gnn = roc_curve(y_test[:, 1], y_predict[:, 1])
# plot ROC curves
plt.figure()
plt.plot(tpr_gnn, fpr_gnn, lw=2.5, label="GNN, AUC = {:.1f}%".format(auc(fpr_gnn, tpr_gnn) * 100))
plt.xlabel(r"True positive rate")
plt.ylabel(r"False positive rate")
plt.semilogy()
plt.ylim(0.001, 1)
plt.xlim(0, 1)
plt.grid(True)
plt.legend(loc="upper left")
plt.show()
Exercises#
Replace the Interaction Network model with a Deep Set model and check the performance. A partial coding of the model is below:
class DeepSet(torch.nn.Module):
def __init__(self):
super(DeepSet, self).__init__()
self.bn = BatchNorm1d(inputs)
self.node_mlp = Seq(Lin(inputs, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, hidden))
self.global_mlp = Seq(Lin(hidden, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, outputs))
def forward(self, x, edge_index, batch):
x = self.bn(x)
x = ... # apply node_mlp to x
mean = ... # take mean over all node features x in one graph (see GlobalBlock of InteractionNetwork)
out = ... # apply global mlp to mean
return out