Deep Sets#

We will start by looking at Deep Sets networks using PyTorch. The architecture is based on the following paper: DeepSets

import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from tqdm.notebook import tqdm
import numpy as np
# For Colab

!pip install wget
import wget

!pip install -U PyYAML
!pip install uproot
!pip install awkward
!pip install mplhep
!pip install torch_scatter
Collecting wget
  Downloading wget-3.2.zip (10 kB)
  Preparing metadata (setup.py) ... ?25l-
 done
?25hBuilding wheels for collected packages: wget
  Building wheel for wget (setup.py) ... ?25l-
 \
 done
?25h  Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9654 sha256=7403af69564cd87a81cb4f7ebc9f0dba9afbb11c2ec7f1eb419dd797fba2f5d6
  Stored in directory: /home/runner/.cache/pip/wheels/e1/e8/db/ebe4dcd7d7d11208c1e4e4ef246cea4fcc8d463c93405a6555
Successfully built wget
Installing collected packages: wget
Successfully installed wget-3.2
Requirement already satisfied: PyYAML in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (6.0)
Requirement already satisfied: uproot in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (4.3.7)
Requirement already satisfied: setuptools in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from uproot) (65.5.1)
Requirement already satisfied: numpy in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from uproot) (1.21.6)
Requirement already satisfied: awkward in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (1.10.1)
Requirement already satisfied: importlib-resources in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from awkward) (5.10.0)
Requirement already satisfied: packaging in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from awkward) (21.3)
Requirement already satisfied: numpy>=1.13.1 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from awkward) (1.21.6)
Requirement already satisfied: zipp>=3.1.0 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from importlib-resources->awkward) (3.10.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from packaging->awkward) (3.0.9)
Requirement already satisfied: mplhep in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (0.3.26)
Requirement already satisfied: matplotlib>=3.4 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from mplhep) (3.5.3)
Requirement already satisfied: uhi>=0.2.0 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from mplhep) (0.3.2)
Requirement already satisfied: mplhep-data in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from mplhep) (0.0.3)
Requirement already satisfied: packaging in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from mplhep) (21.3)
Requirement already satisfied: numpy>=1.16.0 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from mplhep) (1.21.6)
Requirement already satisfied: pyparsing>=2.2.1 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from matplotlib>=3.4->mplhep) (3.0.9)
Requirement already satisfied: cycler>=0.10 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from matplotlib>=3.4->mplhep) (0.11.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from matplotlib>=3.4->mplhep) (1.4.4)
Requirement already satisfied: fonttools>=4.22.0 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from matplotlib>=3.4->mplhep) (4.38.0)
Requirement already satisfied: pillow>=6.2.0 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from matplotlib>=3.4->mplhep) (9.3.0)
Requirement already satisfied: python-dateutil>=2.7 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from matplotlib>=3.4->mplhep) (2.8.2)
Requirement already satisfied: typing-extensions>=3.7 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from uhi>=0.2.0->mplhep) (4.4.0)
Requirement already satisfied: six>=1.5 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from python-dateutil>=2.7->matplotlib>=3.4->mplhep) (1.16.0)
Requirement already satisfied: torch_scatter in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (2.0.9)
import yaml
import os.path

# WGET for colab
if not os.path.exists("definitions_lorentz.yml"):
    url = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/definitions_lorentz.yml"
    definitionsFile = wget.download(url)

with open("definitions_lorentz.yml") as file:
    # The FullLoader parameter handles the conversion from YAML
    # scalar values to Python the dictionary format
    definitions = yaml.load(file, Loader=yaml.FullLoader)

features = definitions["features"]
spectators = definitions["spectators"]
labels = definitions["labels"]

nfeatures = definitions["nfeatures"]
nspectators = definitions["nspectators"]
nlabels = definitions["nlabels"]
ntracks = definitions["ntracks"]

Dataset loader#

Here we have to define the dataset loader.

# If in colab
if not os.path.exists("GraphDataset.py"):
    urlDSD = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/DeepSetsDataset.py"
    DSD = wget.download(urlDSD)
if not os.path.exists("utils.py"):
    urlUtils = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/utils.py"
    utils = wget.download(urlUtils)

from DeepSetsDataset import DeepSetsDataset

# For colab
import os.path

if not os.path.exists("ntuple_merged_90.root"):
    urlFILE = "http://opendata.cern.ch/eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/train/ntuple_merged_90.root"
    dataFILE = wget.download(urlFILE)
