Hands-on 08: Model Compression#
! pip install --user --quiet tensorflow-model-optimization
from tensorflow.keras.utils import to_categorical
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
seed = 42
np.random.seed(seed)
import tensorflow as tf
tf.random.set_seed(seed)
2024-06-05 04:35:10.515277: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
Cell In[2], line 1
----> 1 from tensorflow.keras.utils import to_categorical
2 from sklearn.datasets import fetch_openml
3 from sklearn.model_selection import train_test_split
ModuleNotFoundError: No module named 'tensorflow.keras'
Fetch the jet tagging dataset from Open ML#
data = fetch_openml("hls4ml_lhc_jets_hlf")
X, y = data["data"], data["target"]
le = LabelEncoder()
y_onehot = le.fit_transform(y)
y_onehot = to_categorical(y_onehot, 5)
classes = le.classes_
X_train_val, X_test, y_train_val, y_test = train_test_split(X, y_onehot, test_size=0.2, random_state=42)
scaler = StandardScaler()
X_train_val = scaler.fit_transform(X_train_val)
X_test = scaler.transform(X_test)
Now construct a model#
We’ll use the same architecture as in part 1: 3 hidden layers with 64, then 32, then 32 neurons. Each layer will use relu
activation.
Add an output layer with 5 neurons (one for each class), then finish with Softmax activation.
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l1
from callbacks import all_callbacks
model = Sequential()
model.add(Dense(64, input_shape=(16,), name="fc1", kernel_initializer="lecun_uniform"))
model.add(Activation(activation="relu", name="relu1"))
model.add(Dense(32, name="fc2", kernel_initializer="lecun_uniform"))
model.add(Activation(activation="relu", name="relu2"))
model.add(Dense(32, name="fc3", kernel_initializer="lecun_uniform"))
model.add(Activation(activation="relu", name="relu3"))
model.add(Dense(5, name="output", kernel_initializer="lecun_uniform"))
model.add(Activation(activation="softmax", name="softmax"))
Train the unpruned model#
adam = Adam(learning_rate=0.0001)
model.compile(optimizer=adam, loss=["categorical_crossentropy"], metrics=["accuracy"])
callbacks = all_callbacks(
stop_patience=1000,
lr_factor=0.5,
lr_patience=10,
lr_epsilon=0.000001,
lr_cooldown=2,
lr_minimum=0.0000001,
outputDir="unpruned_model",
)
model.fit(
X_train_val, y_train_val, batch_size=1024, epochs=30, validation_split=0.25, shuffle=True, callbacks=callbacks.callbacks
)
Train the pruned model#
This time we’ll use the Tensorflow model optimization sparsity to train a sparse model (forcing many weights to ‘0’). In this instance, the target sparsity is 75%
from tensorflow_model_optimization.python.core.sparsity.keras import prune, pruning_callbacks, pruning_schedule
from tensorflow_model_optimization.sparsity.keras import strip_pruning
pruned_model = Sequential()
pruned_model.add(Dense(64, input_shape=(16,), name="fc1", kernel_initializer="lecun_uniform", kernel_regularizer=l1(0.0001)))
pruned_model.add(Activation(activation="relu", name="relu1"))
pruned_model.add(Dense(32, name="fc2", kernel_initializer="lecun_uniform", kernel_regularizer=l1(0.0001)))
pruned_model.add(Activation(activation="relu", name="relu2"))
pruned_model.add(Dense(32, name="fc3", kernel_initializer="lecun_uniform", kernel_regularizer=l1(0.0001)))
pruned_model.add(Activation(activation="relu", name="relu3"))
pruned_model.add(Dense(5, name="output", kernel_initializer="lecun_uniform", kernel_regularizer=l1(0.0001)))
pruned_model.add(Activation(activation="softmax", name="softmax"))
pruning_params = {"pruning_schedule": pruning_schedule.ConstantSparsity(0.75, begin_step=2000, frequency=100)}
pruned_model = prune.prune_low_magnitude(pruned_model, **pruning_params)
We’ll use the same settings as before: Adam optimizer with categorical crossentropy loss.
The callbacks will decay the learning rate and save the model into a directory pruned_model
.
adam = Adam(lr=0.0001)
pruned_model.compile(optimizer=adam, loss=["categorical_crossentropy"], metrics=["accuracy"])
callbacks = all_callbacks(
stop_patience=1000,
lr_factor=0.5,
lr_patience=10,
lr_epsilon=0.000001,
lr_cooldown=2,
lr_minimum=0.0000001,
outputDir="pruned_model",
)
callbacks.callbacks.append(pruning_callbacks.UpdatePruningStep())
pruned_model.fit(
X_train_val,
y_train_val,
batch_size=1024,
epochs=30,
validation_split=0.25,
shuffle=True,
callbacks=callbacks.callbacks,
verbose=0,
)
# Save the model again but with the pruning 'stripped' to use the regular layer types
pruned_model = strip_pruning(pruned_model)
pruned_model.save("pruned_model/model_best.h5")
Check sparsity#
Make a quick check that the model was indeed trained sparse. We’ll just make a histogram of the weights of the 1st layer, and hopefully observe a large peak in the bin containing ‘0’. Note logarithmic y axis.
bins = np.arange(-2, 2, 0.04)
w_unpruned = model.layers[0].weights[0].numpy().flatten()
w_pruned = pruned_model.layers[0].weights[0].numpy().flatten()
plt.figure(figsize=(7, 7))
plt.hist(w_unpruned, bins=bins, alpha=0.7, label="Unpruned layer 1")
plt.hist(w_pruned, bins=bins, alpha=0.7, label="Pruned layer 1")
plt.xlabel("Weight value")
plt.ylabel("Number of weights")
plt.semilogy()
plt.legend()
print(f"Sparsity of unpruned model layer 1: {np.sum(w_unpruned==0)*100/np.size(w_unpruned)}% zeros")
print(f"Sparsity of pruned model layer 1: {np.sum(w_pruned==0)*100/np.size(w_pruned)}% zeros")
plt.show()
Check performance#
How does this 75% sparse model compare against the unpruned model? Let’s report the accuracy and make a ROC curve. The pruned model is shown with solid lines, the unpruned model is shown with dashed lines.
import plotting
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from tensorflow.keras.models import load_model
unpruned_model = load_model("unpruned_model/model_best.h5")
y_ref = unpruned_model.predict(X_test, verbose=0)
y_prune = pruned_model.predict(X_test, verbose=0)
print("Accuracy unpruned: {}".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_ref, axis=1))))
print("Accuracy pruned: {}".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_prune, axis=1))))
fig, ax = plt.subplots(figsize=(9, 9))
_ = plotting.make_roc(y_test, y_ref, classes)
plt.gca().set_prop_cycle(None) # reset the colors
_ = plotting.make_roc(y_test, y_prune, classes, linestyle="--")
from matplotlib.lines import Line2D
lines = [Line2D([0], [0], ls="-"), Line2D([0], [0], ls="--")]
from matplotlib.legend import Legend
leg = Legend(ax, lines, labels=["Unpruned", "Pruned"], loc="lower right", frameon=False)
ax.add_artist(leg)
plt.show()
converter = tf.lite.TFLiteConverter.from_keras_model(pruned_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model_quant = converter.convert()
def representative_data_gen():
for input_value in tf.data.Dataset.from_tensor_slices(X_train_val.astype(np.float32)).batch(1).take(100):
# Model has only one input so each data point has one element.
yield [input_value]
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
tflite_model_quant = converter.convert()
import pathlib
tflite_models_dir = pathlib.Path("tflite_models/")
tflite_models_dir.mkdir(exist_ok=True, parents=True)
# Save the quantized model:
tflite_model_quant_file = tflite_models_dir / "model_quant.tflite"
tflite_model_quant_file.write_bytes(tflite_model_quant)
# Helper function to run inference on a TFLite model
def run_tflite_model(tflite_file, X_test_indices):
global X_test
# Initialize the interpreter
interpreter = tf.lite.Interpreter(model_path=str(tflite_file))
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()[0]
output_details = interpreter.get_output_details()[0]
predictions = np.zeros((len(X_test_indices), 5), dtype=np.float32)
for i, X_test_index in enumerate(X_test_indices):
X_test_i = X_test[X_test_index]
# Check if the input type is quantized, then rescale input data to uint8
if input_details["dtype"] == np.uint8:
input_scale, input_zero_point = input_details["quantization"]
X_test_i = X_test_i / input_scale + input_zero_point
X_test_i = np.expand_dims(X_test_i, axis=0).astype(input_details["dtype"])
interpreter.set_tensor(input_details["index"], X_test_i)
interpreter.invoke()
output = interpreter.get_tensor(output_details["index"])[0]
predictions[i] = output
return predictions
X_test_indices = list(range(0, len(X_test)))
y_quant = run_tflite_model(tflite_model_quant_file, X_test_indices)
print("Accuracy pruned+quantized: {}".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_quant, axis=1))))
fig, ax = plt.subplots(figsize=(9, 9))
_ = plotting.make_roc(y_test, y_ref, classes)
plt.gca().set_prop_cycle(None) # reset the colors
_ = plotting.make_roc(y_test, y_prune, classes, linestyle="--")
plt.gca().set_prop_cycle(None) # reset the colors
_ = plotting.make_roc(y_test, y_quant, classes, linestyle="-.")
from matplotlib.lines import Line2D
lines = [Line2D([0], [0], ls="-"), Line2D([0], [0], ls="--"), Line2D([0], [0], ls="-.")]
from matplotlib.legend import Legend
leg = Legend(ax, lines, labels=["Unpruned", "Pruned", "Quantized"], loc="lower right", frameon=False)
ax.add_artist(leg)
plt.show()
print(y_quant)
print(y_prune)
print(y_ref)