Influence functions for data mislabeling ΒΆ
In this notebook, we will take a closer look at the theory of influence functions with the help of a synthetic dataset. Data mislabeling occurs whenever some examples from a usually big dataset are wrongly-labeled. In real-life this happens fairly often, e.g. as a consequence of human error, or noise in the data.
Let's consider a classification problem with the following notation:
In other words, we have a dataset containing \(N\) samples, each with label 1 or 0. As typical example you can think of y indicating whether a patient has a disease based on some feature representation \(x\) .
Let's now introduce a toy model that will help us delve into the theory and practical utility of influence functions. We will assume that \(y\) is a Bernoulli binary random variable while the input \(x\) is d-dimensional Gaussian distribution which depends on the label \(y\) . More precisely:
with fixed means and diagonal covariance. Implementing the sampling scheme in python is straightforward and can be achieved by first sampling \(y\) and afterward \(x\) .
Imports ΒΆ
%autoreload
%matplotlib inline
import os
import random
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from pydvl.influence.torch import DirectInfluence, CgInfluence
from support.shapley import (
synthetic_classification_dataset,
decision_boundary_fixed_variance_2d,
)
from support.common import (
plot_gaussian_blobs,
plot_losses,
plot_influences,
)
from support.torch import (
fit_torch_model,
TorchLogisticRegression,
)
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torch.optim import AdamW, lr_scheduler
from torch.utils.data import DataLoader, TensorDataset
Dataset ΒΆ
The following code snippet generates the aforementioned dataset.
Given the simplicity of the dataset, we can calculate exactly the optimal decision boundary(that which maximizes our accuracy). The following code maps a continuous line of z values to a 2-dimensional vector in feature space (More details are in the appendix to this notebook.)
Plotting the dataset ΒΆ
Let's plot the dataset is plotted with their respective labels and the optimal decision line
Note that there are samples which go across the optimal decision boundary and will be wrongly labelled. The optimal decision boundary can not discriminate these as the mislabelling is a consequence of the presence of random noise.
Training the model ΒΆ
We will now train a logistic regression model on the training data. This can be done with the following
model = TorchLogisticRegression(num_features)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
num_epochs = 50
lr = 0.05
weight_decay = 0.05
batch_size = 256
train_data_loader = DataLoader(
TensorDataset(
torch.as_tensor(train_data[0]),
torch.as_tensor(train_data[1], dtype=torch.float64).unsqueeze(-1),
),
batch_size=batch_size,
shuffle=True,
)
val_data_loader = DataLoader(
TensorDataset(
torch.as_tensor(val_data[0]),
torch.as_tensor(val_data[1], dtype=torch.float64).unsqueeze(-1),
),
batch_size=batch_size,
shuffle=True,
)
optimizer = AdamW(params=model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
losses = fit_torch_model(
model=model,
training_data=train_data_loader,
val_data=val_data_loader,
loss=F.binary_cross_entropy,
optimizer=optimizer,
scheduler=scheduler,
num_epochs=num_epochs,
device=device,
)
And let's check that the model is not overfitting
A look at the confusion matrix also shows good results
Calculating influences ΒΆ
It is important that the model converges to a point near the optimum, since the influence values assume that we are at a minimum (or close) in the loss landscape. The function
measures the influence of the data point \(x_1\) onto \(x_2\) conditioned on the training targets \(y_1\) and \(y_2\) trough some model parameters \(\theta\) . If the loss function L is differentiable, we can take \(I\) to be
$$ I(x_1, x_2) = \nabla_\theta\; L(x_1, y_1) ^\mathsf{T} \; H_\theta^{-1} \; \nabla_\theta \; L(x_2, y_2) $$ See "Understanding Black-box Predictions via Influence Functions" for a detailed derivation of this formula
Let's take a subset of the training data points, which we will calculate the influence values of.
In pyDVL, the influence of the training points on the test points can be calculated with the following
train_x = torch.as_tensor(x)
train_y = torch.as_tensor(y, dtype=torch.float64).unsqueeze(-1)
test_x = torch.as_tensor(test_data[0])
test_y = torch.as_tensor(test_data[1], dtype=torch.float64).unsqueeze(-1)
train_data_loader = DataLoader(
TensorDataset(train_x, train_y),
batch_size=batch_size,
)
influence_model = DirectInfluence(
model,
F.binary_cross_entropy,
hessian_regularization=0.0,
)
influence_model = influence_model.fit(train_data_loader)
influence_values = influence_model.influences(
test_x, test_y, train_x, train_y, mode="up"
)
The above explicitly constructs the Hessian. This can often be computationally expensive and conjugate gradient approximate calculation should be used for bigger models.
With the influence type 'up', training influences have shape [NxM] where N is the number of test samples and M is the number of training samples. They therefore associate to each training sample its influence on each test sample. Influence type 'perturbation', instead, return an array of shape [NxMxF], where F is the number of features in input, ie. the length of x.
In our case, in order to have a value of the total average influence of a point we can just average across training samples.
Let's plot the results (adjust colorbar_limits for better color gradient)