train_files = ["ntuple_merged_90.root"]

##Locally with XRootD
# train_files = ['root://eospublic.cern.ch//eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/train/ntuple_merged_90.root']
# test_files = ['root://eospublic.cern.ch//eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/test/ntuple_merged_1.root']

train_generator = DeepSetsDataset(
    features,
    labels,
    spectators,
    start_event=0,
    stop_event=10000,
    npad=ntracks,
    file_names=train_files,
)
train_generator.process()

test_generator = DeepSetsDataset(
    features,
    labels,
    spectators,
    start_event=10001,
    stop_event=14001,
    npad=ntracks,
    file_names=train_files,
)
test_generator.process()
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
/tmp/ipykernel_5091/2297903339.py in <module>
     14 if not os.path.exists("ntuple_merged_90.root"):
     15     urlFILE = "http://opendata.cern.ch/eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/train/ntuple_merged_90.root"
---> 16     dataFILE = wget.download(urlFILE)
     17 train_files = ["ntuple_merged_90.root"]
     18 

/opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages/wget.py in download(url, out, bar)
    524     else:
    525         binurl = url
--> 526     (tmpfile, headers) = ulib.urlretrieve(binurl, tmpfile, callback)
    527     filename = detect_filename(url, out, headers)
    528     if outdir:

/opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/urllib/request.py in urlretrieve(url, filename, reporthook, data)
    274 
    275             while True:
--> 276                 block = fp.read(bs)
    277                 if not block:
    278                     break

/opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/http/client.py in read(self, amt)
    463             # Amount is given, implement using readinto
    464             b = bytearray(amt)
--> 465             n = self.readinto(b)
    466             return memoryview(b)[:n].tobytes()
    467         else:

/opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/http/client.py in readinto(self, b)
    507         # connection, and the user is reading more bytes than will be provided
    508         # (for example, reading in 1k chunks)
--> 509         n = self.fp.readinto(b)
    510         if not n and b:
    511             # Ideally, we would raise IncompleteRead if the content-length

/opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/socket.py in readinto(self, b)
    587         while True:
    588             try:
--> 589                 return self._sock.recv_into(b)
    590             except timeout:
    591                 self._timeout_occurred = True

KeyboardInterrupt: 

Deep Sets Network#

Deep Sets models are designed to be explicitly permutation invariant. At their core they are composed of two networks, \(\phi\) and \(\rho\), such that the total network \(f\) is given by

\( \begin{align} f &= \rho\left(\Sigma_{\mathbf{x}_i\in\mathcal{X}}\phi(\mathbf{x}_i)\right) \label{eq:deepsets-functions} \end{align} \)

where \(\mathbf{x}_i\) are the features for the i-th element in the input sequence \(\mathcal{X}\).

We will define a DeepSets model that will take as input up to 60 of the tracks (with 48 features) with zero-padding.

import torch.nn as nn
import torch.nn.functional as F
from torch.nn import (
    Sequential as Seq,
    Linear as Lin,
    ReLU,
    BatchNorm1d,
    AvgPool1d,
    Sigmoid,
    Conv1d,
)
from torch_scatter import scatter_mean

# ntracks = 60
inputs = 6
hidden1 = 64
hidden2 = 32
hidden3 = 16
classify1 = 50
outputs = 2


class DeepSets(torch.nn.Module):
    def __init__(self):
        super(DeepSets, self).__init__()
        self.phi = Seq(
            Conv1d(inputs, hidden1, 1),
            BatchNorm1d(hidden1),
            ReLU(),
            Conv1d(hidden1, hidden2, 1),
            BatchNorm1d(hidden2),
            ReLU(),
            Conv1d(hidden2, hidden3, 1),
            BatchNorm1d(hidden3),
            ReLU(),
        )
        self.rho = Seq(
            Lin(hidden3, classify1),
            BatchNorm1d(classify1),
            ReLU(),
            Lin(classify1, outputs),
            Sigmoid(),
        )

    def forward(self, x):
        out = self.phi(x)
        out = scatter_mean(out, torch.LongTensor(np.zeros(ntracks)), dim=-1)
        return self.rho(torch.squeeze(out))


model = DeepSets()
print(model)
print("----------")
print({l: model.state_dict()[l].shape for l in model.state_dict()})

