Hands-on 08: Model Compression#

! pip install --user --quiet tensorflow-model-optimization
from tensorflow.keras.utils import to_categorical
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
seed = 42
np.random.seed(seed)
import tensorflow as tf

tf.random.set_seed(seed)
2024-06-05 04:35:10.515277: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[2], line 1
----> 1 from tensorflow.keras.utils import to_categorical
      2 from sklearn.datasets import fetch_openml
      3 from sklearn.model_selection import train_test_split

ModuleNotFoundError: No module named 'tensorflow.keras'

Fetch the jet tagging dataset from Open ML#

data = fetch_openml("hls4ml_lhc_jets_hlf")
X, y = data["data"], data["target"]

le = LabelEncoder()
y_onehot = le.fit_transform(y)
y_onehot = to_categorical(y_onehot, 5)
classes = le.classes_

X_train_val, X_test, y_train_val, y_test = train_test_split(X, y_onehot, test_size=0.2, random_state=42)


scaler = StandardScaler()
X_train_val = scaler.fit_transform(X_train_val)
X_test = scaler.transform(X_test)

Now construct a model#

We’ll use the same architecture as in part 1: 3 hidden layers with 64, then 32, then 32 neurons. Each layer will use relu activation. Add an output layer with 5 neurons (one for each class), then finish with Softmax activation.

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l1
from callbacks import all_callbacks
model = Sequential()
model.add(Dense(64, input_shape=(16,), name="fc1", kernel_initializer="lecun_uniform"))
model.add(Activation(activation="relu", name="relu1"))
model.add(Dense(32, name="fc2", kernel_initializer="lecun_uniform"))
model.add(Activation(activation="relu", name="relu2"))
model.add(Dense(32, name="fc3", kernel_initializer="lecun_uniform"))
model.add(Activation(activation="relu", name="relu3"))
model.add(Dense(5, name="output", kernel_initializer="lecun_uniform"))
model.add(Activation(activation="softmax", name="softmax"))

Train the unpruned model#

adam = Adam(learning_rate=0.0001)
model.compile(optimizer=adam, loss=["categorical_crossentropy"], metrics=["accuracy"])
callbacks = all_callbacks(
    stop_patience=1000,
    lr_factor=0.5,
    lr_patience=10,
    lr_epsilon=0.000001,
    lr_cooldown=2,
    lr_minimum=0.0000001,
    outputDir="unpruned_model",
)
model.fit(
    X_train_val, y_train_val, batch_size=1024, epochs=30, validation_split=0.25, shuffle=True, callbacks=callbacks.callbacks
)

Train the pruned model#

This time we’ll use the Tensorflow model optimization sparsity to train a sparse model (forcing many weights to ‘0’). In this instance, the target sparsity is 75%

from tensorflow_model_optimization.python.core.sparsity.keras import prune, pruning_callbacks, pruning_schedule
from tensorflow_model_optimization.sparsity.keras import strip_pruning

pruned_model = Sequential()
pruned_model.add(Dense(64, input_shape=(16,), name="fc1", kernel_initializer="lecun_uniform", kernel_regularizer=l1(0.0001)))
pruned_model.add(Activation(activation="relu", name="relu1"))
pruned_model.add(Dense(32, name="fc2", kernel_initializer="lecun_uniform", kernel_regularizer=l1(0.0001)))
pruned_model.add(Activation(activation="relu", name="relu2"))
pruned_model.add(Dense(32, name="fc3", kernel_initializer="lecun_uniform", kernel_regularizer=l1(0.0001)))
pruned_model.add(Activation(activation="relu", name="relu3"))
pruned_model.add(Dense(5, name="output", kernel_initializer="lecun_uniform", kernel_regularizer=l1(0.0001)))
pruned_model.add(Activation(activation="softmax", name="softmax"))

pruning_params = {"pruning_schedule": pruning_schedule.ConstantSparsity(0.75, begin_step=2000, frequency=100)}
pruned_model = prune.prune_low_magnitude(pruned_model, **pruning_params)

We’ll use the same settings as before: Adam optimizer with categorical crossentropy loss. The callbacks will decay the learning rate and save the model into a directory pruned_model.

adam = Adam(lr=0.0001)
pruned_model.compile(optimizer=adam, loss=["categorical_crossentropy"], metrics=["accuracy"])
callbacks = all_callbacks(
    stop_patience=1000,
    lr_factor=0.5,
    lr_patience=10,
    lr_epsilon=0.000001,
    lr_cooldown=2,
    lr_minimum=0.0000001,
    outputDir="pruned_model",
)
callbacks.callbacks.append(pruning_callbacks.UpdatePruningStep())
pruned_model.fit(
    X_train_val,
    y_train_val,
    batch_size=1024,
    epochs=30,
    validation_split=0.25,
    shuffle=True,
    callbacks=callbacks.callbacks,
    verbose=0,
)
# Save the model again but with the pruning 'stripped' to use the regular layer types
pruned_model = strip_pruning(pruned_model)
pruned_model.save("pruned_model/model_best.h5")

