Deep Sets
Contents
Deep Sets#
We will start by looking at Deep Sets networks using PyTorch. The architecture is based on the following paper: DeepSets
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from tqdm.notebook import tqdm
import numpy as np
# For Colab
!pip install wget
import wget
!pip install -U PyYAML
!pip install uproot
!pip install awkward
!pip install mplhep
!pip install torch_scatter
Collecting wget
Downloading wget-3.2.zip (10 kB)
Preparing metadata (setup.py) ... ?25l-
done
?25hBuilding wheels for collected packages: wget
Building wheel for wget (setup.py) ... ?25l-
\
done
?25h Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9654 sha256=7403af69564cd87a81cb4f7ebc9f0dba9afbb11c2ec7f1eb419dd797fba2f5d6
Stored in directory: /home/runner/.cache/pip/wheels/e1/e8/db/ebe4dcd7d7d11208c1e4e4ef246cea4fcc8d463c93405a6555
Successfully built wget
Installing collected packages: wget
Successfully installed wget-3.2
Requirement already satisfied: PyYAML in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (6.0)
Requirement already satisfied: uproot in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (4.3.7)
Requirement already satisfied: setuptools in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from uproot) (65.5.1)
Requirement already satisfied: numpy in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from uproot) (1.21.6)
Requirement already satisfied: awkward in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (1.10.1)
Requirement already satisfied: importlib-resources in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from awkward) (5.10.0)
Requirement already satisfied: packaging in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from awkward) (21.3)
Requirement already satisfied: numpy>=1.13.1 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from awkward) (1.21.6)
Requirement already satisfied: zipp>=3.1.0 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from importlib-resources->awkward) (3.10.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from packaging->awkward) (3.0.9)
Requirement already satisfied: mplhep in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (0.3.26)
Requirement already satisfied: matplotlib>=3.4 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from mplhep) (3.5.3)
Requirement already satisfied: uhi>=0.2.0 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from mplhep) (0.3.2)
Requirement already satisfied: mplhep-data in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from mplhep) (0.0.3)
Requirement already satisfied: packaging in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from mplhep) (21.3)
Requirement already satisfied: numpy>=1.16.0 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from mplhep) (1.21.6)
Requirement already satisfied: pyparsing>=2.2.1 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from matplotlib>=3.4->mplhep) (3.0.9)
Requirement already satisfied: cycler>=0.10 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from matplotlib>=3.4->mplhep) (0.11.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from matplotlib>=3.4->mplhep) (1.4.4)
Requirement already satisfied: fonttools>=4.22.0 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from matplotlib>=3.4->mplhep) (4.38.0)
Requirement already satisfied: pillow>=6.2.0 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from matplotlib>=3.4->mplhep) (9.3.0)
Requirement already satisfied: python-dateutil>=2.7 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from matplotlib>=3.4->mplhep) (2.8.2)
Requirement already satisfied: typing-extensions>=3.7 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from uhi>=0.2.0->mplhep) (4.4.0)
Requirement already satisfied: six>=1.5 in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (from python-dateutil>=2.7->matplotlib>=3.4->mplhep) (1.16.0)
Requirement already satisfied: torch_scatter in /opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages (2.0.9)
import yaml
import os.path
# WGET for colab
if not os.path.exists("definitions_lorentz.yml"):
url = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/definitions_lorentz.yml"
definitionsFile = wget.download(url)
with open("definitions_lorentz.yml") as file:
# The FullLoader parameter handles the conversion from YAML
# scalar values to Python the dictionary format
definitions = yaml.load(file, Loader=yaml.FullLoader)
features = definitions["features"]
spectators = definitions["spectators"]
labels = definitions["labels"]
nfeatures = definitions["nfeatures"]
nspectators = definitions["nspectators"]
nlabels = definitions["nlabels"]
ntracks = definitions["ntracks"]
Dataset loader#
Here we have to define the dataset loader.
# If in colab
if not os.path.exists("GraphDataset.py"):
urlDSD = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/DeepSetsDataset.py"
DSD = wget.download(urlDSD)
if not os.path.exists("utils.py"):
urlUtils = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/utils.py"
utils = wget.download(urlUtils)
from DeepSetsDataset import DeepSetsDataset
# For colab
import os.path
if not os.path.exists("ntuple_merged_90.root"):
urlFILE = "http://opendata.cern.ch/eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/train/ntuple_merged_90.root"
dataFILE = wget.download(urlFILE)
train_files = ["ntuple_merged_90.root"]
##Locally with XRootD
# train_files = ['root://eospublic.cern.ch//eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/train/ntuple_merged_90.root']
# test_files = ['root://eospublic.cern.ch//eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/test/ntuple_merged_1.root']
train_generator = DeepSetsDataset(
features,
labels,
spectators,
start_event=0,
stop_event=10000,
npad=ntracks,
file_names=train_files,
)
train_generator.process()
test_generator = DeepSetsDataset(
features,
labels,
spectators,
start_event=10001,
stop_event=14001,
npad=ntracks,
file_names=train_files,
)
test_generator.process()
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
/tmp/ipykernel_5091/2297903339.py in <module>
14 if not os.path.exists("ntuple_merged_90.root"):
15 urlFILE = "http://opendata.cern.ch/eos/opendata/cms/datascience/HiggsToBBNtupleProducerTool/HiggsToBBNTuple_HiggsToBB_QCD_RunII_13TeV_MC/train/ntuple_merged_90.root"
---> 16 dataFILE = wget.download(urlFILE)
17 train_files = ["ntuple_merged_90.root"]
18
/opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/site-packages/wget.py in download(url, out, bar)
524 else:
525 binurl = url
--> 526 (tmpfile, headers) = ulib.urlretrieve(binurl, tmpfile, callback)
527 filename = detect_filename(url, out, headers)
528 if outdir:
/opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/urllib/request.py in urlretrieve(url, filename, reporthook, data)
274
275 while True:
--> 276 block = fp.read(bs)
277 if not block:
278 break
/opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/http/client.py in read(self, amt)
463 # Amount is given, implement using readinto
464 b = bytearray(amt)
--> 465 n = self.readinto(b)
466 return memoryview(b)[:n].tobytes()
467 else:
/opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/http/client.py in readinto(self, b)
507 # connection, and the user is reading more bytes than will be provided
508 # (for example, reading in 1k chunks)
--> 509 n = self.fp.readinto(b)
510 if not n and b:
511 # Ideally, we would raise IncompleteRead if the content-length
/opt/hostedtoolcache/Python/3.7.15/x64/lib/python3.7/socket.py in readinto(self, b)
587 while True:
588 try:
--> 589 return self._sock.recv_into(b)
590 except timeout:
591 self._timeout_occurred = True
KeyboardInterrupt:
Deep Sets Network#
Deep Sets models are designed to be explicitly permutation invariant. At their core they are composed of two networks, \(\phi\) and \(\rho\), such that the total network \(f\) is given by
\( \begin{align} f &= \rho\left(\Sigma_{\mathbf{x}_i\in\mathcal{X}}\phi(\mathbf{x}_i)\right) \label{eq:deepsets-functions} \end{align} \)
where \(\mathbf{x}_i\) are the features for the i-th element in the input sequence \(\mathcal{X}\).
We will define a DeepSets model that will take as input up to 60 of the tracks (with 48 features) with zero-padding.
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import (
Sequential as Seq,
Linear as Lin,
ReLU,
BatchNorm1d,
AvgPool1d,
Sigmoid,
Conv1d,
)
from torch_scatter import scatter_mean
# ntracks = 60
inputs = 6
hidden1 = 64
hidden2 = 32
hidden3 = 16
classify1 = 50
outputs = 2
class DeepSets(torch.nn.Module):
def __init__(self):
super(DeepSets, self).__init__()
self.phi = Seq(
Conv1d(inputs, hidden1, 1),
BatchNorm1d(hidden1),
ReLU(),
Conv1d(hidden1, hidden2, 1),
BatchNorm1d(hidden2),
ReLU(),
Conv1d(hidden2, hidden3, 1),
BatchNorm1d(hidden3),
ReLU(),
)
self.rho = Seq(
Lin(hidden3, classify1),
BatchNorm1d(classify1),
ReLU(),
Lin(classify1, outputs),
Sigmoid(),
)
def forward(self, x):
out = self.phi(x)
out = scatter_mean(out, torch.LongTensor(np.zeros(ntracks)), dim=-1)
return self.rho(torch.squeeze(out))
model = DeepSets()
print(model)
print("----------")
print({l: model.state_dict()[l].shape for l in model.state_dict()})
model = DeepSets().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
DeepSets(
(phi): Sequential(
(0): Conv1d(6, 64, kernel_size=(1,), stride=(1,))
(1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv1d(64, 32, kernel_size=(1,), stride=(1,))
(4): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
(6): Conv1d(32, 16, kernel_size=(1,), stride=(1,))
(7): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU()
)
(rho): Sequential(
(0): Linear(in_features=16, out_features=50, bias=True)
(1): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Linear(in_features=50, out_features=2, bias=True)
(4): Sigmoid()
)
)
----------
{'phi.0.weight': torch.Size([64, 6, 1]), 'phi.0.bias': torch.Size([64]), 'phi.1.weight': torch.Size([64]), 'phi.1.bias': torch.Size([64]), 'phi.1.running_mean': torch.Size([64]), 'phi.1.running_var': torch.Size([64]), 'phi.1.num_batches_tracked': torch.Size([]), 'phi.3.weight': torch.Size([32, 64, 1]), 'phi.3.bias': torch.Size([32]), 'phi.4.weight': torch.Size([32]), 'phi.4.bias': torch.Size([32]), 'phi.4.running_mean': torch.Size([32]), 'phi.4.running_var': torch.Size([32]), 'phi.4.num_batches_tracked': torch.Size([]), 'phi.6.weight': torch.Size([16, 32, 1]), 'phi.6.bias': torch.Size([16]), 'phi.7.weight': torch.Size([16]), 'phi.7.bias': torch.Size([16]), 'phi.7.running_mean': torch.Size([16]), 'phi.7.running_var': torch.Size([16]), 'phi.7.num_batches_tracked': torch.Size([]), 'rho.0.weight': torch.Size([50, 16]), 'rho.0.bias': torch.Size([50]), 'rho.1.weight': torch.Size([50]), 'rho.1.bias': torch.Size([50]), 'rho.1.running_mean': torch.Size([50]), 'rho.1.running_var': torch.Size([50]), 'rho.1.num_batches_tracked': torch.Size([]), 'rho.3.weight': torch.Size([2, 50]), 'rho.3.bias': torch.Size([2])}
Define training loop#
@torch.no_grad()
def test(model, loader, total, batch_size, leave=False):
model.eval()
xentropy = nn.CrossEntropyLoss(reduction="mean")
sum_loss = 0.0
t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)
for i, data in t:
x = data[0].to(device)
y = data[1].to(device)
y = torch.argmax(y, dim=1)
batch_output = model(x)
batch_loss_item = xentropy(batch_output, y).item()
sum_loss += batch_loss_item
t.set_description("loss = %.5f" % (batch_loss_item))
t.refresh() # to show immediately the update
return sum_loss / (i + 1)
def train(model, optimizer, loader, total, batch_size, leave=False):
model.train()
xentropy = nn.CrossEntropyLoss(reduction="mean")
sum_loss = 0.0
t = tqdm(enumerate(loader), total=total / batch_size, leave=leave)
for i, data in t:
x = data[0].to(device)
y = data[1].to(device)
y = torch.argmax(y, dim=1)
optimizer.zero_grad()
batch_output = model(x)
batch_loss = xentropy(batch_output, y)
batch_loss.backward()
batch_loss_item = batch_loss.item()
t.set_description("loss = %.5f" % batch_loss_item)
t.refresh() # to show immediately the update
sum_loss += batch_loss_item
optimizer.step()
return sum_loss / (i + 1)
Define training, validation, testing data generators#
from torch.utils.data import ConcatDataset
train_generator_data = ConcatDataset(train_generator.datas)
test_generator_data = ConcatDataset(test_generator.datas)
from torch.utils.data import random_split, DataLoader
torch.manual_seed(0)
valid_frac = 0.20
train_length = len(train_generator_data)
valid_num = int(valid_frac * train_length)
batch_size = 32
train_dataset, valid_dataset = random_split(
train_generator_data, [train_length - valid_num, valid_num]
)
def collate(items):
l = sum(items, [])
return Batch.from_data_list(l)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# train_loader.collate_fn = collate
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
# valid_loader.collate_fn = collate
test_loader = DataLoader(test_generator_data, batch_size=batch_size, shuffle=False)
# test_loader.collate_fn = collate
train_samples = len(train_dataset)
valid_samples = len(valid_dataset)
test_samples = len(test_generator_data)
print(train_length)
print(train_samples)
print(valid_samples)
print(test_samples)
9387
7510
1877
3751
Train#
import os.path as osp
n_epochs = 30
stale_epochs = 0
best_valid_loss = 99999
patience = 5
t = tqdm(range(0, n_epochs))
for epoch in t:
loss = train(
model,
optimizer,
train_loader,
train_samples,
batch_size,
leave=bool(epoch == n_epochs - 1),
)
valid_loss = test(
model,
valid_loader,
valid_samples,
batch_size,
leave=bool(epoch == n_epochs - 1),
)
print("Epoch: {:02d}, Training Loss: {:.4f}".format(epoch, loss))
print(" Validation Loss: {:.4f}".format(valid_loss))
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
modpath = osp.join("deepsets_best.pth")
print("New best model saved to:", modpath)
torch.save(model.state_dict(), modpath)
stale_epochs = 0
else:
print("Stale epoch")
stale_epochs += 1
if stale_epochs >= patience:
print("Early stopping after %i stale epochs" % patience)
break
Epoch: 00, Training Loss: 0.4502
Validation Loss: 0.4461
New best model saved to: deepsets_best.pth
Epoch: 01, Training Loss: 0.4396
Validation Loss: 0.4470
Stale epoch
Epoch: 02, Training Loss: 0.4404
Validation Loss: 0.4452
New best model saved to: deepsets_best.pth
Epoch: 03, Training Loss: 0.4397
Validation Loss: 0.4475
Stale epoch
Epoch: 04, Training Loss: 0.4396
Validation Loss: 0.4489
Stale epoch
Epoch: 05, Training Loss: 0.4397
Validation Loss: 0.4465
Stale epoch
Epoch: 06, Training Loss: 0.4392
Validation Loss: 0.4467
Stale epoch
Epoch: 07, Training Loss: 0.4392
Validation Loss: 0.4489
Stale epoch
Early stopping after 5 stale epochs
Evaluate on testing data#
# In case you need to load the model from a pth file
# Trained on 4 vectors (as above in notebook)
# urlPTH = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/deepsets_best_4vec.pth"
# pthFile = wget.download(urlPTH)
# model.load_state_dict(torch.load("deepsets_best_4vec.pth"))
# Trained on all possible inputs (a different configuration)
# urlPTH = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/deepsets_best_AllTraining.pth"
# pthFile = wget.download(urlPTH)
# model.load_state_dict(torch.load("deepsets_best_AllTraining.pth"))
model.eval()
t = tqdm(enumerate(test_loader), total=test_samples / batch_size)
y_test = []
y_predict = []
track_pt = []
for i, data in t:
x = data[0].to(device)
y = data[1].to(device)
track_pt.append(x[:, 0, 0].numpy())
batch_output = model(x)
y_predict.append(batch_output.detach().cpu().numpy())
y_test.append(y.cpu().numpy())
track_pt = np.concatenate(track_pt)
y_test = np.concatenate(y_test)
y_predict = np.concatenate(y_predict)
# For Colab
import matplotlib.pyplot as plt
_, bins, _ = plt.hist(
track_pt[y_test[:, 1] == 1], bins=50, label="sig", histtype="step"
)
_, bins, _ = plt.hist(
track_pt[y_test[:, 1] == 0], bins=bins, label="bkg", histtype="step"
)
plt.legend()
plt.semilogy()
[]
from sklearn.metrics import roc_curve, auc
import mplhep as hep
plt.style.use(hep.style.ROOT)
# create ROC curves
fpr_deepset, tpr_deepset, threshold_deepset = roc_curve(y_test[:, 1], y_predict[:, 1])
with open("deepset_roc.npy", "wb") as f:
np.save(f, fpr_deepset)
np.save(f, tpr_deepset)
np.save(f, threshold_deepset)
# plot ROC curves
plt.figure()
plt.plot(
tpr_deepset,
fpr_deepset,
lw=2.5,
label="DeepSet, AUC = {:.1f}%".format(auc(fpr_deepset, tpr_deepset) * 100),
)
plt.xlabel(r"True positive rate")
plt.ylabel(r"False positive rate")
plt.ylim(0.001, 1)
plt.xlim(0, 1)
plt.grid(True)
plt.legend(loc="upper left")
plt.show()
model.eval()
t = tqdm(enumerate(test_loader), total=test_samples / batch_size)
y_test = []
y_predict = []
track_pt = []
for i, data in t:
x = data[0].to(device)
y = data[1].to(device)
idx = torch.randperm(x.size(2))
x = x[:, :, idx]
track_pt.append(x[:, 0, 0].numpy())
batch_output = model(x)
y_predict.append(batch_output.detach().cpu().numpy())
y_test.append(y.cpu().numpy())
track_pt = np.concatenate(track_pt)
y_test = np.concatenate(y_test)
y_predict = np.concatenate(y_predict)
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import mplhep as hep
plt.style.use(hep.style.ROOT)
# create ROC curves
fpr_deepset_perm, tpr_deepset_perm, threshold_deepset_perm = roc_curve(
y_test[:, 1], y_predict[:, 1]
)
with open("deepset_perm_roc.npy", "wb") as f:
np.save(f, fpr_deepset)
np.save(f, tpr_deepset)
np.save(f, threshold_deepset)
with open("deepset_roc.npy", "rb") as f:
fpr_deepset = np.load(f)
tpr_deepset = np.load(f)
threshold_deepset = np.load(f)
# plot ROC curves
plt.figure()
plt.plot(
tpr_deepset,
fpr_deepset,
lw=2.5,
label="DeepSet, AUC = {:.1f}%".format(auc(fpr_deepset, tpr_deepset) * 100),
)
plt.plot(
tpr_deepset_perm,
fpr_deepset_perm,
lw=2.5,
label="DeepSet (Permuted), AUC = {:.1f}%".format(
auc(fpr_deepset_perm, tpr_deepset_perm) * 100
),
linestyle="dashed",
)
plt.xlabel(r"True positive rate")
plt.ylabel(r"False positive rate")
plt.ylim(0.001, 1)
plt.xlim(0, 1)
plt.grid(True)
plt.legend(loc="upper left")
plt.show()
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import mplhep as hep
plt.style.use(hep.style.ROOT)
# create ROC curves
pt_split = 5.0
fpr_deepset_lowb, tpr_deepset_lowb, threshold_deepset_lowb = roc_curve(
y_test[track_pt <= pt_split, 1], y_predict[track_pt <= pt_split, 1]
)
fpr_deepset_highb, tpr_deepset_highb, threshold_deepset_highb = roc_curve(
y_test[track_pt > pt_split, 1], y_predict[track_pt > pt_split, 1]
)
# plot ROC curves
plt.figure()
plt.plot(
tpr_deepset_lowb,
fpr_deepset_lowb,
lw=2.5,
label="DeepSet (low pT), AUC = {:.1f}%".format(
auc(fpr_deepset_lowb, tpr_deepset_lowb) * 100
),
)
plt.plot(
tpr_deepset_highb,
fpr_deepset_highb,
lw=2.5,
label="DeepSet (high pT), AUC = {:.1f}%".format(
auc(fpr_deepset_highb, tpr_deepset_highb) * 100
),
)
plt.xlabel(r"True positive rate")
plt.ylabel(r"False positive rate")
plt.ylim(0.001, 1)
plt.xlim(0, 1)
plt.grid(True)
plt.legend(loc="upper left")
plt.show()
If you finish this notebook, you can go back and retrain with different inputs. Replace the code above:
# WGET for colab
if not os.path.exists("definitions_lorentz.yml"):
url = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/definitions_lorentz.yml"
definitionsFile = wget.download(url)
with open("definitions_lorentz.yml") as file:
# The FullLoader parameter handles the conversion from YAML
# scalar values to Python the dictionary format
definitions = yaml.load(file, Loader=yaml.FullLoader)
with
# WGET for colab
if not os.path.exists("definitions.yml"):
url = "https://raw.githubusercontent.com/jmduarte/iaifi-summer-school/main/book/definitions.yml"
definitionsFile = wget.download(url)
with open("definitions.yml") as file:
# The FullLoader parameter handles the conversion from YAML
# scalar values to Python the dictionary format
definitions = yaml.load(file, Loader=yaml.FullLoader)