model = DeepSets().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
DeepSets(
  (phi): Sequential(
    (0): Conv1d(6, 64, kernel_size=(1,), stride=(1,))
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv1d(64, 32, kernel_size=(1,), stride=(1,))
    (4): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv1d(32, 16, kernel_size=(1,), stride=(1,))
    (7): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
  )
  (rho): Sequential(
    (0): Linear(in_features=16, out_features=50, bias=True)
    (1): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=50, out_features=2, bias=True)
    (4): Sigmoid()
  )
)
----------
{'phi.0.weight': torch.Size([64, 6, 1]), 'phi.0.bias': torch.Size([64]), 'phi.1.weight': torch.Size([64]), 'phi.1.bias': torch.Size([64]), 'phi.1.running_mean': torch.Size([64]), 'phi.1.running_var': torch.Size([64]), 'phi.1.num_batches_tracked': torch.Size([]), 'phi.3.weight': torch.Size([32, 64, 1]), 'phi.3.bias': torch.Size([32]), 'phi.4.weight': torch.Size([32]), 'phi.4.bias': torch.Size([32]), 'phi.4.running_mean': torch.Size([32]), 'phi.4.running_var': torch.Size([32]), 'phi.4.num_batches_tracked': torch.Size([]), 'phi.6.weight': torch.Size([16, 32, 1]), 'phi.6.bias': torch.Size([16]), 'phi.7.weight': torch.Size([16]), 'phi.7.bias': torch.Size([16]), 'phi.7.running_mean': torch.Size([16]), 'phi.7.running_var': torch.Size([16]), 'phi.7.num_batches_tracked': torch.Size([]), 'rho.0.weight': torch.Size([50, 16]), 'rho.0.bias': torch.Size([50]), 'rho.1.weight': torch.Size([50]), 'rho.1.bias': torch.Size([50]), 'rho.1.running_mean': torch.Size([50]), 'rho.1.running_var': torch.Size([50]), 'rho.1.num_batches_tracked': torch.Size([]), 'rho.3.weight': torch.Size([2, 50]), 'rho.3.bias': torch.Size([2])}

Define training loop#

@torch.no_grad()
def test(model, loader, total, batch_size, leave=False):
    model.eval()

    xentropy = nn.CrossEntropyLoss(reduction="mean")

    sum_loss = 0.0
    t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)
    for i, data in t:
        x = data[0].to(device)
        y = data[1].to(device)
        y = torch.argmax(y, dim=1)
        batch_output = model(x)
        batch_loss_item = xentropy(batch_output, y).item()
        sum_loss += batch_loss_item
        t.set_description("loss = %.5f" % (batch_loss_item))
        t.refresh()  # to show immediately the update

    return sum_loss / (i + 1)


def train(model, optimizer, loader, total, batch_size, leave=False):
    model.train()

    xentropy = nn.CrossEntropyLoss(reduction="mean")

    sum_loss = 0.0
    t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)
    for i, data in t:
        x = data[0].to(device)
        y = data[1].to(device)
        y = torch.argmax(y, dim=1)
        optimizer.zero_grad()
        batch_output = model(x)
        batch_loss = xentropy(batch_output, y)
        batch_loss.backward()
        batch_loss_item = batch_loss.item()
        t.set_description("loss = %.5f" % batch_loss_item)
        t.refresh()  # to show immediately the update
        sum_loss += batch_loss_item
        optimizer.step()

    return sum_loss / (i + 1)

Define training, validation, testing data generators#

from torch.utils.data import ConcatDataset

train_generator_data = ConcatDataset(train_generator.datas)
test_generator_data = ConcatDataset(test_generator.datas)
from torch.utils.data import random_split, DataLoader

torch.manual_seed(0)
valid_frac = 0.20
train_length = len(train_generator_data)
valid_num = int(valid_frac * train_length)
batch_size = 32

train_dataset, valid_dataset = random_split(
    train_generator_data, [train_length - valid_num, valid_num]
)


def collate(items):
    l = sum(items, [])
    return Batch.from_data_list(l)


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# train_loader.collate_fn = collate
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
# valid_loader.collate_fn = collate
test_loader = DataLoader(test_generator_data, batch_size=batch_size, shuffle=False)
# test_loader.collate_fn = collate

train_samples = len(train_dataset)
valid_samples = len(valid_dataset)
test_samples = len(test_generator_data)
print(train_length)
print(train_samples)
print(valid_samples)
print(test_samples)
9387
7510
1877
3751

