{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Lorentz-Equivariant GNN\n",
"\n",
"Now we will look at graph neural networks using the PyTorch Geometric library: . See {cite:p}`PyTorchGeometric` for more details."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# For Colab\n",
"!pip install torch_geometric\n",
"!pip install torch_sparse\n",
"!pip install torch_scatter"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch_geometric\n",
"\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"from tqdm.notebook import tqdm\n",
"import numpy as np\n",
"\n",
"local = False"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# For Colab\n",
"\n",
"!pip install wget\n",
"import wget\n",
"\n",
"!pip install -U PyYAML\n",
"!pip install uproot\n",
"!pip install awkward\n",
"!pip install mplhep"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import yaml\n",
"import os.path\n",
"\n",
"# WGET for colab\n",
"if not os.path.exists(\"definitions_lorentz.yml\"):\n",
" url = \"https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/definitions_lorentz.yml\"\n",
" definitionsFile = wget.download(url)\n",
"\n",
"with open(\"definitions_lorentz.yml\") as file:\n",
" # The FullLoader parameter handles the conversion from YAML\n",
" # scalar values to Python the dictionary format\n",
" definitions = yaml.load(file, Loader=yaml.FullLoader)\n",
"\n",
"features = definitions[\"features\"]\n",
"spectators = definitions[\"spectators\"]\n",
"labels = definitions[\"labels\"]\n",
"\n",
"nfeatures = definitions[\"nfeatures\"]\n",
"nspectators = definitions[\"nspectators\"]\n",
"nlabels = definitions[\"nlabels\"]\n",
"ntracks = definitions[\"ntracks\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Graph datasets\n",
"Here we have to define the graph dataset. We do this in a separate class following this example: https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-larger-datasets\n",
"\n",
"Formally, a graph is represented by a triplet $\\mathcal G = (\\mathbf{u}, V, E)$, consisting of a graph-level, or *global*, feature vector $\\mathbf{u}$, a set of $N^v$ nodes $V$, and a set of $N^e$ edges $E$.\n",
"The nodes are given by $V = \\{\\mathbf{v}_i\\}_{i=1:N^v}$, where $\\mathbf{v}_i$ represents the $i$th node's attributes.\n",
"The edges connect pairs of nodes and are given by $E = \\{\\left(\\mathbf{e}_k, r_k, s_k\\right)\\}_{k=1:N^e}$, where $\\mathbf{e}_k$ represents the $k$th edge's attributes, and $r_k$ and $s_k$ are the indices of the \"receiver\" and \n",
"\"sender\" nodes, respectively, connected by the $k$th edge (from the sender node to the receiver node).\n",
"The receiver and sender index vectors are an alternative way of encoding the directed adjacency matrix.\n",
"\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing...\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "51eb09799508446ab490fb88420cc6d3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/8000 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/wmccorma/IAIFI_Summer_School_2022/iaifi-summer-school-PATRICK/book/LorentzGraphDataset.py:153: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/distiller/project/conda/conda-bld/pytorch_1646756029501/work/torch/csrc/utils/tensor_new.cpp:210.)\n",
" x = torch.tensor(fourvec, dtype=torch.float).T\n",
"Done!\n",
"Processing...\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "85319282abd04bd881fe887ee3782bb1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/2000 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"['track_pt', 'track_mass', 'track_ptrel', 'track_erel', 'track_etarel', 'track_phirel']\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Done!\n"
]
}
],
"source": [
"# If in colab\n",
"if not os.path.exists(\"LorentzGraphDataset.py\"):\n",
" urlDSD = \"https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/LorentzGraphDataset.py\"\n",
" DSD = wget.download(urlDSD)\n",
"if not os.path.exists(\"utils.py\"):\n",
" urlUtils = \"https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/utils.py\"\n",
" utils = wget.download(urlUtils)\n",
"\n",
"\n",
"from LorentzGraphDataset import LorentzGraphDataset\n",
"\n",
"# For Colab\n",
"if not os.path.exists(\"ntuple_merged_90.root\"):\n",
" urlFILE = \"http://opendata.cern.ch/eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/train/ntuple_merged_90.root\"\n",
" dataFILE = wget.download(urlFILE)\n",
"file_names = [\"ntuple_merged_90.root\"]\n",
"\n",
"##If you pulled github locally\n",
"# if local:\n",
"# file_names = [\n",
"# \"/teams/DSC180A_FA20_A00/b06particlephysics/train/ntuple_merged_10.root\"\n",
"# ]\n",
"# file_names_test = [\n",
"# \"/teams/DSC180A_FA20_A00/b06particlephysics/test/ntuple_merged_0.root\"\n",
"# ]\n",
"# else:\n",
"# file_names = [\n",
"# \"root://eospublic.cern.ch//eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/train/ntuple_merged_10.root\"\n",
"# ]\n",
"# file_names_test = [\n",
"# \"root://eospublic.cern.ch//eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/test/ntuple_merged_0.root\"\n",
"# ]\n",
"\n",
"graph_dataset = LorentzGraphDataset(\n",
" \"ldata_train\",\n",
" features,\n",
" labels,\n",
" spectators,\n",
" start_event=0,\n",
" stop_event=8000,\n",
" n_events_merge=1,\n",
" file_names=file_names,\n",
")\n",
"\n",
"test_dataset = LorentzGraphDataset(\n",
" \"ldata_test\",\n",
" features,\n",
" labels,\n",
" spectators,\n",
" start_event=8001,\n",
" stop_event=10001,\n",
" n_events_merge=1,\n",
" file_names=file_names,\n",
")\n",
"print(test_dataset.features)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Graph neural network\n",
"\n",
"Here, we recapitulate the \"graph network\" (GN) formalism {cite:p}`battaglia2018relational`, which generalizes various GNNs and other similar methods.\n",
"GNs are graph-to-graph mappings, whose output graphs have the same node and edge structure as the input. \n",
"Formally, a GN block contains three \"update\" functions, $\\phi$, and three \"aggregation\" functions, $\\rho$.\n",
"The stages of processing in a single GN block are:\n",
"\n",
"$\n",
"\\begin{align}\n",
" \\mathbf{e}'_k &= \\phi^e\\left(\\mathbf{e}_k, \\mathbf{v}_{r_k}, \\mathbf{v}_{s_k}, \\mathbf{u} \\right) & \\mathbf{\\bar{e}}'_i &= \\rho^{e \\rightarrow v}\\left(E'_i\\right) & \\text{(Edge block),}\\\\\n",
" \\mathbf{v}'_i &= \\phi^v\\left(\\mathbf{\\bar{e}}'_i, \\mathbf{v}_i, \\mathbf{u}\\right) & \n",
" \\mathbf{\\bar{e}}' &= \\rho^{e \\rightarrow u}\\left(E'\\right) & \\text{(Node block),}\\\\\n",
" \\mathbf{u}' &= \\phi^u\\left(\\mathbf{\\bar{e}}', \\mathbf{\\bar{v}}', \\mathbf{u}\\right) & \n",
" \\mathbf{\\bar{v}}' &= \\rho^{v \\rightarrow u}\\left(V'\\right) &\\text{(Global block).}\n",
" \\label{eq:gn-functions}\n",
"\\end{align}\n",
"$\n",
"\n",
"where $E'_i = \\left\\{\\left(\\mathbf{e}'_k, r_k, s_k \\right)\\right\\}_{r_k=i,\\; k=1:N^e}$ contains the updated edge features for edges whose receiver node is the $i$th node, $E' = \\bigcup_i E_i' = \\left\\{\\left(\\mathbf{e}'_k, r_k, s_k \\right)\\right\\}_{k=1:N^e}$ is the set of updated edges, and $V'=\\left\\{\\mathbf{v}'_i\\right\\}_{i=1:N^v}$ is the set of updated nodes.\n",
"\n",
"
\n",
"\n",
"We will define an interaction network model similar to Ref. {cite:p}`Moreno:2019neq`, but just modeling the particle-particle interactions. It will take as input all of the tracks (with 48 features) without truncating or zero-padding. Another modification is the use of batch normalization {cite:p}`bn` layers to improve the stability of the training."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch_geometric.transforms as T\n",
"from torch_geometric.nn import EdgeConv, global_mean_pool\n",
"from torch.nn import Sequential as Seq, Linear as Lin, ReLU, BatchNorm1d\n",
"from torch_scatter import scatter_mean\n",
"from torch_geometric.nn import MetaLayer\n",
"\n",
"hidden = 16\n",
"outputs = 2\n",
"\n",
"\n",
"class LorentzEdgeBlock(torch.nn.Module):\n",
" def __init__(self):\n",
" super(LorentzEdgeBlock, self).__init__()\n",
" self.edge_mlp = Seq(Lin(4, hidden), ReLU(), Lin(hidden, hidden))\n",
" self.minkowski = torch.from_numpy(\n",
" np.array(\n",
" [\n",
" [-1.0, 0.0, 0.0, 0.0],\n",
" [0.0, 1.0, 0.0, 0.0],\n",
" [0.0, 0.0, 1.0, 0.0],\n",
" [0.0, 0.0, 0.0, 1.0],\n",
" ],\n",
" dtype=np.float32,\n",
" )\n",
" )\n",
"\n",
" def psi(self, x):\n",
" return torch.sign(x) * torch.log(torch.abs(x) + 1)\n",
"\n",
" def innerprod(self, x1, x2):\n",
" return torch.sum(\n",
" torch.mul(torch.matmul(x1, self.minkowski), x2), 1, keepdim=True\n",
" )\n",
"\n",
" def forward(self, src, dest, edge_attr, u, batch):\n",
" out = torch.cat(\n",
" [\n",
" self.innerprod(src, src),\n",
" self.innerprod(src, dest),\n",
" self.psi(self.innerprod(dest, dest)),\n",
" self.psi(self.innerprod(src - dest, src - dest)),\n",
" ],\n",
" dim=1,\n",
" )\n",
" return self.edge_mlp(out)\n",
"\n",
"\n",
"class LorentzNodeBlock(torch.nn.Module):\n",
" def __init__(self):\n",
" super(LorentzNodeBlock, self).__init__()\n",
" self.node_mlp_1 = Seq(Lin(1 + hidden, hidden), ReLU(), Lin(hidden, hidden))\n",
" self.node_mlp_2 = Seq(Lin(1 + hidden, hidden), ReLU(), Lin(hidden, hidden))\n",
" self.minkowski = torch.from_numpy(\n",
" np.array(\n",
" [\n",
" [-1.0, 0.0, 0.0, 0.0],\n",
" [0.0, 1.0, 0.0, 0.0],\n",
" [0.0, 0.0, 1.0, 0.0],\n",
" [0.0, 0.0, 0.0, 1.0],\n",
" ],\n",
" dtype=np.float32,\n",
" )\n",
" )\n",
"\n",
" def innerprod(self, x1, x2):\n",
" return torch.sum(\n",
" torch.mul(torch.matmul(x1, self.minkowski), x2), 1, keepdim=True\n",
" )\n",
"\n",
" def forward(self, x, edge_index, edge_attr, u, batch):\n",
" row, col = edge_index\n",
" out = torch.cat([self.innerprod(x[row], x[row]), edge_attr], dim=1)\n",
" out = self.node_mlp_1(out)\n",
" out = scatter_mean(out, col, dim=0, dim_size=x.size(0))\n",
" out = torch.cat([self.innerprod(x, x), out], dim=1)\n",
" return self.node_mlp_2(out)\n",
"\n",
"\n",
"class GlobalBlock(torch.nn.Module):\n",
" def __init__(self):\n",
" super(GlobalBlock, self).__init__()\n",
" self.global_mlp = Seq(Lin(hidden, hidden), ReLU(), Lin(hidden, outputs))\n",
"\n",
" def forward(self, x, edge_index, edge_attr, u, batch):\n",
" out = scatter_mean(x, batch, dim=0)\n",
" return self.global_mlp(out)\n",
"\n",
"\n",
"class LorentzInteractionNetwork(torch.nn.Module):\n",
" def __init__(self):\n",
" super(LorentzInteractionNetwork, self).__init__()\n",
" self.lorentzinteractionnetwork = MetaLayer(\n",
" LorentzEdgeBlock(), LorentzNodeBlock(), GlobalBlock()\n",
" )\n",
"\n",
" def forward(self, x, edge_index, batch):\n",
"\n",
" x, edge_attr, u = self.lorentzinteractionnetwork(\n",
" x, edge_index, None, None, batch\n",
" )\n",
" return u\n",
"\n",
"\n",
"model = LorentzInteractionNetwork().to(device)\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define training loop"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"@torch.no_grad()\n",
"def test(model, loader, total, batch_size, leave=False):\n",
" model.eval()\n",
"\n",
" xentropy = nn.CrossEntropyLoss(reduction=\"mean\")\n",
"\n",
" sum_loss = 0.0\n",
" t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)\n",
" for i, data in t:\n",
" data = data.to(device)\n",
" y = torch.argmax(data.y, dim=1)\n",
" batch_output = model(data.x, data.edge_index, data.batch)\n",
" batch_loss_item = xentropy(batch_output, y).item()\n",
" sum_loss += batch_loss_item\n",
" t.set_description(\"loss = %.5f\" % (batch_loss_item))\n",
" t.refresh() # to show immediately the update\n",
"\n",
" return sum_loss / (i + 1)\n",
"\n",
"\n",
"def train(model, optimizer, loader, total, batch_size, leave=False):\n",
" model.train()\n",
"\n",
" xentropy = nn.CrossEntropyLoss(reduction=\"mean\")\n",
"\n",
" sum_loss = 0.0\n",
" t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)\n",
" for i, data in t:\n",
" data = data.to(device)\n",
" y = torch.argmax(data.y, dim=1)\n",
" optimizer.zero_grad()\n",
" batch_output = model(data.x, data.edge_index, data.batch)\n",
" batch_loss = xentropy(batch_output, y)\n",
" batch_loss.backward()\n",
" batch_loss_item = batch_loss.item()\n",
" t.set_description(\"loss = %.5f\" % batch_loss_item)\n",
" t.refresh() # to show immediately the update\n",
" sum_loss += batch_loss_item\n",
" optimizer.step()\n",
"\n",
" return sum_loss / (i + 1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define training, validation, testing data generators"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"7484\n",
"5988\n",
"1496\n",
"1887\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/wmccorma/miniconda3/envs/ml-iaifi/lib/python3.9/site-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.DataListLoader' is deprecated, use 'loader.DataListLoader' instead\n",
" warnings.warn(out)\n"
]
}
],
"source": [
"from torch_geometric.data import Data, DataListLoader, Batch\n",
"from torch.utils.data import random_split\n",
"\n",
"\n",
"def collate(items):\n",
" l = sum(items, [])\n",
" return Batch.from_data_list(l)\n",
"\n",
"\n",
"torch.manual_seed(0)\n",
"valid_frac = 0.20\n",
"full_length = len(graph_dataset)\n",
"valid_num = int(valid_frac * full_length)\n",
"batch_size = 32\n",
"\n",
"train_dataset, valid_dataset = random_split(\n",
" graph_dataset, [full_length - valid_num, valid_num]\n",
")\n",
"\n",
"train_loader = DataListLoader(\n",
" train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True\n",
")\n",
"train_loader.collate_fn = collate\n",
"valid_loader = DataListLoader(\n",
" valid_dataset, batch_size=batch_size, pin_memory=True, shuffle=False\n",
")\n",
"valid_loader.collate_fn = collate\n",
"test_loader = DataListLoader(\n",
" test_dataset, batch_size=batch_size, pin_memory=True, shuffle=False\n",
")\n",
"test_loader.collate_fn = collate\n",
"\n",
"\n",
"train_samples = len(train_dataset)\n",
"valid_samples = len(valid_dataset)\n",
"test_samples = len(test_dataset)\n",
"print(full_length)\n",
"print(train_samples)\n",
"print(valid_samples)\n",
"print(test_samples)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4218389d3f8f473689e0b7f5acdf982e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/2 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/187.125 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/46.75 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 00, Training Loss: 0.7810\n",
" Validation Loss: 0.7829\n",
"New best model saved to: lorentznetwork_best.pth\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1114825255ea478bbf1252cd406b9927",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/187.125 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ff4de51a4b5947e88522ef67466ae6e3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/46.75 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 01, Training Loss: 0.7805\n",
" Validation Loss: 0.7821\n",
"New best model saved to: lorentznetwork_best.pth\n"
]
}
],
"source": [
"import os.path as osp\n",
"\n",
"n_epochs = 2\n",
"stale_epochs = 0\n",
"best_valid_loss = 99999\n",
"patience = 5\n",
"t = tqdm(range(0, n_epochs))\n",
"\n",
"for epoch in t:\n",
" loss = train(\n",
" model,\n",
" optimizer,\n",
" train_loader,\n",
" train_samples,\n",
" batch_size,\n",
" leave=bool(epoch == n_epochs - 1),\n",
" )\n",
" valid_loss = test(\n",
" model,\n",
" valid_loader,\n",
" valid_samples,\n",
" batch_size,\n",
" leave=bool(epoch == n_epochs - 1),\n",
" )\n",
" print(\"Epoch: {:02d}, Training Loss: {:.4f}\".format(epoch, loss))\n",
" print(\" Validation Loss: {:.4f}\".format(valid_loss))\n",
"\n",
" if valid_loss < best_valid_loss:\n",
" best_valid_loss = valid_loss\n",
" modpath = osp.join(\"lorentznetwork_best.pth\")\n",
" print(\"New best model saved to:\", modpath)\n",
" torch.save(model.state_dict(), modpath)\n",
" stale_epochs = 0\n",
" else:\n",
" print(\"Stale epoch\")\n",
" stale_epochs += 1\n",
" if stale_epochs >= patience:\n",
" print(\"Early stopping after %i stale epochs\" % patience)\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluate on testing data"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6d13d8453de64f0e866ef3be894b6b90",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/58.96875 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# In case you need to load the model from a pth file\n",
"# Trained on 4 vectors (as above in notebook)\n",
"# urlPTH = \"https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/lorentznetwork_best.pth\"\n",
"# pthFile = wget.download(urlPTH)\n",
"# model.load_state_dict(torch.load(\"lorentznetwork_best.pth\"))\n",
"\n",
"model.eval()\n",
"t = tqdm(enumerate(test_loader), total=test_samples / batch_size)\n",
"y_test = []\n",
"y_predict = []\n",
"track_pt = []\n",
"for i, data in t:\n",
" data = data.to(device)\n",
" batchmask = torch.cat([-1.0 * torch.ones(1), data.batch[:-1]], dim=0)\n",
" track_pt.append(data.x[batchmask != data.batch, 0])\n",
" batch_output = model(data.x, data.edge_index, data.batch)\n",
" y_predict.append(batch_output.detach().cpu().numpy())\n",
" y_test.append(data.y.cpu().numpy())\n",
"track_pt = np.concatenate(track_pt)\n",
"y_test = np.concatenate(y_test)\n",
"y_predict = np.concatenate(y_predict)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import roc_curve, auc\n",
"import matplotlib.pyplot as plt\n",
"import mplhep as hep\n",
"\n",
"plt.style.use(hep.style.ROOT)\n",
"# create ROC curves\n",
"fpr_lorentz, tpr_lorentz, threshold_lorentz = roc_curve(y_test[:, 1], -y_predict[:, 1])\n",
"with open(\"lorentz_roc.npy\", \"wb\") as f:\n",
" np.save(f, fpr_lorentz)\n",
" np.save(f, tpr_lorentz)\n",
" np.save(f, threshold_lorentz)\n",
"\n",
"\n",
"For colab:\n",
"if not os.path.exists(\"deepset_roc.py\"):\n",
" urlROC = \"https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/deepset_roc.npy\"\n",
" rocFile = wget.download(urlROC) \n",
" \n",
"with open(\"deepset_roc.npy\", \"rb\") as f:\n",
" fpr_deepset = np.load(f)\n",
" tpr_deepset = np.load(f)\n",
" threshold_deepset = np.load(f)\n",
"\n",
"# plot ROC curves\n",
"plt.figure()\n",
"plt.plot(\n",
" tpr_deepset,\n",
" fpr_deepset,\n",
" lw=2.5,\n",
" label=\"DeepSet, AUC = {:.1f}%\".format(auc(fpr_deepset, tpr_deepset) * 100),\n",
")\n",
"plt.plot(\n",
" tpr_lorentz,\n",
" fpr_lorentz,\n",
" lw=2.5,\n",
" label=\"Lorentz GNN, AUC = {:.1f}%\".format(auc(fpr_lorentz, tpr_lorentz) * 100),\n",
")\n",
"plt.xlabel(r\"True positive rate\")\n",
"plt.ylabel(r\"False positive rate\")\n",
"plt.ylim(0.001, 1)\n",
"plt.xlim(0, 1)\n",
"plt.grid(True)\n",
"plt.legend(loc=\"upper left\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import roc_curve, auc\n",
"import matplotlib.pyplot as plt\n",
"import mplhep as hep\n",
"\n",
"plt.style.use(hep.style.ROOT)\n",
"# create ROC curves\n",
"pt_split = 5.0\n",
"fpr_lorentz_lowb, tpr_lorentz_lowb, threshold_lorentz_lowb = roc_curve(\n",
" y_test[track_pt <= pt_split, 1], -y_predict[track_pt <= pt_split, 1]\n",
")\n",
"fpr_lorentz_highb, tpr_lorentz_highb, threshold_lorentz_highb = roc_curve(\n",
" y_test[track_pt > pt_split, 1], -y_predict[track_pt > pt_split, 1]\n",
")\n",
"\n",
"# plot ROC curves\n",
"plt.figure()\n",
"plt.plot(\n",
" tpr_lorentz_lowb,\n",
" fpr_lorentz_lowb,\n",
" lw=2.5,\n",
" label=\"Lorentz GNN (low pT), AUC = {:.1f}%\".format(\n",
" auc(fpr_lorentz_lowb, tpr_lorentz_lowb) * 100\n",
" ),\n",
")\n",
"plt.plot(\n",
" tpr_lorentz_highb,\n",
" fpr_lorentz_highb,\n",
" lw=2.5,\n",
" label=\"Lorentz GNN (high pT), AUC = {:.1f}%\".format(\n",
" auc(fpr_lorentz_highb, tpr_lorentz_highb) * 100\n",
" ),\n",
")\n",
"plt.xlabel(r\"True positive rate\")\n",
"plt.ylabel(r\"False positive rate\")\n",
"plt.ylim(0.001, 1)\n",
"plt.xlim(0, 1)\n",
"plt.grid(True)\n",
"plt.legend(loc=\"upper left\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}