Skip to content

Influence functions for neural networks

This notebook explores the use of influence functions for convolutional neural networks. In the first part we will investigate the usefulness, or lack thereof, of influence functions for the interpretation of a classifier's outputs.

For our study we choose a pre-trained ResNet18, fine-tuned on the tiny-imagenet dataset . This dataset was created for a Stanford course on Deep Learning for Computer Vision , and is a subset of the famous ImageNet with 200 classes instead of 1000, and images down-sampled to a lower resolution of 64x64 pixels.

After tuning the last layers of the network, we will use pyDVL to find the most and the least influential training images for the test set. This can sometimes be used to explain inference errors, or to direct efforts during data collection, although we will face inconclusive results with our model and data. This illustrates well-known issues of influence functions for neural networks.

However, in the final part of the notebook we will see that influence functions are an effective tool for finding anomalous or corrupted data points.

We conclude with an appendix with some basic theoretical concepts used.

If you are reading this in the documentation, some boilerplate has been omitted for convenience.

Imports and setup

from pydvl.influence.torch import CgInfluence
from pydvl.reporting.plots import plot_influence_distribution_by_label
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, f1_score

Loading and preprocessing the dataset

We pick two classes arbitrarily to work with: 90 and 100, corresponding respectively to dining tables, and boats in Venice (you can of course select any other two classes, or more of them, although that would imply longer training times and some modifications in the notebook below). The dataset is loaded with load_preprocess_imagenet() , which returns three pandas DataFrames with training, validation and test sets respectively. Each dataframe has three columns: normalized images, labels and the original images. Note that you can load a subset of the data decreasing downsampling_ratio.

label_names = {90: "tables", 100: "boats"}
train_ds, val_ds, test_ds = load_preprocess_imagenet(
    train_size=0.8,
    test_size=0.1,
    keep_labels=label_names,
    downsampling_ratio=1,
)

print("Normalised image dtype:", train_ds["normalized_images"][0].dtype)
print("Label type:", type(train_ds["labels"][0]))
print("Image type:", type(train_ds["images"][0]))
train_ds.info()

Let's take a closer look at a few image samples

No description has been provided for this image

Let's now further pre-process the data and prepare for model training. The helper function process_io converts the normalized images into tensors and the labels to the indices 0 and 1 to train the classifier.

def process_io(df: pd.DataFrame, labels: dict) -> Tuple[torch.Tensor, torch.Tensor]:
    x = df["normalized_images"]
    y = df["labels"]
    ds_label_to_model_label = {
        ds_label: idx for idx, ds_label in enumerate(labels.values())
    }
    x_nn = torch.stack(x.tolist()).to(DEVICE)
    y_nn = torch.tensor([ds_label_to_model_label[yi] for yi in y], device=DEVICE)
    return x_nn, y_nn


train_x, train_y = process_io(train_ds, label_names)
val_x, val_y = process_io(val_ds, label_names)
test_x, test_y = process_io(test_ds, label_names)

batch_size = 768
train_data = DataLoader(TensorDataset(train_x, train_y), batch_size=batch_size)
test_data = DataLoader(TensorDataset(test_x, test_y), batch_size=batch_size)
val_data = DataLoader(TensorDataset(val_x, val_y), batch_size=batch_size)

Model definition and training

We use a ResNet18 from torchvision with final layers modified for binary classification.

For training, we use the convenience class TrainingManager which transparently handles persistence after training. It is not part of the main pyDVL package but just a way to reduce clutter in this notebook.

We train the model for 50 epochs and save the results. Then we plot the train and validation loss curves.

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = new_resnet_model(output_size=len(label_names))
mgr = TrainingManager(
    "model_ft",
    model_ft,
    nn.CrossEntropyLoss(),
    train_data,
    val_data,
    MODEL_PATH,
    device=device,
)
# Set use_cache=False to retrain the model
train_loss, val_loss = mgr.train(n_epochs=50, use_cache=True)
plot_losses(Losses(train_loss, val_loss))
No description has been provided for this image

The confusion matrix and \(F_1\) score look good, especially considering the low resolution of the images and their complexity (they contain different objects)

pred_y_test = np.argmax(model_ft(test_x).cpu().detach(), axis=1).cpu()
model_score = f1_score(test_y.cpu(), pred_y_test, average="weighted")

cm = confusion_matrix(test_y.cpu(), pred_y_test)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=label_names.values())
print("f1_score of model:", model_score)
disp.plot();
f1_score of model: 0.9062805208898536

No description has been provided for this image

Influence computation

Let's now calculate influences! The central interface for computing influences is InfluenceFunctionModel . Since Resnet18 is quite big, we pick the conjugate gradient implementation CgInfluence , which takes a trained torch.nn.Module , the training loss and the training data. Other important parameters are the Hessian regularization term, which should be chosen as small as possible for the computation to converge (further details on why this is important can be found in the Appendix ).

influence_model = CgInfluence(mgr.model, mgr.loss, hessian_reg, progress=True)
influence_model = influence_model.fit(train_data)

On the instantiated influence object, we can call the method influences , which takes some test data and some input dataset with labels (which typically is the training data, or a subset of it). The influence type will be up . The other option, perturbation , is beyond the scope of this notebook, but more info can be found in the notebook using the Wine dataset or in the documentation for pyDVL.

influences = influence_model.influences(test_x, test_y, train_x, train_y, mode="up")

The output is a matrix of size test_set_length x training_set_length . Each row represents a test data point, and each column a training data point, so that entry \((i,j)\) represents the influence of training point \(j\) on test point \(i\) .

Analysing influences

With the computed influences we can study single images or all of them together:

Influence on a single test image

Let's take any image in the test set:

No description has been provided for this image

Now we plot the histogram of the influence that all training images have on the image selected above, separated by their label.

No description has been provided for this image

Rather unsurprisingly, the training points with the highest influence have the same label. Now we can take the training images with the same label and show those with highest and lowest scores.

No description has been provided for this image

Looking at the images, it is difficult to explain why those on the right are more influential than those on the left. At first sight, the choice seems to be random (or at the very least noisy). Let's dig in a bit more by looking at average influences:

Analysing the average influence on test samples

By averaging across the rows of the influence matrix, we obtain the average influence of each training sample on the whole test set:

avg_influences = np.mean(influences.cpu().numpy(), axis=0)

Once again, let's plot the histogram of influence values by label.

No description has been provided for this image

Next, for each class (you can change value by changing label key) we can have a look at the top and bottom images by average influence, i.e. we can show the images that have the highest and lowest average influence over all test images.