Train#

import os.path as osp

n_epochs = 30
stale_epochs = 0
best_valid_loss = 99999
patience = 5
t = tqdm(range(0, n_epochs))

for epoch in t:
    loss = train(
        model,
        optimizer,
        train_loader,
        train_samples,
        batch_size,
        leave=bool(epoch == n_epochs - 1),
    )
    valid_loss = test(
        model,
        valid_loader,
        valid_samples,
        batch_size,
        leave=bool(epoch == n_epochs - 1),
    )
    print("Epoch: {:02d}, Training Loss:   {:.4f}".format(epoch, loss))
    print("           Validation Loss: {:.4f}".format(valid_loss))

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        modpath = osp.join("deepsets_best.pth")
        print("New best model saved to:", modpath)
        torch.save(model.state_dict(), modpath)
        stale_epochs = 0
    else:
        print("Stale epoch")
        stale_epochs += 1
    if stale_epochs >= patience:
        print("Early stopping after %i stale epochs" % patience)
        break
Epoch: 00, Training Loss:   0.4502
           Validation Loss: 0.4461
New best model saved to: deepsets_best.pth
Epoch: 01, Training Loss:   0.4396
           Validation Loss: 0.4470
Stale epoch
Epoch: 02, Training Loss:   0.4404
           Validation Loss: 0.4452
New best model saved to: deepsets_best.pth
Epoch: 03, Training Loss:   0.4397
           Validation Loss: 0.4475
Stale epoch
Epoch: 04, Training Loss:   0.4396
           Validation Loss: 0.4489
Stale epoch
Epoch: 05, Training Loss:   0.4397
           Validation Loss: 0.4465
Stale epoch
Epoch: 06, Training Loss:   0.4392
           Validation Loss: 0.4467
Stale epoch
Epoch: 07, Training Loss:   0.4392
           Validation Loss: 0.4489
Stale epoch
Early stopping after 5 stale epochs

Evaluate on testing data#

# In case you need to load the model from a pth file
# Trained on 4 vectors (as above in notebook)
# urlPTH = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/deepsets_best_4vec.pth"
# pthFile = wget.download(urlPTH)
# model.load_state_dict(torch.load("deepsets_best_4vec.pth"))
# Trained on all possible inputs (a different configuration)
# urlPTH = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/deepsets_best_AllTraining.pth"
# pthFile = wget.download(urlPTH)
# model.load_state_dict(torch.load("deepsets_best_AllTraining.pth"))

model.eval()
t = tqdm(enumerate(test_loader), total=test_samples / batch_size)
y_test = []
y_predict = []
track_pt = []
for i, data in t:
    x = data[0].to(device)
    y = data[1].to(device)
    track_pt.append(x[:, 0, 0].numpy())
    batch_output = model(x)
    y_predict.append(batch_output.detach().cpu().numpy())
    y_test.append(y.cpu().numpy())
track_pt = np.concatenate(track_pt)
y_test = np.concatenate(y_test)
y_predict = np.concatenate(y_predict)
# For Colab
import matplotlib.pyplot as plt

_, bins, _ = plt.hist(
    track_pt[y_test[:, 1] == 1], bins=50, label="sig", histtype="step"
)
_, bins, _ = plt.hist(
    track_pt[y_test[:, 1] == 0], bins=bins, label="bkg", histtype="step"
)
plt.legend()
plt.semilogy()
[]
_images/1.3_deep_sets_17_1.png
from sklearn.metrics import roc_curve, auc
import mplhep as hep

plt.style.use(hep.style.ROOT)
# create ROC curves
fpr_deepset, tpr_deepset, threshold_deepset = roc_curve(y_test[:, 1], y_predict[:, 1])
with open("deepset_roc.npy", "wb") as f:
    np.save(f, fpr_deepset)
    np.save(f, tpr_deepset)
    np.save(f, threshold_deepset)

