Influence functions for data mislabeling ΒΆ
In this notebook, we will take a closer look at the theory of influence functions with the help of a synthetic dataset. Data mislabeling occurs whenever some examples from a usually big dataset are wrongly-labeled. In real-life this happens fairly often, e.g. as a consequence of human error, or noise in the data.
Let's consider a classification problem with the following notation:
In other words, we have a dataset containing \(N\) samples, each with label 1 or 0. As typical example you can think of y indicating whether a patient has a disease based on some feature representation \(x\) .
Let's now introduce a toy model that will help us delve into the theory and practical utility of influence functions. We will assume that \(y\) is a Bernoulli binary random variable while the input \(x\) is d-dimensional Gaussian distribution which depends on the label \(y\) . More precisely:
with fixed means and diagonal covariance. Implementing the sampling scheme in python is straightforward and can be achieved by first sampling \(y\) and afterward \(x\) .
Imports ΒΆ
%autoreload
%matplotlib inline
import os
import random
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from pydvl.influence.torch import DirectInfluence, CgInfluence
from support.shapley import (
synthetic_classification_dataset,
decision_boundary_fixed_variance_2d,
)
from support.common import (
plot_gaussian_blobs,
plot_losses,
plot_influences,
)
from support.torch import (
fit_torch_model,
TorchLogisticRegression,
)
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torch.optim import AdamW, lr_scheduler
from torch.utils.data import DataLoader, TensorDataset
Dataset ΒΆ
The following code snippet generates the aforementioned dataset.
Given the simplicity of the dataset, we can calculate exactly the optimal decision boundary(that which maximizes our accuracy). The following code maps a continuous line of z values to a 2-dimensional vector in feature space (More details are in the appendix to this notebook.)
Plotting the dataset ΒΆ
Let's plot the dataset is plotted with their respective labels and the optimal decision line
Note that there are samples which go across the optimal decision boundary and will be wrongly labelled. The optimal decision boundary can not discriminate these as the mislabelling is a consequence of the presence of random noise.
Training the model ΒΆ
We will now train a logistic regression model on the training data. This can be done with the following
model = TorchLogisticRegression(num_features)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
num_epochs = 50
lr = 0.05
weight_decay = 0.05
batch_size = 256
train_data_loader = DataLoader(
TensorDataset(
torch.as_tensor(train_data[0]),
torch.as_tensor(train_data[1], dtype=torch.float64).unsqueeze(-1),
),
batch_size=batch_size,
shuffle=True,
)
val_data_loader = DataLoader(
TensorDataset(
torch.as_tensor(val_data[0]),
torch.as_tensor(val_data[1], dtype=torch.float64).unsqueeze(-1),
),
batch_size=batch_size,
shuffle=True,
)
optimizer = AdamW(params=model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
losses = fit_torch_model(
model=model,
training_data=train_data_loader,
val_data=val_data_loader,
loss=F.binary_cross_entropy,
optimizer=optimizer,
scheduler=scheduler,
num_epochs=num_epochs,
device=device,
)
And let's check that the model is not overfitting
A look at the confusion matrix also shows good results
Calculating influences ΒΆ
It is important that the model converges to a point near the optimum, since the influence values assume that we are at a minimum (or close) in the loss landscape. The function
measures the influence of the data point \(x_1\) onto \(x_2\) conditioned on the training targets \(y_1\) and \(y_2\) trough some model parameters \(\theta\) . If the loss function L is differentiable, we can take \(I\) to be
$$ I(x_1, x_2) = \nabla_\theta\; L(x_1, y_1) ^\mathsf{T} \; H_\theta^{-1} \; \nabla_\theta \; L(x_2, y_2) $$ See "Understanding Black-box Predictions via Influence Functions" for a detailed derivation of this formula
Let's take a subset of the training data points, which we will calculate the influence values of.
In pyDVL, the influence of the training points on the test points can be calculated with the following
train_x = torch.as_tensor(x)
train_y = torch.as_tensor(y, dtype=torch.float64).unsqueeze(-1)
test_x = torch.as_tensor(test_data[0])
test_y = torch.as_tensor(test_data[1], dtype=torch.float64).unsqueeze(-1)
train_data_loader = DataLoader(
TensorDataset(train_x, train_y),
batch_size=batch_size,
)
influence_model = DirectInfluence(
model,
F.binary_cross_entropy,
hessian_regularization=0.0,
)
influence_model = influence_model.fit(train_data_loader)
influence_values = influence_model.influences(
test_x, test_y, train_x, train_y, mode="up"
)
The above explicitly constructs the Hessian. This can often be computationally expensive and conjugate gradient approximate calculation should be used for bigger models.
With the influence type 'up', training influences have shape [NxM] where N is the number of test samples and M is the number of training samples. They therefore associate to each training sample its influence on each test sample. Influence type 'perturbation', instead, return an array of shape [NxMxF], where F is the number of features in input, ie. the length of x.
In our case, in order to have a value of the total average influence of a point we can just average across training samples.
Let's plot the results (adjust colorbar_limits for better color gradient)
We can see that, as we approach the separation line, the influences tend to move away from zero, i.e. the points become more decisive for model training, some in a positive way, some negative.
As a further test, let's introduce some labelling errors into \(y\) and see how the distribution of the influences changes. Let's flip the first 10 labels and calculate influences
y_corrupted = np.copy(y)
y_corrupted[:10] = [1 - yi for yi in y[:10]]
train_y_corrupted = torch.as_tensor(y_corrupted, dtype=torch.float64).unsqueeze(-1)
train_corrupted_data_loader = DataLoader(
TensorDataset(
train_x,
train_y_corrupted,
),
batch_size=batch_size,
)
influence_model = DirectInfluence(
model,
F.binary_cross_entropy,
hessian_regularization=0.0,
)
influence_model = influence_model.fit(train_corrupted_data_loader)
influence_values = influence_model.influences(
test_x, test_y, train_x, train_y_corrupted, mode="up"
)
mean_train_influences = np.mean(influence_values.cpu().numpy(), axis=0)
Red circles indicate the points which have been corrupted. We can see that the mislabelled data have a more negative average influence on the model, especially those that are farther away from the decision boundary.
Inversion through conjugate gradient ΒΆ
The "direct" method that we have used above involves the inversion of the Hessian matrix of the model. If a model has \(n\) training points and \(\theta \in \mathbb{R}^p\) parameters, this requires \(O(n \ p^2 + p^3)\) operations, which for larger models, like neural networks, becomes quickly unfeasible. Conjugate gradient avoids the explicit computation of the Hessian via a technique called implicit Hessian-vector products (HVPs), which typically takes \(O(n \ p)\) operations.
In the next cell we will use conjugate gradient to compute the influence factors. Since logistic regression is a very simple model, "cg" actually slows computation with respect to the direct method, which in this case is a much better choice. Nevertheless, we are able to verify that the influences calculated with "cg" are the same (to a minor error) as those calculated directly.
influence_model = CgInfluence(
model,
F.binary_cross_entropy,
hessian_regularization=0.0,
)
influence_model = influence_model.fit(train_corrupted_data_loader)
influence_values = influence_model.influences(
test_x, test_y, train_x, train_y_corrupted
)
mean_train_influences = np.mean(influence_values.cpu().numpy(), axis=0)
print("Average mislabelled data influence:", np.mean(mean_train_influences[:10]))
print("Average correct data influence:", np.mean(mean_train_influences[10:]))
Averages are very similar to the ones calculated through direct method. Same is true for the plot
Appendix: Calculating the decision boundary ΒΆ
For obtaining the optimal discriminator one has to solve the equation
and determine the solution set \(X\) . Let's take the following probabilities
For a single fixed diagonal variance parameterized by \(\sigma\) , the optimal discriminator lays at points which are equidistant from the means of the two distributions, i.e.
This is just the implicit description of the line. Solving for the explicit form can be achieved by enforcing a functional form \(f(z) = x = a z + b\) with \(z \in \mathbb{R}\) onto \(x\) . After the term is inserted in the previous equation
We can write \(a\) since, by symmetry, it is expected to be explicitly orthogonal to \(\mu_2 - \mu_1\) . Then, solving for \(b\) , the solution can be found to be