Check sparsity#

Make a quick check that the model was indeed trained sparse. We’ll just make a histogram of the weights of the 1st layer, and hopefully observe a large peak in the bin containing ‘0’. Note logarithmic y axis.

bins = np.arange(-2, 2, 0.04)
w_unpruned = model.layers[0].weights[0].numpy().flatten()
w_pruned = pruned_model.layers[0].weights[0].numpy().flatten()

plt.figure(figsize=(7, 7))

plt.hist(w_unpruned, bins=bins, alpha=0.7, label="Unpruned layer 1")
plt.hist(w_pruned, bins=bins, alpha=0.7, label="Pruned layer 1")

plt.xlabel("Weight value")
plt.ylabel("Number of weights")
plt.semilogy()
plt.legend()

print(f"Sparsity of unpruned model layer 1: {np.sum(w_unpruned==0)*100/np.size(w_unpruned)}% zeros")
print(f"Sparsity of pruned model layer 1: {np.sum(w_pruned==0)*100/np.size(w_pruned)}% zeros")
plt.show()

Check performance#

How does this 75% sparse model compare against the unpruned model? Let’s report the accuracy and make a ROC curve. The pruned model is shown with solid lines, the unpruned model is shown with dashed lines.

import plotting
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from tensorflow.keras.models import load_model

unpruned_model = load_model("unpruned_model/model_best.h5")

y_ref = unpruned_model.predict(X_test, verbose=0)
y_prune = pruned_model.predict(X_test, verbose=0)

print("Accuracy unpruned: {}".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_ref, axis=1))))
print("Accuracy pruned:   {}".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_prune, axis=1))))
fig, ax = plt.subplots(figsize=(9, 9))
_ = plotting.make_roc(y_test, y_ref, classes)
plt.gca().set_prop_cycle(None)  # reset the colors
_ = plotting.make_roc(y_test, y_prune, classes, linestyle="--")

from matplotlib.lines import Line2D

lines = [Line2D([0], [0], ls="-"), Line2D([0], [0], ls="--")]
from matplotlib.legend import Legend

leg = Legend(ax, lines, labels=["Unpruned", "Pruned"], loc="lower right", frameon=False)
ax.add_artist(leg)
plt.show()
converter = tf.lite.TFLiteConverter.from_keras_model(pruned_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

tflite_model_quant = converter.convert()
def representative_data_gen():
    for input_value in tf.data.Dataset.from_tensor_slices(X_train_val.astype(np.float32)).batch(1).take(100):
        # Model has only one input so each data point has one element.
        yield [input_value]


converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen

tflite_model_quant = converter.convert()
import pathlib

tflite_models_dir = pathlib.Path("tflite_models/")
tflite_models_dir.mkdir(exist_ok=True, parents=True)

# Save the quantized model:
tflite_model_quant_file = tflite_models_dir / "model_quant.tflite"
tflite_model_quant_file.write_bytes(tflite_model_quant)
# Helper function to run inference on a TFLite model
def run_tflite_model(tflite_file, X_test_indices):
    global X_test

    # Initialize the interpreter
    interpreter = tf.lite.Interpreter(model_path=str(tflite_file))
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()[0]
    output_details = interpreter.get_output_details()[0]

    predictions = np.zeros((len(X_test_indices), 5), dtype=np.float32)
    for i, X_test_index in enumerate(X_test_indices):
        X_test_i = X_test[X_test_index]

        # Check if the input type is quantized, then rescale input data to uint8
        if input_details["dtype"] == np.uint8:
            input_scale, input_zero_point = input_details["quantization"]
            X_test_i = X_test_i / input_scale + input_zero_point

        X_test_i = np.expand_dims(X_test_i, axis=0).astype(input_details["dtype"])
        interpreter.set_tensor(input_details["index"], X_test_i)
        interpreter.invoke()
        output = interpreter.get_tensor(output_details["index"])[0]
        predictions[i] = output

    return predictions
X_test_indices = list(range(0, len(X_test)))

y_quant = run_tflite_model(tflite_model_quant_file, X_test_indices)
print("Accuracy pruned+quantized:   {}".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_quant, axis=1))))
fig, ax = plt.subplots(figsize=(9, 9))
_ = plotting.make_roc(y_test, y_ref, classes)
plt.gca().set_prop_cycle(None)  # reset the colors
_ = plotting.make_roc(y_test, y_prune, classes, linestyle="--")
plt.gca().set_prop_cycle(None)  # reset the colors
_ = plotting.make_roc(y_test, y_quant, classes, linestyle="-.")

from matplotlib.lines import Line2D

lines = [Line2D([0], [0], ls="-"), Line2D([0], [0], ls="--"), Line2D([0], [0], ls="-.")]
from matplotlib.legend import Legend

leg = Legend(ax, lines, labels=["Unpruned", "Pruned", "Quantized"], loc="lower right", frameon=False)
ax.add_artist(leg)
plt.show()
print(y_quant)
print(y_prune)
print(y_ref)