# plot ROC curves
plt.figure()
plt.plot(
    tpr_deepset,
    fpr_deepset,
    lw=2.5,
    label="DeepSet, AUC = {:.1f}%".format(auc(fpr_deepset, tpr_deepset) * 100),
)
plt.xlabel(r"True positive rate")
plt.ylabel(r"False positive rate")
plt.ylim(0.001, 1)
plt.xlim(0, 1)
plt.grid(True)
plt.legend(loc="upper left")
plt.show()
_images/1.3_deep_sets_18_0.png
model.eval()
t = tqdm(enumerate(test_loader), total=test_samples / batch_size)
y_test = []
y_predict = []
track_pt = []
for i, data in t:
    x = data[0].to(device)
    y = data[1].to(device)
    idx = torch.randperm(x.size(2))
    x = x[:, :, idx]
    track_pt.append(x[:, 0, 0].numpy())
    batch_output = model(x)
    y_predict.append(batch_output.detach().cpu().numpy())
    y_test.append(y.cpu().numpy())
track_pt = np.concatenate(track_pt)
y_test = np.concatenate(y_test)
y_predict = np.concatenate(y_predict)
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import mplhep as hep

plt.style.use(hep.style.ROOT)
# create ROC curves
fpr_deepset_perm, tpr_deepset_perm, threshold_deepset_perm = roc_curve(
    y_test[:, 1], y_predict[:, 1]
)
with open("deepset_perm_roc.npy", "wb") as f:
    np.save(f, fpr_deepset)
    np.save(f, tpr_deepset)
    np.save(f, threshold_deepset)

with open("deepset_roc.npy", "rb") as f:
    fpr_deepset = np.load(f)
    tpr_deepset = np.load(f)
    threshold_deepset = np.load(f)

# plot ROC curves
plt.figure()
plt.plot(
    tpr_deepset,
    fpr_deepset,
    lw=2.5,
    label="DeepSet, AUC = {:.1f}%".format(auc(fpr_deepset, tpr_deepset) * 100),
)
plt.plot(
    tpr_deepset_perm,
    fpr_deepset_perm,
    lw=2.5,
    label="DeepSet (Permuted), AUC = {:.1f}%".format(
        auc(fpr_deepset_perm, tpr_deepset_perm) * 100
    ),
    linestyle="dashed",
)
plt.xlabel(r"True positive rate")
plt.ylabel(r"False positive rate")
plt.ylim(0.001, 1)
plt.xlim(0, 1)
plt.grid(True)
plt.legend(loc="upper left")
plt.show()
_images/1.3_deep_sets_20_0.png
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import mplhep as hep

plt.style.use(hep.style.ROOT)
# create ROC curves
pt_split = 5.0
fpr_deepset_lowb, tpr_deepset_lowb, threshold_deepset_lowb = roc_curve(
    y_test[track_pt <= pt_split, 1], y_predict[track_pt <= pt_split, 1]
)
fpr_deepset_highb, tpr_deepset_highb, threshold_deepset_highb = roc_curve(
    y_test[track_pt > pt_split, 1], y_predict[track_pt > pt_split, 1]
)

# plot ROC curves
plt.figure()
plt.plot(
    tpr_deepset_lowb,
    fpr_deepset_lowb,
    lw=2.5,
    label="DeepSet (low pT), AUC = {:.1f}%".format(
        auc(fpr_deepset_lowb, tpr_deepset_lowb) * 100
    ),
)
plt.plot(
    tpr_deepset_highb,
    fpr_deepset_highb,
    lw=2.5,
    label="DeepSet (high pT), AUC = {:.1f}%".format(
        auc(fpr_deepset_highb, tpr_deepset_highb) * 100
    ),
)
plt.xlabel(r"True positive rate")
plt.ylabel(r"False positive rate")
plt.ylim(0.001, 1)
plt.xlim(0, 1)
plt.grid(True)
plt.legend(loc="upper left")
plt.show()
_images/1.3_deep_sets_21_0.png

If you finish this notebook, you can go back and retrain with different inputs. Replace the code above:

# WGET for colab
if not os.path.exists("definitions_lorentz.yml"):
    url = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/definitions_lorentz.yml"
    definitionsFile = wget.download(url)

with open("definitions_lorentz.yml") as file:
    # The FullLoader parameter handles the conversion from YAML
    # scalar values to Python the dictionary format
    definitions = yaml.load(file, Loader=yaml.FullLoader)

with

# WGET for colab
if not os.path.exists("definitions.yml"):
    url = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/definitions.yml"
    definitionsFile = wget.download(url)

with open("definitions.yml") as file:
    # The FullLoader parameter handles the conversion from YAML
    # scalar values to Python the dictionary format
    definitions = yaml.load(file, Loader=yaml.FullLoader)