{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Hands-on 06: Graph data and GNNs: Tagging Higgs boson jets\n", "\n", "This week, we will look at graph neural networks using the PyTorch Geometric library: . See {cite:p}`PyTorchGeometric` for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch_geometric\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "from tqdm.notebook import tqdm\n", "import numpy as np\n", "\n", "local = False" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import yaml\n", "\n", "with open(\"definitions.yml\") as file:\n", " # The FullLoader parameter handles the conversion from YAML\n", " # scalar values to Python the dictionary format\n", " definitions = yaml.load(file, Loader=yaml.FullLoader)\n", "\n", "features = definitions[\"features\"]\n", "spectators = definitions[\"spectators\"]\n", "labels = definitions[\"labels\"]\n", "\n", "nfeatures = definitions[\"nfeatures\"]\n", "nspectators = definitions[\"nspectators\"]\n", "nlabels = definitions[\"nlabels\"]\n", "ntracks = definitions[\"ntracks\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Graph datasets\n", "Here we have to define the graph dataset. We do this in a separate class following this example: https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-larger-datasets\n", "\n", "Formally, a graph is represented by a triplet $\\mathcal G = (\\mathbf{u}, V, E)$, consisting of a graph-level, or *global*, feature vector $\\mathbf{u}$, a set of $N^v$ nodes $V$, and a set of $N^e$ edges $E$.\n", "The nodes are given by $V = \\{\\mathbf{v}_i\\}_{i=1:N^v}$, where $\\mathbf{v}_i$ represents the $i$th node's attributes.\n", "The edges connect pairs of nodes and are given by $E = \\{\\left(\\mathbf{e}_k, r_k, s_k\\right)\\}_{k=1:N^e}$, where $\\mathbf{e}_k$ represents the $k$th edge's attributes, and $r_k$ and $s_k$ are the indices of the \"receiver\" and \n", "\"sender\" nodes, respectively, connected by the $k$th edge (from the sender node to the receiver node).\n", "The receiver and sender index vectors are an alternative way of encoding the directed adjacency matrix.\n", "\n", "\"attributes\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from GraphDataset import GraphDataset\n", "\n", "if local:\n", " file_names = [\"/teams/DSC180A_FA20_A00/b06particlephysics/train/ntuple_merged_10.root\"]\n", " file_names_test = [\"/teams/DSC180A_FA20_A00/b06particlephysics/test/ntuple_merged_0.root\"]\n", "else:\n", " file_names = [\n", " \"root://eospublic.cern.ch//eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/train/ntuple_merged_10.root\"\n", " ]\n", " file_names_test = [\n", " \"root://eospublic.cern.ch//eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/test/ntuple_merged_0.root\"\n", " ]\n", "\n", "graph_dataset = GraphDataset(\n", " \"gdata_train\", features, labels, spectators, n_events=1000, n_events_merge=1, file_names=file_names\n", ")\n", "\n", "test_dataset = GraphDataset(\n", " \"gdata_test\", features, labels, spectators, n_events=2000, n_events_merge=1, file_names=file_names_test\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Graph neural network\n", "\n", "Here, we recapitulate the \"graph network\" (GN) formalism {cite:p}`battaglia2018relational`, which generalizes various GNNs and other similar methods.\n", "GNs are graph-to-graph mappings, whose output graphs have the same node and edge structure as the input. \n", "Formally, a GN block contains three \"update\" functions, $\\phi$, and three \"aggregation\" functions, $\\rho$.\n", "The stages of processing in a single GN block are:\n", "\n", "$\n", "\\begin{align}\n", " \\mathbf{e}'_k &= \\phi^e\\left(\\mathbf{e}_k, \\mathbf{v}_{r_k}, \\mathbf{v}_{s_k}, \\mathbf{u} \\right) & \\mathbf{\\bar{e}}'_i &= \\rho^{e \\rightarrow v}\\left(E'_i\\right) & \\text{(Edge block),}\\\\\n", " \\mathbf{v}'_i &= \\phi^v\\left(\\mathbf{\\bar{e}}'_i, \\mathbf{v}_i, \\mathbf{u}\\right) & \n", " \\mathbf{\\bar{e}}' &= \\rho^{e \\rightarrow u}\\left(E'\\right) & \\text{(Node block),}\\\\\n", " \\mathbf{u}' &= \\phi^u\\left(\\mathbf{\\bar{e}}', \\mathbf{\\bar{v}}', \\mathbf{u}\\right) & \n", " \\mathbf{\\bar{v}}' &= \\rho^{v \\rightarrow u}\\left(V'\\right) &\\text{(Global block).}\n", " \\label{eq:gn-functions}\n", "\\end{align}\n", "$\n", "\n", "where $E'_i = \\left\\{\\left(\\mathbf{e}'_k, r_k, s_k \\right)\\right\\}_{r_k=i,\\; k=1:N^e}$ contains the updated edge features for edges whose receiver node is the $i$th node, $E' = \\bigcup_i E_i' = \\left\\{\\left(\\mathbf{e}'_k, r_k, s_k \\right)\\right\\}_{k=1:N^e}$ is the set of updated edges, and $V'=\\left\\{\\mathbf{v}'_i\\right\\}_{i=1:N^v}$ is the set of updated nodes.\n", "\n", "\"GN\n", "\n", "We will define an interaction network model similar to Ref. {cite:p}`Moreno:2019neq`, but just modeling the particle-particle interactions. It will take as input all of the tracks (with 48 features) without truncating or zero-padding. Another modification is the use of batch normalization {cite:p}`bn` layers to improve the stability of the training." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch_geometric.transforms as T\n", "from torch_geometric.nn import EdgeConv, global_mean_pool\n", "from torch.nn import Sequential as Seq, Linear as Lin, ReLU, BatchNorm1d\n", "from torch_scatter import scatter_mean\n", "from torch_geometric.nn import MetaLayer\n", "\n", "inputs = 48\n", "hidden = 128\n", "outputs = 2\n", "\n", "\n", "class EdgeBlock(torch.nn.Module):\n", " def __init__(self):\n", " super(EdgeBlock, self).__init__()\n", " self.edge_mlp = Seq(Lin(inputs * 2, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, hidden))\n", "\n", " def forward(self, src, dest, edge_attr, u, batch):\n", " out = torch.cat([src, dest], 1)\n", " return self.edge_mlp(out)\n", "\n", "\n", "class NodeBlock(torch.nn.Module):\n", " def __init__(self):\n", " super(NodeBlock, self).__init__()\n", " self.node_mlp_1 = Seq(Lin(inputs + hidden, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, hidden))\n", " self.node_mlp_2 = Seq(Lin(inputs + hidden, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, hidden))\n", "\n", " def forward(self, x, edge_index, edge_attr, u, batch):\n", " row, col = edge_index\n", " out = torch.cat([x[row], edge_attr], dim=1)\n", " out = self.node_mlp_1(out)\n", " out = scatter_mean(out, col, dim=0, dim_size=x.size(0))\n", " out = torch.cat([x, out], dim=1)\n", " return self.node_mlp_2(out)\n", "\n", "\n", "class GlobalBlock(torch.nn.Module):\n", " def __init__(self):\n", " super(GlobalBlock, self).__init__()\n", " self.global_mlp = Seq(Lin(hidden, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, outputs))\n", "\n", " def forward(self, x, edge_index, edge_attr, u, batch):\n", " out = scatter_mean(x, batch, dim=0)\n", " return self.global_mlp(out)\n", "\n", "\n", "class InteractionNetwork(torch.nn.Module):\n", " def __init__(self):\n", " super(InteractionNetwork, self).__init__()\n", " self.interactionnetwork = MetaLayer(EdgeBlock(), NodeBlock(), GlobalBlock())\n", " self.bn = BatchNorm1d(inputs)\n", "\n", " def forward(self, x, edge_index, batch):\n", "\n", " x = self.bn(x)\n", " x, edge_attr, u = self.interactionnetwork(x, edge_index, None, None, batch)\n", " return u\n", "\n", "\n", "model = InteractionNetwork().to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define training loop" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@torch.no_grad()\n", "def test(model, loader, total, batch_size, leave=False):\n", " model.eval()\n", "\n", " xentropy = nn.CrossEntropyLoss(reduction=\"mean\")\n", "\n", " sum_loss = 0.0\n", " t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)\n", " for i, data in t:\n", " data = data.to(device)\n", " y = torch.argmax(data.y, dim=1)\n", " batch_output = model(data.x, data.edge_index, data.batch)\n", " batch_loss_item = xentropy(batch_output, y).item()\n", " sum_loss += batch_loss_item\n", " t.set_description(\"loss = %.5f\" % (batch_loss_item))\n", " t.refresh() # to show immediately the update\n", "\n", " return sum_loss / (i + 1)\n", "\n", "\n", "def train(model, optimizer, loader, total, batch_size, leave=False):\n", " model.train()\n", "\n", " xentropy = nn.CrossEntropyLoss(reduction=\"mean\")\n", "\n", " sum_loss = 0.0\n", " t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)\n", " for i, data in t:\n", " data = data.to(device)\n", " y = torch.argmax(data.y, dim=1)\n", " optimizer.zero_grad()\n", " batch_output = model(data.x, data.edge_index, data.batch)\n", " batch_loss = xentropy(batch_output, y)\n", " batch_loss.backward()\n", " batch_loss_item = batch_loss.item()\n", " t.set_description(\"loss = %.5f\" % batch_loss_item)\n", " t.refresh() # to show immediately the update\n", " sum_loss += batch_loss_item\n", " optimizer.step()\n", "\n", " return sum_loss / (i + 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define training, validation, testing data generators" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torch_geometric.data import Data, DataListLoader, Batch\n", "from torch.utils.data import random_split\n", "\n", "\n", "def collate(items):\n", " l = sum(items, [])\n", " return Batch.from_data_list(l)\n", "\n", "\n", "torch.manual_seed(0)\n", "valid_frac = 0.20\n", "full_length = len(graph_dataset)\n", "valid_num = int(valid_frac * full_length)\n", "batch_size = 32\n", "\n", "train_dataset, valid_dataset = random_split(graph_dataset, [full_length - valid_num, valid_num])\n", "\n", "train_loader = DataListLoader(train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True)\n", "train_loader.collate_fn = collate\n", "valid_loader = DataListLoader(valid_dataset, batch_size=batch_size, pin_memory=True, shuffle=False)\n", "valid_loader.collate_fn = collate\n", "test_loader = DataListLoader(test_dataset, batch_size=batch_size, pin_memory=True, shuffle=False)\n", "test_loader.collate_fn = collate\n", "\n", "\n", "train_samples = len(train_dataset)\n", "valid_samples = len(valid_dataset)\n", "test_samples = len(test_dataset)\n", "print(full_length)\n", "print(train_samples)\n", "print(valid_samples)\n", "print(test_samples)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os.path as osp\n", "\n", "n_epochs = 10\n", "stale_epochs = 0\n", "best_valid_loss = 99999\n", "patience = 5\n", "t = tqdm(range(0, n_epochs))\n", "\n", "for epoch in t:\n", " loss = train(model, optimizer, train_loader, train_samples, batch_size, leave=bool(epoch == n_epochs - 1))\n", " valid_loss = test(model, valid_loader, valid_samples, batch_size, leave=bool(epoch == n_epochs - 1))\n", " print(\"Epoch: {:02d}, Training Loss: {:.4f}\".format(epoch, loss))\n", " print(\" Validation Loss: {:.4f}\".format(valid_loss))\n", "\n", " if valid_loss < best_valid_loss:\n", " best_valid_loss = valid_loss\n", " modpath = osp.join(\"interactionnetwork_best.pth\")\n", " print(\"New best model saved to:\", modpath)\n", " torch.save(model.state_dict(), modpath)\n", " stale_epochs = 0\n", " else:\n", " print(\"Stale epoch\")\n", " stale_epochs += 1\n", " if stale_epochs >= patience:\n", " print(\"Early stopping after %i stale epochs\" % patience)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluate on testing data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.eval()\n", "t = tqdm(enumerate(test_loader), total=test_samples / batch_size)\n", "y_test = []\n", "y_predict = []\n", "for i, data in t:\n", " data = data.to(device)\n", " batch_output = model(data.x, data.edge_index, data.batch)\n", " y_predict.append(batch_output.detach().cpu().numpy())\n", " y_test.append(data.y.cpu().numpy())\n", "y_test = np.concatenate(y_test)\n", "y_predict = np.concatenate(y_predict)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import roc_curve, auc\n", "import matplotlib.pyplot as plt\n", "import mplhep as hep\n", "\n", "plt.style.use(hep.style.ROOT)\n", "# create ROC curves\n", "fpr_gnn, tpr_gnn, threshold_gnn = roc_curve(y_test[:, 1], y_predict[:, 1])\n", "\n", "# plot ROC curves\n", "plt.figure()\n", "plt.plot(tpr_gnn, fpr_gnn, lw=2.5, label=\"GNN, AUC = {:.1f}%\".format(auc(fpr_gnn, tpr_gnn) * 100))\n", "plt.xlabel(r\"True positive rate\")\n", "plt.ylabel(r\"False positive rate\")\n", "plt.semilogy()\n", "plt.ylim(0.001, 1)\n", "plt.xlim(0, 1)\n", "plt.grid(True)\n", "plt.legend(loc=\"upper left\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" } }, "nbformat": 4, "nbformat_minor": 2 }