{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Week 8 Notebook: Extending the Model\n",
"===============================================================\n",
"\n",
"This week, we will look at graph neural networks using the PyTorch Geometric library: . See {cite:p}`PyTorchGeometric` for more details."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch_geometric\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"from tqdm.notebook import tqdm\n",
"import numpy as np\n",
"local = False"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import yaml\n",
"\n",
"with open('definitions.yml') as file:\n",
" # The FullLoader parameter handles the conversion from YAML\n",
" # scalar values to Python the dictionary format\n",
" definitions = yaml.load(file, Loader=yaml.FullLoader)\n",
" \n",
"features = definitions['features']\n",
"spectators = definitions['spectators']\n",
"labels = definitions['labels']\n",
"\n",
"nfeatures = definitions['nfeatures']\n",
"nspectators = definitions['nspectators']\n",
"nlabels = definitions['nlabels']\n",
"ntracks = definitions['ntracks']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Graph datasets\n",
"Here we have to define the graph dataset. We do this in a separate class following this example: https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-larger-datasets\n",
"\n",
"Formally, a graph is represented by a triplet $\\mathcal G = (\\mathbf{u}, V, E)$, consisting of a graph-level, or *global*, feature vector $\\mathbf{u}$, a set of $N^v$ nodes $V$, and a set of $N^e$ edges $E$.\n",
"The nodes are given by $V = \\{\\mathbf{v}_i\\}_{i=1:N^v}$, where $\\mathbf{v}_i$ represents the $i$th node's attributes.\n",
"The edges connect pairs of nodes and are given by $E = \\{\\left(\\mathbf{e}_k, r_k, s_k\\right)\\}_{k=1:N^e}$, where $\\mathbf{e}_k$ represents the $k$th edge's attributes, and $r_k$ and $s_k$ are the indices of the \"receiver\" and \n",
"\"sender\" nodes, respectively, connected by the $k$th edge (from the sender node to the receiver node).\n",
"The receiver and sender index vectors are an alternative way of encoding the directed adjacency matrix.\n",
"\n",
"
"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from GraphDataset import GraphDataset\n",
"if local:\n",
" file_names = ['/teams/DSC180A_FA20_A00/b06particlephysics/train/ntuple_merged_10.root']\n",
" file_names_test = ['/teams/DSC180A_FA20_A00/b06particlephysics/test/ntuple_merged_0.root']\n",
"else:\n",
" file_names = ['root://eospublic.cern.ch//eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/train/ntuple_merged_10.root']\n",
" file_names_test = ['root://eospublic.cern.ch//eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/test/ntuple_merged_0.root']\n",
"\n",
"graph_dataset = GraphDataset('gdata_train', features, labels, spectators, n_events=1000, n_events_merge=1, \n",
" file_names=file_names)\n",
"\n",
"test_dataset = GraphDataset('gdata_test', features, labels, spectators, n_events=2000, n_events_merge=1, \n",
" file_names=file_names_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Graph neural network\n",
"\n",
"Here, we recapitulate the \"graph network\" (GN) formalism {cite:p}`battaglia2018relational`, which generalizes various GNNs and other similar methods.\n",
"GNs are graph-to-graph mappings, whose output graphs have the same node and edge structure as the input. \n",
"Formally, a GN block contains three \"update\" functions, $\\phi$, and three \"aggregation\" functions, $\\rho$.\n",
"The stages of processing in a single GN block are:\n",
"\n",
"$\n",
"\\begin{align}\n",
" \\mathbf{e}'_k &= \\phi^e\\left(\\mathbf{e}_k, \\mathbf{v}_{r_k}, \\mathbf{v}_{s_k}, \\mathbf{u} \\right) & \\mathbf{\\bar{e}}'_i &= \\rho^{e \\rightarrow v}\\left(E'_i\\right) & \\text{(Edge block),}\\\\\n",
" \\mathbf{v}'_i &= \\phi^v\\left(\\mathbf{\\bar{e}}'_i, \\mathbf{v}_i, \\mathbf{u}\\right) & \n",
" \\mathbf{\\bar{e}}' &= \\rho^{e \\rightarrow u}\\left(E'\\right) & \\text{(Node block),}\\\\\n",
" \\mathbf{u}' &= \\phi^u\\left(\\mathbf{\\bar{e}}', \\mathbf{\\bar{v}}', \\mathbf{u}\\right) & \n",
" \\mathbf{\\bar{v}}' &= \\rho^{v \\rightarrow u}\\left(V'\\right) &\\text{(Global block).}\n",
" \\label{eq:gn-functions}\n",
"\\end{align}\n",
"$\n",
"\n",
"where $E'_i = \\left\\{\\left(\\mathbf{e}'_k, r_k, s_k \\right)\\right\\}_{r_k=i,\\; k=1:N^e}$ contains the updated edge features for edges whose receiver node is the $i$th node, $E' = \\bigcup_i E_i' = \\left\\{\\left(\\mathbf{e}'_k, r_k, s_k \\right)\\right\\}_{k=1:N^e}$ is the set of updated edges, and $V'=\\left\\{\\mathbf{v}'_i\\right\\}_{i=1:N^v}$ is the set of updated nodes.\n",
"\n",
"
\n",
"\n",
"We will define an interaction network model similar to Ref. {cite:p}`Moreno:2019neq`, but just modeling the particle-particle interactions. It will take as input all of the tracks (with 48 features) without truncating or zero-padding. Another modification is the use of batch normalization {cite:p}`bn` layers to improve the stability of the training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch_geometric.transforms as T\n",
"from torch_geometric.nn import EdgeConv, global_mean_pool\n",
"from torch.nn import Sequential as Seq, Linear as Lin, ReLU, BatchNorm1d\n",
"from torch_scatter import scatter_mean\n",
"from torch_geometric.nn import MetaLayer\n",
"\n",
"inputs = 48\n",
"hidden = 128\n",
"outputs = 2\n",
"\n",
"class EdgeBlock(torch.nn.Module):\n",
" def __init__(self):\n",
" super(EdgeBlock, self).__init__()\n",
" self.edge_mlp = Seq(Lin(inputs*2, hidden), \n",
" BatchNorm1d(hidden),\n",
" ReLU(),\n",
" Lin(hidden, hidden))\n",
"\n",
" def forward(self, src, dest, edge_attr, u, batch):\n",
" out = torch.cat([src, dest], 1)\n",
" return self.edge_mlp(out)\n",
"\n",
"class NodeBlock(torch.nn.Module):\n",
" def __init__(self):\n",
" super(NodeBlock, self).__init__()\n",
" self.node_mlp_1 = Seq(Lin(inputs+hidden, hidden), \n",
" BatchNorm1d(hidden),\n",
" ReLU(), \n",
" Lin(hidden, hidden))\n",
" self.node_mlp_2 = Seq(Lin(inputs+hidden, hidden), \n",
" BatchNorm1d(hidden),\n",
" ReLU(), \n",
" Lin(hidden, hidden))\n",
"\n",
" def forward(self, x, edge_index, edge_attr, u, batch):\n",
" row, col = edge_index\n",
" out = torch.cat([x[row], edge_attr], dim=1)\n",
" out = self.node_mlp_1(out)\n",
" out = scatter_mean(out, col, dim=0, dim_size=x.size(0))\n",
" out = torch.cat([x, out], dim=1)\n",
" return self.node_mlp_2(out)\n",
"\n",
" \n",
"class GlobalBlock(torch.nn.Module):\n",
" def __init__(self):\n",
" super(GlobalBlock, self).__init__()\n",
" self.global_mlp = Seq(Lin(hidden, hidden), \n",
" BatchNorm1d(hidden),\n",
" ReLU(), \n",
" Lin(hidden, outputs))\n",
"\n",
" def forward(self, x, edge_index, edge_attr, u, batch):\n",
" out = scatter_mean(x, batch, dim=0)\n",
" return self.global_mlp(out)\n",
"\n",
"\n",
"class InteractionNetwork(torch.nn.Module):\n",
" def __init__(self):\n",
" super(InteractionNetwork, self).__init__()\n",
" self.interactionnetwork = MetaLayer(EdgeBlock(), NodeBlock(), GlobalBlock())\n",
" self.bn = BatchNorm1d(inputs)\n",
" \n",
" def forward(self, x, edge_index, batch):\n",
" \n",
" x = self.bn(x)\n",
" x, edge_attr, u = self.interactionnetwork(x, edge_index, None, None, batch)\n",
" return u\n",
" \n",
"model = InteractionNetwork().to(device)\n",
"optimizer = torch.optim.Adam(model.parameters(), lr = 1e-2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define training loop"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@torch.no_grad()\n",
"def test(model, loader, total, batch_size, leave=False):\n",
" model.eval()\n",
" \n",
" xentropy = nn.CrossEntropyLoss(reduction='mean')\n",
"\n",
" sum_loss = 0.\n",
" t = tqdm(enumerate(loader), total=total/batch_size, leave=leave)\n",
" for i, data in t:\n",
" data = data.to(device)\n",
" y = torch.argmax(data.y, dim=1)\n",
" batch_output = model(data.x, data.edge_index, data.batch)\n",
" batch_loss_item = xentropy(batch_output, y).item()\n",
" sum_loss += batch_loss_item\n",
" t.set_description(\"loss = %.5f\" % (batch_loss_item))\n",
" t.refresh() # to show immediately the update\n",
"\n",
" return sum_loss/(i+1)\n",
"\n",
"def train(model, optimizer, loader, total, batch_size, leave=False):\n",
" model.train()\n",
" \n",
" xentropy = nn.CrossEntropyLoss(reduction='mean')\n",
"\n",
" sum_loss = 0.\n",
" t = tqdm(enumerate(loader), total=total/batch_size, leave=leave)\n",
" for i, data in t:\n",
" data = data.to(device)\n",
" y = torch.argmax(data.y, dim=1)\n",
" optimizer.zero_grad()\n",
" batch_output = model(data.x, data.edge_index, data.batch)\n",
" batch_loss = xentropy(batch_output, y)\n",
" batch_loss.backward()\n",
" batch_loss_item = batch_loss.item()\n",
" t.set_description(\"loss = %.5f\" % batch_loss_item)\n",
" t.refresh() # to show immediately the update\n",
" sum_loss += batch_loss_item\n",
" optimizer.step()\n",
" \n",
" return sum_loss/(i+1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define training, validation, testing data generators"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torch_geometric.data import Data, DataListLoader, Batch\n",
"from torch.utils.data import random_split\n",
"\n",
"def collate(items):\n",
" l = sum(items, [])\n",
" return Batch.from_data_list(l)\n",
"\n",
"torch.manual_seed(0)\n",
"valid_frac = 0.20\n",
"full_length = len(graph_dataset)\n",
"valid_num = int(valid_frac*full_length)\n",
"batch_size = 32\n",
"\n",
"train_dataset, valid_dataset = random_split(graph_dataset, [full_length-valid_num,valid_num])\n",
"\n",
"train_loader = DataListLoader(train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True)\n",
"train_loader.collate_fn = collate\n",
"valid_loader = DataListLoader(valid_dataset, batch_size=batch_size, pin_memory=True, shuffle=False)\n",
"valid_loader.collate_fn = collate\n",
"test_loader = DataListLoader(test_dataset, batch_size=batch_size, pin_memory=True, shuffle=False)\n",
"test_loader.collate_fn = collate\n",
"\n",
"\n",
"train_samples = len(train_dataset)\n",
"valid_samples = len(valid_dataset)\n",
"test_samples = len(test_dataset)\n",
"print(full_length)\n",
"print(train_samples)\n",
"print(valid_samples)\n",
"print(test_samples)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os.path as osp\n",
"\n",
"n_epochs = 10\n",
"stale_epochs = 0\n",
"best_valid_loss = 99999\n",
"patience = 5\n",
"t = tqdm(range(0, n_epochs))\n",
"\n",
"for epoch in t:\n",
" loss = train(model, optimizer, train_loader, train_samples, batch_size, leave=bool(epoch==n_epochs-1))\n",
" valid_loss = test(model, valid_loader, valid_samples, batch_size, leave=bool(epoch==n_epochs-1))\n",
" print('Epoch: {:02d}, Training Loss: {:.4f}'.format(epoch, loss))\n",
" print(' Validation Loss: {:.4f}'.format(valid_loss))\n",
"\n",
" if valid_loss < best_valid_loss:\n",
" best_valid_loss = valid_loss\n",
" modpath = osp.join('interactionnetwork_best.pth')\n",
" print('New best model saved to:',modpath)\n",
" torch.save(model.state_dict(),modpath)\n",
" stale_epochs = 0\n",
" else:\n",
" print('Stale epoch')\n",
" stale_epochs += 1\n",
" if stale_epochs >= patience:\n",
" print('Early stopping after %i stale epochs'%patience)\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluate on testing data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.eval()\n",
"t = tqdm(enumerate(test_loader),total=test_samples/batch_size)\n",
"y_test = []\n",
"y_predict = []\n",
"for i,data in t:\n",
" data = data.to(device) \n",
" batch_output = model(data.x, data.edge_index, data.batch) \n",
" y_predict.append(batch_output.detach().cpu().numpy())\n",
" y_test.append(data.y.cpu().numpy())\n",
"y_test = np.concatenate(y_test)\n",
"y_predict = np.concatenate(y_predict)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import roc_curve, auc\n",
"import matplotlib.pyplot as plt\n",
"import mplhep as hep\n",
"plt.style.use(hep.style.ROOT)\n",
"# create ROC curves\n",
"fpr_gnn, tpr_gnn, threshold_gnn = roc_curve(y_test[:,1], y_predict[:,1])\n",
" \n",
"# plot ROC curves\n",
"plt.figure()\n",
"plt.plot(tpr_gnn, fpr_gnn, lw=2.5, label=\"GNN, AUC = {:.1f}%\".format(auc(fpr_gnn,tpr_gnn)*100))\n",
"plt.xlabel(r'True positive rate')\n",
"plt.ylabel(r'False positive rate')\n",
"plt.semilogy()\n",
"plt.ylim(0.001, 1)\n",
"plt.xlim(0, 1)\n",
"plt.grid(True)\n",
"plt.legend(loc='upper left')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}