{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Interaction Network GNN\n", "\n", "Now we will look at graph neural networks using the PyTorch Geometric library: . See {cite:p}`PyTorchGeometric` for more details." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# For Colab\n", "!pip install torch_geometric\n", "!pip install torch_sparse\n", "!pip install torch_scatter\n", "\n", "import torch\n", "import torch_geometric\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "from tqdm.notebook import tqdm\n", "import numpy as np\n", "\n", "local = False" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# For Colab\n", "\n", "!pip install wget\n", "import wget\n", "\n", "!pip install -U PyYAML\n", "!pip install uproot\n", "!pip install awkward\n", "!pip install mplhep" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "import yaml\n", "import os.path\n", "\n", "# WGET for colab\n", "if not os.path.exists(\"definitions.yml\"):\n", " url = \"https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/definitions.yml\"\n", " definitionsFile = wget.download(url)\n", "\n", "with open(\"definitions.yml\") as file:\n", " # The FullLoader parameter handles the conversion from YAML\n", " # scalar values to Python the dictionary format\n", " definitions = yaml.load(file, Loader=yaml.FullLoader)\n", "\n", "\n", "# You can test with using only 4-vectors by using:\n", "# if not os.path.exists(\"definitions_lorentz.yml\"):\n", "# url = \"https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/definitions_lorentz.yml\"\n", "# definitionsFile = wget.download(url)\n", "# with open('definitions_lorentz.yml') as file:\n", "# # The FullLoader parameter handles the conversion from YAML\n", "# # scalar values to Python the dictionary format\n", "# definitions = yaml.load(file, Loader=yaml.FullLoader)\n", "\n", "features = definitions[\"features\"]\n", "spectators = definitions[\"spectators\"]\n", "labels = definitions[\"labels\"]\n", "\n", "nfeatures = definitions[\"nfeatures\"]\n", "nspectators = definitions[\"nspectators\"]\n", "nlabels = definitions[\"nlabels\"]\n", "ntracks = definitions[\"ntracks\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Graph datasets\n", "Here we have to define the graph dataset. We do this in a separate class following this example: https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-larger-datasets\n", "\n", "Formally, a graph is represented by a triplet $\\mathcal G = (\\mathbf{u}, V, E)$, consisting of a graph-level, or *global*, feature vector $\\mathbf{u}$, a set of $N^v$ nodes $V$, and a set of $N^e$ edges $E$.\n", "The nodes are given by $V = \\{\\mathbf{v}_i\\}_{i=1:N^v}$, where $\\mathbf{v}_i$ represents the $i$th node's attributes.\n", "The edges connect pairs of nodes and are given by $E = \\{\\left(\\mathbf{e}_k, r_k, s_k\\right)\\}_{k=1:N^e}$, where $\\mathbf{e}_k$ represents the $k$th edge's attributes, and $r_k$ and $s_k$ are the indices of the \"receiver\" and \n", "\"sender\" nodes, respectively, connected by the $k$th edge (from the sender node to the receiver node).\n", "The receiver and sender index vectors are an alternative way of encoding the directed adjacency matrix.\n", "\n", "\"attributes\"" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# If in colab\n", "if not os.path.exists(\"GraphDataset.py\"):\n", " urlDSD = \"https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/GraphDataset.py\"\n", " DSD = wget.download(urlDSD)\n", "if not os.path.exists(\"utils.py\"):\n", " urlUtils = \"https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/utils.py\"\n", " utils = wget.download(urlUtils)\n", "\n", "\n", "from GraphDataset import GraphDataset\n", "\n", "\n", "# For Colab\n", "if not os.path.exists(\"ntuple_merged_90.root\"):\n", " urlFILE = \"http://opendata.cern.ch/eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/train/ntuple_merged_90.root\"\n", " dataFILE = wget.download(urlFILE)\n", "file_names = [\"ntuple_merged_90.root\"]\n", "\n", "##If you pulled github locally\n", "# if local:\n", "# file_names = [\n", "# \"/teams/DSC180A_FA20_A00/b06particlephysics/train/ntuple_merged_10.root\"\n", "# ]\n", "# file_names_test = [\n", "# \"/teams/DSC180A_FA20_A00/b06particlephysics/test/ntuple_merged_0.root\"\n", "# ]\n", "# else:\n", "# file_names = [\n", "# \"root://eospublic.cern.ch//eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/train/ntuple_merged_10.root\"\n", "# ]\n", "# file_names_test = [\n", "# \"root://eospublic.cern.ch//eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/test/ntuple_merged_0.root\"\n", "# ]\n", "\n", "graph_dataset = GraphDataset(\n", " \"gdata_train\",\n", " features,\n", " labels,\n", " spectators,\n", " start_event=0,\n", " stop_event=8000,\n", " n_events_merge=1,\n", " file_names=file_names,\n", ")\n", "\n", "test_dataset = GraphDataset(\n", " \"gdata_test\",\n", " features,\n", " labels,\n", " spectators,\n", " start_event=8001,\n", " stop_event=10001,\n", " n_events_merge=1,\n", " file_names=file_names,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Graph neural network\n", "\n", "Here, we recapitulate the \"graph network\" (GN) formalism {cite:p}`battaglia2018relational`, which generalizes various GNNs and other similar methods.\n", "GNs are graph-to-graph mappings, whose output graphs have the same node and edge structure as the input. \n", "Formally, a GN block contains three \"update\" functions, $\\phi$, and three \"aggregation\" functions, $\\rho$.\n", "The stages of processing in a single GN block are:\n", "\n", "$\n", "\\begin{align}\n", " \\mathbf{e}'_k &= \\phi^e\\left(\\mathbf{e}_k, \\mathbf{v}_{r_k}, \\mathbf{v}_{s_k}, \\mathbf{u} \\right) & \\mathbf{\\bar{e}}'_i &= \\rho^{e \\rightarrow v}\\left(E'_i\\right) & \\text{(Edge block),}\\\\\n", " \\mathbf{v}'_i &= \\phi^v\\left(\\mathbf{\\bar{e}}'_i, \\mathbf{v}_i, \\mathbf{u}\\right) & \n", " \\mathbf{\\bar{e}}' &= \\rho^{e \\rightarrow u}\\left(E'\\right) & \\text{(Node block),}\\\\\n", " \\mathbf{u}' &= \\phi^u\\left(\\mathbf{\\bar{e}}', \\mathbf{\\bar{v}}', \\mathbf{u}\\right) & \n", " \\mathbf{\\bar{v}}' &= \\rho^{v \\rightarrow u}\\left(V'\\right) &\\text{(Global block).}\n", " \\label{eq:gn-functions}\n", "\\end{align}\n", "$\n", "\n", "where $E'_i = \\left\\{\\left(\\mathbf{e}'_k, r_k, s_k \\right)\\right\\}_{r_k=i,\\; k=1:N^e}$ contains the updated edge features for edges whose receiver node is the $i$th node, $E' = \\bigcup_i E_i' = \\left\\{\\left(\\mathbf{e}'_k, r_k, s_k \\right)\\right\\}_{k=1:N^e}$ is the set of updated edges, and $V'=\\left\\{\\mathbf{v}'_i\\right\\}_{i=1:N^v}$ is the set of updated nodes.\n", "\n", "\"GN\n", "\n", "We will define an interaction network model similar to Ref. {cite:p}`Moreno:2019neq`, but just modeling the particle-particle interactions. It will take as input all of the tracks (with 48 features) without truncating or zero-padding. Another modification is the use of batch normalization {cite:p}`bn` layers to improve the stability of the training." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch_geometric.transforms as T\n", "from torch_geometric.nn import EdgeConv, global_mean_pool\n", "from torch.nn import Sequential as Seq, Linear as Lin, ReLU, BatchNorm1d\n", "from torch_scatter import scatter_mean\n", "from torch_geometric.nn import MetaLayer\n", "\n", "inputs = 48\n", "hidden = 128\n", "outputs = 2\n", "\n", "\n", "class EdgeBlock(torch.nn.Module):\n", " def __init__(self):\n", " super(EdgeBlock, self).__init__()\n", " self.edge_mlp = Seq(\n", " Lin(inputs * 2, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, hidden)\n", " )\n", "\n", " def forward(self, src, dest, edge_attr, u, batch):\n", " out = torch.cat([src, dest], 1)\n", " return self.edge_mlp(out)\n", "\n", "\n", "class NodeBlock(torch.nn.Module):\n", " def __init__(self):\n", " super(NodeBlock, self).__init__()\n", " self.node_mlp_1 = Seq(\n", " Lin(inputs + hidden, hidden),\n", " BatchNorm1d(hidden),\n", " ReLU(),\n", " Lin(hidden, hidden),\n", " )\n", " self.node_mlp_2 = Seq(\n", " Lin(inputs + hidden, hidden),\n", " BatchNorm1d(hidden),\n", " ReLU(),\n", " Lin(hidden, hidden),\n", " )\n", "\n", " def forward(self, x, edge_index, edge_attr, u, batch):\n", " row, col = edge_index\n", " out = torch.cat([x[row], edge_attr], dim=1)\n", " out = self.node_mlp_1(out)\n", " out = scatter_mean(out, col, dim=0, dim_size=x.size(0))\n", " out = torch.cat([x, out], dim=1)\n", " return self.node_mlp_2(out)\n", "\n", "\n", "class GlobalBlock(torch.nn.Module):\n", " def __init__(self):\n", " super(GlobalBlock, self).__init__()\n", " self.global_mlp = Seq(\n", " Lin(hidden, hidden), BatchNorm1d(hidden), ReLU(), Lin(hidden, outputs)\n", " )\n", "\n", " def forward(self, x, edge_index, edge_attr, u, batch):\n", " out = scatter_mean(x, batch, dim=0)\n", " return self.global_mlp(out)\n", "\n", "\n", "class InteractionNetwork(torch.nn.Module):\n", " def __init__(self):\n", " super(InteractionNetwork, self).__init__()\n", " self.interactionnetwork = MetaLayer(EdgeBlock(), NodeBlock(), GlobalBlock())\n", " self.bn = BatchNorm1d(inputs)\n", "\n", " def forward(self, x, edge_index, batch):\n", "\n", " x = self.bn(x)\n", " x, edge_attr, u = self.interactionnetwork(x, edge_index, None, None, batch)\n", " return u\n", "\n", "\n", "model = InteractionNetwork().to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define training loop" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "@torch.no_grad()\n", "def test(model, loader, total, batch_size, leave=False):\n", " model.eval()\n", "\n", " xentropy = nn.CrossEntropyLoss(reduction=\"mean\")\n", "\n", " sum_loss = 0.0\n", " t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)\n", " for i, data in t:\n", " data = data.to(device)\n", " y = torch.argmax(data.y, dim=1)\n", " batch_output = model(data.x, data.edge_index, data.batch)\n", " batch_loss_item = xentropy(batch_output, y).item()\n", " sum_loss += batch_loss_item\n", " t.set_description(\"loss = %.5f\" % (batch_loss_item))\n", " t.refresh() # to show immediately the update\n", "\n", " return sum_loss / (i + 1)\n", "\n", "\n", "def train(model, optimizer, loader, total, batch_size, leave=False):\n", " model.train()\n", "\n", " xentropy = nn.CrossEntropyLoss(reduction=\"mean\")\n", "\n", " sum_loss = 0.0\n", " t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)\n", " for i, data in t:\n", " data = data.to(device)\n", " y = torch.argmax(data.y, dim=1)\n", " optimizer.zero_grad()\n", " batch_output = model(data.x, data.edge_index, data.batch)\n", " batch_loss = xentropy(batch_output, y)\n", " batch_loss.backward()\n", " batch_loss_item = batch_loss.item()\n", " t.set_description(\"loss = %.5f\" % batch_loss_item)\n", " t.refresh() # to show immediately the update\n", " sum_loss += batch_loss_item\n", " optimizer.step()\n", "\n", " return sum_loss / (i + 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define training, validation, testing data generators" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "7501\n", "6001\n", "1500\n", "1886\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/wmccorma/miniconda3/envs/ml-iaifi/lib/python3.9/site-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.DataListLoader' is deprecated, use 'loader.DataListLoader' instead\n", " warnings.warn(out)\n" ] } ], "source": [ "from torch_geometric.data import Data, DataListLoader, Batch\n", "from torch.utils.data import random_split\n", "\n", "\n", "def collate(items):\n", " l = sum(items, [])\n", " return Batch.from_data_list(l)\n", "\n", "\n", "torch.manual_seed(0)\n", "valid_frac = 0.20\n", "full_length = len(graph_dataset)\n", "valid_num = int(valid_frac * full_length)\n", "batch_size = 32\n", "\n", "train_dataset, valid_dataset = random_split(\n", " graph_dataset, [full_length - valid_num, valid_num]\n", ")\n", "\n", "train_loader = DataListLoader(\n", " train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True\n", ")\n", "train_loader.collate_fn = collate\n", "valid_loader = DataListLoader(\n", " valid_dataset, batch_size=batch_size, pin_memory=True, shuffle=False\n", ")\n", "valid_loader.collate_fn = collate\n", "test_loader = DataListLoader(\n", " test_dataset, batch_size=batch_size, pin_memory=True, shuffle=False\n", ")\n", "test_loader.collate_fn = collate\n", "\n", "\n", "train_samples = len(train_dataset)\n", "valid_samples = len(valid_dataset)\n", "test_samples = len(test_dataset)\n", "print(full_length)\n", "print(train_samples)\n", "print(valid_samples)\n", "print(test_samples)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "413e118e0e9041beb711de477240779e", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10 [00:00= patience:\n", " print(\"Early stopping after %i stale epochs\" % patience)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluate on testing data" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d5d757e6dcf445f1b35aa3b7f219fdb3", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/58.9375 [00:00" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import roc_curve, auc\n", "import matplotlib.pyplot as plt\n", "import mplhep as hep\n", "\n", "plt.style.use(hep.style.ROOT)\n", "# create ROC curves\n", "fpr_gnn, tpr_gnn, threshold_gnn = roc_curve(y_test[:, 1], y_predict[:, 1])\n", "with open(\"gnn_roc.npy\", \"wb\") as f:\n", " np.save(f, fpr_gnn)\n", " np.save(f, tpr_gnn)\n", " np.save(f, threshold_gnn)\n", "\n", "\n", "# For colab:\n", "if not os.path.exists(\"deepset_roc.py\"):\n", " urlROC = \"https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/deepset_roc.npy\"\n", " rocFile = wget.download(urlROC)\n", "\n", "with open(\"deepset_roc.npy\", \"rb\") as f:\n", " fpr_deepset = np.load(f)\n", " tpr_deepset = np.load(f)\n", " threshold_deepset = np.load(f)\n", "\n", "# plot ROC curves\n", "plt.figure()\n", "plt.plot(\n", " tpr_deepset,\n", " fpr_deepset,\n", " lw=2.5,\n", " label=\"DeepSet, AUC = {:.1f}%\".format(auc(fpr_deepset, tpr_deepset) * 100),\n", ")\n", "plt.plot(\n", " tpr_gnn,\n", " fpr_gnn,\n", " lw=2.5,\n", " label=\"GNN, AUC = {:.1f}%\".format(auc(fpr_gnn, tpr_gnn) * 100),\n", ")\n", "plt.xlabel(r\"True positive rate\")\n", "plt.ylabel(r\"False positive rate\")\n", "plt.semilogy()\n", "plt.ylim(0.001, 1)\n", "plt.xlim(0, 1)\n", "plt.grid(True)\n", "plt.legend(loc=\"upper left\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 2 }