{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Lorentz-Equivariant GNN\n", "\n", "Now we will look at graph neural networks using the PyTorch Geometric library: . See {cite:p}`PyTorchGeometric` for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# For Colab\n", "!pip install torch_geometric\n", "!pip install torch_sparse\n", "!pip install torch_scatter" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch_geometric\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "from tqdm.notebook import tqdm\n", "import numpy as np\n", "\n", "local = False" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# For Colab\n", "\n", "!pip install wget\n", "import wget\n", "\n", "!pip install -U PyYAML\n", "!pip install uproot\n", "!pip install awkward\n", "!pip install mplhep" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import yaml\n", "import os.path\n", "\n", "# WGET for colab\n", "if not os.path.exists(\"definitions_lorentz.yml\"):\n", " url = \"https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/definitions_lorentz.yml\"\n", " definitionsFile = wget.download(url)\n", "\n", "with open(\"definitions_lorentz.yml\") as file:\n", " # The FullLoader parameter handles the conversion from YAML\n", " # scalar values to Python the dictionary format\n", " definitions = yaml.load(file, Loader=yaml.FullLoader)\n", "\n", "features = definitions[\"features\"]\n", "spectators = definitions[\"spectators\"]\n", "labels = definitions[\"labels\"]\n", "\n", "nfeatures = definitions[\"nfeatures\"]\n", "nspectators = definitions[\"nspectators\"]\n", "nlabels = definitions[\"nlabels\"]\n", "ntracks = definitions[\"ntracks\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Graph datasets\n", "Here we have to define the graph dataset. We do this in a separate class following this example: https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-larger-datasets\n", "\n", "Formally, a graph is represented by a triplet $\\mathcal G = (\\mathbf{u}, V, E)$, consisting of a graph-level, or *global*, feature vector $\\mathbf{u}$, a set of $N^v$ nodes $V$, and a set of $N^e$ edges $E$.\n", "The nodes are given by $V = \\{\\mathbf{v}_i\\}_{i=1:N^v}$, where $\\mathbf{v}_i$ represents the $i$th node's attributes.\n", "The edges connect pairs of nodes and are given by $E = \\{\\left(\\mathbf{e}_k, r_k, s_k\\right)\\}_{k=1:N^e}$, where $\\mathbf{e}_k$ represents the $k$th edge's attributes, and $r_k$ and $s_k$ are the indices of the \"receiver\" and \n", "\"sender\" nodes, respectively, connected by the $k$th edge (from the sender node to the receiver node).\n", "The receiver and sender index vectors are an alternative way of encoding the directed adjacency matrix.\n", "\n", "\"attributes\"" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Processing...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "51eb09799508446ab490fb88420cc6d3", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/8000 [00:00\n", "\n", "We will define an interaction network model similar to Ref. {cite:p}`Moreno:2019neq`, but just modeling the particle-particle interactions. It will take as input all of the tracks (with 48 features) without truncating or zero-padding. Another modification is the use of batch normalization {cite:p}`bn` layers to improve the stability of the training." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch_geometric.transforms as T\n", "from torch_geometric.nn import EdgeConv, global_mean_pool\n", "from torch.nn import Sequential as Seq, Linear as Lin, ReLU, BatchNorm1d\n", "from torch_scatter import scatter_mean\n", "from torch_geometric.nn import MetaLayer\n", "\n", "hidden = 16\n", "outputs = 2\n", "\n", "\n", "class LorentzEdgeBlock(torch.nn.Module):\n", " def __init__(self):\n", " super(LorentzEdgeBlock, self).__init__()\n", " self.edge_mlp = Seq(Lin(4, hidden), ReLU(), Lin(hidden, hidden))\n", " self.minkowski = torch.from_numpy(\n", " np.array(\n", " [\n", " [-1.0, 0.0, 0.0, 0.0],\n", " [0.0, 1.0, 0.0, 0.0],\n", " [0.0, 0.0, 1.0, 0.0],\n", " [0.0, 0.0, 0.0, 1.0],\n", " ],\n", " dtype=np.float32,\n", " )\n", " )\n", "\n", " def psi(self, x):\n", " return torch.sign(x) * torch.log(torch.abs(x) + 1)\n", "\n", " def innerprod(self, x1, x2):\n", " return torch.sum(\n", " torch.mul(torch.matmul(x1, self.minkowski), x2), 1, keepdim=True\n", " )\n", "\n", " def forward(self, src, dest, edge_attr, u, batch):\n", " out = torch.cat(\n", " [\n", " self.innerprod(src, src),\n", " self.innerprod(src, dest),\n", " self.psi(self.innerprod(dest, dest)),\n", " self.psi(self.innerprod(src - dest, src - dest)),\n", " ],\n", " dim=1,\n", " )\n", " return self.edge_mlp(out)\n", "\n", "\n", "class LorentzNodeBlock(torch.nn.Module):\n", " def __init__(self):\n", " super(LorentzNodeBlock, self).__init__()\n", " self.node_mlp_1 = Seq(Lin(1 + hidden, hidden), ReLU(), Lin(hidden, hidden))\n", " self.node_mlp_2 = Seq(Lin(1 + hidden, hidden), ReLU(), Lin(hidden, hidden))\n", " self.minkowski = torch.from_numpy(\n", " np.array(\n", " [\n", " [-1.0, 0.0, 0.0, 0.0],\n", " [0.0, 1.0, 0.0, 0.0],\n", " [0.0, 0.0, 1.0, 0.0],\n", " [0.0, 0.0, 0.0, 1.0],\n", " ],\n", " dtype=np.float32,\n", " )\n", " )\n", "\n", " def innerprod(self, x1, x2):\n", " return torch.sum(\n", " torch.mul(torch.matmul(x1, self.minkowski), x2), 1, keepdim=True\n", " )\n", "\n", " def forward(self, x, edge_index, edge_attr, u, batch):\n", " row, col = edge_index\n", " out = torch.cat([self.innerprod(x[row], x[row]), edge_attr], dim=1)\n", " out = self.node_mlp_1(out)\n", " out = scatter_mean(out, col, dim=0, dim_size=x.size(0))\n", " out = torch.cat([self.innerprod(x, x), out], dim=1)\n", " return self.node_mlp_2(out)\n", "\n", "\n", "class GlobalBlock(torch.nn.Module):\n", " def __init__(self):\n", " super(GlobalBlock, self).__init__()\n", " self.global_mlp = Seq(Lin(hidden, hidden), ReLU(), Lin(hidden, outputs))\n", "\n", " def forward(self, x, edge_index, edge_attr, u, batch):\n", " out = scatter_mean(x, batch, dim=0)\n", " return self.global_mlp(out)\n", "\n", "\n", "class LorentzInteractionNetwork(torch.nn.Module):\n", " def __init__(self):\n", " super(LorentzInteractionNetwork, self).__init__()\n", " self.lorentzinteractionnetwork = MetaLayer(\n", " LorentzEdgeBlock(), LorentzNodeBlock(), GlobalBlock()\n", " )\n", "\n", " def forward(self, x, edge_index, batch):\n", "\n", " x, edge_attr, u = self.lorentzinteractionnetwork(\n", " x, edge_index, None, None, batch\n", " )\n", " return u\n", "\n", "\n", "model = LorentzInteractionNetwork().to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define training loop" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "@torch.no_grad()\n", "def test(model, loader, total, batch_size, leave=False):\n", " model.eval()\n", "\n", " xentropy = nn.CrossEntropyLoss(reduction=\"mean\")\n", "\n", " sum_loss = 0.0\n", " t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)\n", " for i, data in t:\n", " data = data.to(device)\n", " y = torch.argmax(data.y, dim=1)\n", " batch_output = model(data.x, data.edge_index, data.batch)\n", " batch_loss_item = xentropy(batch_output, y).item()\n", " sum_loss += batch_loss_item\n", " t.set_description(\"loss = %.5f\" % (batch_loss_item))\n", " t.refresh() # to show immediately the update\n", "\n", " return sum_loss / (i + 1)\n", "\n", "\n", "def train(model, optimizer, loader, total, batch_size, leave=False):\n", " model.train()\n", "\n", " xentropy = nn.CrossEntropyLoss(reduction=\"mean\")\n", "\n", " sum_loss = 0.0\n", " t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)\n", " for i, data in t:\n", " data = data.to(device)\n", " y = torch.argmax(data.y, dim=1)\n", " optimizer.zero_grad()\n", " batch_output = model(data.x, data.edge_index, data.batch)\n", " batch_loss = xentropy(batch_output, y)\n", " batch_loss.backward()\n", " batch_loss_item = batch_loss.item()\n", " t.set_description(\"loss = %.5f\" % batch_loss_item)\n", " t.refresh() # to show immediately the update\n", " sum_loss += batch_loss_item\n", " optimizer.step()\n", "\n", " return sum_loss / (i + 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define training, validation, testing data generators" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "7484\n", "5988\n", "1496\n", "1887\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/wmccorma/miniconda3/envs/ml-iaifi/lib/python3.9/site-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.DataListLoader' is deprecated, use 'loader.DataListLoader' instead\n", " warnings.warn(out)\n" ] } ], "source": [ "from torch_geometric.data import Data, DataListLoader, Batch\n", "from torch.utils.data import random_split\n", "\n", "\n", "def collate(items):\n", " l = sum(items, [])\n", " return Batch.from_data_list(l)\n", "\n", "\n", "torch.manual_seed(0)\n", "valid_frac = 0.20\n", "full_length = len(graph_dataset)\n", "valid_num = int(valid_frac * full_length)\n", "batch_size = 32\n", "\n", "train_dataset, valid_dataset = random_split(\n", " graph_dataset, [full_length - valid_num, valid_num]\n", ")\n", "\n", "train_loader = DataListLoader(\n", " train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True\n", ")\n", "train_loader.collate_fn = collate\n", "valid_loader = DataListLoader(\n", " valid_dataset, batch_size=batch_size, pin_memory=True, shuffle=False\n", ")\n", "valid_loader.collate_fn = collate\n", "test_loader = DataListLoader(\n", " test_dataset, batch_size=batch_size, pin_memory=True, shuffle=False\n", ")\n", "test_loader.collate_fn = collate\n", "\n", "\n", "train_samples = len(train_dataset)\n", "valid_samples = len(valid_dataset)\n", "test_samples = len(test_dataset)\n", "print(full_length)\n", "print(train_samples)\n", "print(valid_samples)\n", "print(test_samples)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4218389d3f8f473689e0b7f5acdf982e", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/2 [00:00= patience:\n", " print(\"Early stopping after %i stale epochs\" % patience)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluate on testing data" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6d13d8453de64f0e866ef3be894b6b90", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/58.96875 [00:00" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import roc_curve, auc\n", "import matplotlib.pyplot as plt\n", "import mplhep as hep\n", "\n", "plt.style.use(hep.style.ROOT)\n", "# create ROC curves\n", "fpr_lorentz, tpr_lorentz, threshold_lorentz = roc_curve(y_test[:, 1], -y_predict[:, 1])\n", "with open(\"lorentz_roc.npy\", \"wb\") as f:\n", " np.save(f, fpr_lorentz)\n", " np.save(f, tpr_lorentz)\n", " np.save(f, threshold_lorentz)\n", "\n", "\n", "For colab:\n", "if not os.path.exists(\"deepset_roc.py\"):\n", " urlROC = \"https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/deepset_roc.npy\"\n", " rocFile = wget.download(urlROC) \n", " \n", "with open(\"deepset_roc.npy\", \"rb\") as f:\n", " fpr_deepset = np.load(f)\n", " tpr_deepset = np.load(f)\n", " threshold_deepset = np.load(f)\n", "\n", "# plot ROC curves\n", "plt.figure()\n", "plt.plot(\n", " tpr_deepset,\n", " fpr_deepset,\n", " lw=2.5,\n", " label=\"DeepSet, AUC = {:.1f}%\".format(auc(fpr_deepset, tpr_deepset) * 100),\n", ")\n", "plt.plot(\n", " tpr_lorentz,\n", " fpr_lorentz,\n", " lw=2.5,\n", " label=\"Lorentz GNN, AUC = {:.1f}%\".format(auc(fpr_lorentz, tpr_lorentz) * 100),\n", ")\n", "plt.xlabel(r\"True positive rate\")\n", "plt.ylabel(r\"False positive rate\")\n", "plt.ylim(0.001, 1)\n", "plt.xlim(0, 1)\n", "plt.grid(True)\n", "plt.legend(loc=\"upper left\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import roc_curve, auc\n", "import matplotlib.pyplot as plt\n", "import mplhep as hep\n", "\n", "plt.style.use(hep.style.ROOT)\n", "# create ROC curves\n", "pt_split = 5.0\n", "fpr_lorentz_lowb, tpr_lorentz_lowb, threshold_lorentz_lowb = roc_curve(\n", " y_test[track_pt <= pt_split, 1], -y_predict[track_pt <= pt_split, 1]\n", ")\n", "fpr_lorentz_highb, tpr_lorentz_highb, threshold_lorentz_highb = roc_curve(\n", " y_test[track_pt > pt_split, 1], -y_predict[track_pt > pt_split, 1]\n", ")\n", "\n", "# plot ROC curves\n", "plt.figure()\n", "plt.plot(\n", " tpr_lorentz_lowb,\n", " fpr_lorentz_lowb,\n", " lw=2.5,\n", " label=\"Lorentz GNN (low pT), AUC = {:.1f}%\".format(\n", " auc(fpr_lorentz_lowb, tpr_lorentz_lowb) * 100\n", " ),\n", ")\n", "plt.plot(\n", " tpr_lorentz_highb,\n", " fpr_lorentz_highb,\n", " lw=2.5,\n", " label=\"Lorentz GNN (high pT), AUC = {:.1f}%\".format(\n", " auc(fpr_lorentz_highb, tpr_lorentz_highb) * 100\n", " ),\n", ")\n", "plt.xlabel(r\"True positive rate\")\n", "plt.ylabel(r\"False positive rate\")\n", "plt.ylim(0.001, 1)\n", "plt.xlim(0, 1)\n", "plt.grid(True)\n", "plt.legend(loc=\"upper left\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 2 }