Skip to content

Influence functions for outlier detection

This notebook shows how to calculate influences on a NN model using pyDVL for an arbitrary dataset, and how this can be used to find anomalous or corrupted data points.

It uses the wine dataset from sklearn: given a set of 13 different input parameters regarding a particular bottle, each related to some physical property (e.g. concentration of magnesium, malic acidity, alcoholic percentage, etc.), the model will need to predict to which of 3 classes the wine belongs to. For more details, please refer to the sklearn documentation.

If you are reading this in the documentation, some boilerplate has been omitted for convenience.

Let's start by loading the imports, the dataset and splitting it into train, validation and test sets. We will use a large test set to have a less noisy estimate of the average influence.

Imports

%autoreload
%matplotlib inline

import os
import random

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from support.common import plot_losses
from support.torch import TorchMLP, fit_torch_model
from pydvl.influence import compute_influences
from pydvl.influence.torch import TorchTwiceDifferentiable
from support.shapley import load_wine_dataset
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, f1_score
from torch.optim import Adam, lr_scheduler
from torch.utils.data import DataLoader, TensorDataset

Constants

random_state = 24
is_CI = os.environ.get("CI")
random.seed(random_state)
np.random.seed(random_state)

Dataset

training_data, val_data, test_data, feature_names = load_wine_dataset(
    train_size=0.3, test_size=0.6
)
# In CI we only use a subset of the training set
if is_CI:
    train_data = (training_data[0][:10], training_data[1][:10])

We will corrupt some of the training points by flipping their labels

num_corrupted_idxs = 10
training_data[1][:num_corrupted_idxs] = torch.tensor(
    [(val + 1) % 3 for val in training_data[1][:num_corrupted_idxs]]
)

and let's wrap it in a pytorch data loader

training_data_loader = DataLoader(
    TensorDataset(*training_data), batch_size=32, shuffle=False
)
val_data_loader = DataLoader(TensorDataset(*val_data), batch_size=32, shuffle=False)
test_data_loader = DataLoader(TensorDataset(*test_data), batch_size=32, shuffle=False)

Fit a neural network to the data

We will train a 2-layer neural network. PyDVL has some convenience wrappers to initialize a pytorch NN. If you already have a model loaded and trained, you can skip this section.

feature_dimension = 13
num_classes = 3
network_size = [16, 16]
layers_size = [feature_dimension, *network_size, num_classes]
num_epochs = 300
lr = 0.005
weight_decay = 0.01
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

nn_model = TorchMLP(layers_size)
nn_model.to(device)

optimizer = Adam(params=nn_model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

losses = fit_torch_model(
    model=nn_model,
    training_data=training_data_loader,
    val_data=val_data_loader,
    loss=F.cross_entropy,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=num_epochs,
)
Model fitting:   0%|          | 0/300 [00:00<?, ?it/s]

Let's check that the training has found a stable minimum by plotting the training and validation loss

plot_losses(losses)
No description has been provided for this image

Since it is a classification problem, let's also take a look at the confusion matrix on the test set

No description has been provided for this image

And let's compute the f1 score of the model

f1_score(test_data[1], pred_y_test, average="weighted")
0.9906846833902615

Let's now move to calculating influences of each point on the total score.

Calculating influences for small neural networks

The following cell calculates the influences of each training data point on the neural network. Neural networks have typically a very bumpy parameter space, which, during training, is explored until the configuration that minimises the loss is found. There is an important assumption in influence functions that the model lays at a (at least local) minimum of such loss, and if this is not fulfilled many issues can arise. In order to avoid this scenario, a regularisation term should be used whenever dealing with big and noisy models.

train_influences = compute_influences(
    TorchTwiceDifferentiable(nn_model, F.cross_entropy),
    training_data=training_data_loader,
    test_data=test_data_loader,
    influence_type="up",
    inversion_method="direct",
    hessian_regularization=0.1,
    progress=True,
)
Batch Test Gradients:   0%|          | 0/4 [00:00<?, ?it/s]
Batch Split Input Gradients:   0%|          | 0/2 [00:00<?, ?it/s]

the returned matrix, train_influences, has a quantity of columns equal to the points in the training set, and a number of rows equal to the points in the test set. At each element \(a_{i,j}\) it stores the influence that training point \(j\) has on the classification of test point \(i\).

If we take the average across every column of the influences matrix, we obtain an estimate of the overall influence of a training point on the total accuracy of the network.

mean_train_influences = np.mean(train_influences.numpy(), axis=0)

The following histogram shows that there are big differences in score within the training set (notice the log-scale on the y axis).

No description has been provided for this image

We can see that the corrupted points tend to have a negative effect on the model, as expected

Average influence of corrupted points:  -0.06944679
Average influence of other points:  0.04018428

Influence of training features

We have seen how to calculate the influence of single training points on each test point using influence_type 'up'. Using influence_type 'perturbation' we can also calculate the influence of the input features of each point. In the next cell we will calculate the average influence of each feature on training and test points, and ultimately assess which are the most relevant to model performance.

feature_influences = compute_influences(
    TorchTwiceDifferentiable(nn_model, F.cross_entropy),
    training_data=training_data_loader,
    test_data=test_data_loader,
    influence_type="perturbation",
    inversion_method="direct",
    hessian_regularization=1,
    progress=True,
)
Batch Test Gradients:   0%|          | 0/4 [00:00<?, ?it/s]
Batch Influence Perturbation:   0%|          | 0/2 [00:00<?, ?it/s]
No description has been provided for this image

Speeding up influences for big models

The calculation of the Hessian matrix (necessary to calculate the influences) can be quite numerically challenging, but there are some techniques to speed up its calculation. PyDVL allows to use the full method ("direct") or the conjugate gradient method ("cg"). The first one should be used only for very small networks (like our current example), while for bigger ones "cg" is advisable.

cg_train_influences = compute_influences(
    TorchTwiceDifferentiable(nn_model, F.cross_entropy),
    training_data=training_data_loader,
    test_data=test_data_loader,
    influence_type="up",
    inversion_method="cg",
    hessian_regularization=0.1,
    progress=True,
)
mean_cg_train_influences = np.mean(cg_train_influences.numpy(), axis=0)
Batch Test Gradients:   0%|          | 0/4 [00:00<?, ?it/s]
Batch Split Input Gradients:   0%|          | 0/2 [00:00<?, ?it/s]

Let's compare the results obtained through conjugate gradient with those from the direct method

Percentage error of cg over direct method:5.899372013118409e-05 %

This was a quick introduction to the pyDVL interface for influence functions. Despite their speed and simplicity, influence functions are known to be a very noisy estimator of data quality, as pointed out in the paper "Influence functions in deep learning are fragile". The size of the network, the weight decay, the inversion method used for calculating influences, the size of the test set: they all add up to the total amount of noise. Experiments may therefore give quantitative and qualitatively different results if not averaged across several realisations. Shapley values, on the contrary, have shown to be a more robust, but this comes at the cost of high computational requirements. PyDVL employs several parallelization and caching techniques to optimize such calculations.


Last update: 2023-10-14
Created: 2023-10-14