Banzhaf Semi-values for data valuation ΒΆ
This notebook showcases Data Banzhaf: A Robust Data Valuation Framework for Machine Learning by Wang, and Jia.
Computing Banzhaf semi-values using pyDVL follows basically the same procedure as all other semi-value-based methods like Shapley values. However, Data-Banzhaf tends to be more robust to stochasticity in the training process than other semi-values. A property that we study here.
Additionally, we compare two sampling techniques: the standard permutation-based Monte Carlo sampling, and the so-called MSR (Maximum Sample Reuse) principle.
In order to highlight the strengths of Data-Banzhaf, we require a stochastic model. For this reason, we use a CNN to classify handwritten digits from the scikit-learn toy datasets .
Setup ΒΆ
Loading the dataset ΒΆ
We use a support function,
load_digits_dataset()
, which downloads the data and prepares it for usage. It returns four arrays that we then use to construct a
Dataset
. The data consists of grayscale images of shape 8x8 pixels with 16 shades of gray. These images contain handwritten digits from 0 to 9.
Training and test data are then used to instantiate a Dataset object:
Creating the utility and computing Banzhaf semivalues ΒΆ
Now we can calculate the contribution of each training sample to the model performance. First we need a model and a Scorer .
As a model, we use a simple CNN written torch, and wrapped into an object to convert numpy arrays into tensors (as of v0.9.0 valuation methods in pyDVL work only with numpy arrays). Note that any model that implements the protocol
pydvl.utils.types.SupervisedModel
, which is just the standard sklearn interface of
fit()
,
predict()
and
score()
can be used to construct the utility.
The final component is the scoring function. It can be anything like accuracy or \(R^2\) , and is set with a string from the standard sklearn scoring methods . Please refer to that documentation on information on how to define your own scoring function.
We group dataset, model and scoring function into an instance of
Utility
and compute the Banzhaf semi-values. We take all defaults, and choose to stop computation using the
MaxChecks
stopping criterion, which terminates after a fixed number of calls to it. With the default
batch_size
of 1 this means that we will retrain the model.
Note how we enable caching using memcached (assuming memcached runs with the default configuration for localhost). This is necessary in the current preliminary implementation of permutation sampling , which is the default for compute_banzhaf_semivalues .
from pydvl.utils import MemcachedCacheBackend, MemcachedClientConfig
# Compute regular Banzhaf semivalue
utility = Utility(
model=model,
data=dataset,
scorer=Scorer("accuracy", default=0.0, range=(0, 1)),
cache_backend=MemcachedCacheBackend(MemcachedClientConfig()),
)
values = compute_banzhaf_semivalues(
utility, done=MaxChecks(max_checks), n_jobs=n_jobs, progress=True
)
values.sort(key="value")
df = values.to_dataframe(column="banzhaf_value", use_names=True)
The returned dataframe contains the mean and variance of the Monte Carlo estimates for the values:
Let us plot the results. In the next cell we will take the 30 images with the lowest score and plot their values with 95% Normal confidence intervals. Keep in mind that Permutation Monte Carlo Banzhaf is typically very noisy, and it can take many steps to arrive at a clean estimate.
Evaluation on anomalous data ΒΆ
An interesting use-case for data valuation is finding anomalous data. Maybe some of the data is really noisy or has been mislabeled. To simulate this, we will change some of the labels of our dataset and add noise to some others. Intuitively, these anomalous data points should then have a lower value.
To evaluate this, let us first check the average value of the first 10 data points, as these will be the ones that we modify. Currently, these are the 10 data points with the highest values:
For the first 5 images, we will falsify their label, for images 6-10, we will add some noise.
x_train_anomalous = training_data[0].copy()
y_train_anomalous = training_data[1].copy()
anomalous_indices = high_dvl.index.map(int).values[:10]
# Set label of first 5 images to 0
y_train_anomalous[high_dvl.index.map(int).values[:5]] = 0
# Add noise to images 6-10
indices = high_dvl.index.values[5:10].astype(int)
current_images = x_train_anomalous[indices]
noisy_images = current_images + 0.5 * np.random.randn(*current_images.shape)
noisy_images[noisy_images < 0] = 0.0
noisy_images[noisy_images > 1] = 1.0
x_train_anomalous[indices] = noisy_images
anomalous_dataset = Dataset(
x_train=x_train_anomalous,
y_train=y_train_anomalous,
x_test=test_data[0],
y_test=test_data[1],
)
anomalous_utility = Utility(
model=TorchCNNModel(),
data=anomalous_dataset,
scorer=Scorer("accuracy", default=0.0, range=(0, 1)),
cache_backend=MemcachedCacheBackend(MemcachedClientConfig()),
)
anomalous_values = compute_banzhaf_semivalues(
anomalous_utility, done=MaxChecks(max_checks), n_jobs=n_jobs, progress=True
)
anomalous_values.sort(key="value")
anomalous_df = anomalous_values.to_dataframe(column="banzhaf_value", use_names=True)
Let us now take a look at the low-value images and check how many of our anomalous images are part of it.
As can be seen in this figure, the valuation of the data points has decreased significantly by adding noise or falsifying their labels. This shows the potential of using Banzhaf values or other data valuation methods to detect mislabeled data points or noisy input data.
Maximum Sample Reuse Banzhaf ΒΆ
Despite the previous results already being useful, we had to retrain the model a number of times and yet the variance of the value estimates was high. This has consequences for the stability of the top-k ranking of points, which decreases the applicability of the method. We now introduce a different sampling method called Maximum Sample Reuse ( MSR ) which reuses every sample for updating the Banzhaf values. The method was introduced by the authors of Data-Banzhaf and is much more sample-efficient, as we will show.
We next construct a new utility. Note how this time we don't use a cache: the chance of hitting twice the same subset of the training set is low enough that one can dispense with it (nevertheless it can still be useful, e.g. when running many experiments).
Computing the values is the same, but we now use a better stopping criterion. Instead of fixing the number of utility evaluations with MaxChecks , we use RankCorrelation to stop when the change in Spearman correlation between the ranking of two successive iterations is below a threshold.
Inspection of the values reveals (generally) much lower variances. Notice the number of updates to each value as well.
Compare convergence speed of Banzhaf and MSR Banzhaf Values ΒΆ
Conventional margin-based samplers produce require evaluating the utility twice to do one update of the value, and permutation samplers do instead \(n+1\) evaluations for \(n\) updates. Maximum Sample Reuse ( MSR ) updates instead all indices in every sample that the utility evaluates. We compare the convergence rates of these methods.
In order to do so, we will compute the semi-values using different samplers and use a high number of iterations to make sure that the values have converged.
def get_semivalues_and_history(
sampler_t, max_checks=max_checks, n_jobs=n_jobs, progress=True
):
_history = HistoryDeviation(n_steps=max_checks, rtol=1e-9)
if sampler_t == MSRSampler:
semivalue_function = compute_msr_banzhaf_semivalues
else:
semivalue_function = compute_banzhaf_semivalues
_values = semivalue_function(
utility,
sampler_t=sampler_t,
done=MaxChecks(max_checks + 2) | _history,
n_jobs=n_jobs,
progress=progress,
)
return _history, _values
The plot above visualizes the convergence speed of different samplers used for Banzhaf semivalue calculation. /It shows the average magnitude of how much the semivalues are updated in every step of the algorithm.
As you can see, MSR Banzhaf stabilizes much faster. After 1000 iterations (subsets sampled and evaluated with the utility), Permutation Monte Carlo Banzhaf has evaluated the marginal function about 5 times per data point (we are using 200 data points). For MSR , the semivalue of each data point was updated 1000 times. Due to this, the values converge much faster wrt. the number of utility evaluations, which is the key advantage of MSR sampling.
MSR sampling does come at a cost, however, which is that the updates to the semivalues are more noisy than in other methods. We will analyze the impact of this tradeoff in the next sections. First, let us look at how similar all the computed semivalues are. They are all Banzhaf values, so in a perfect world, all samplers should result in the exact same semivalues. However, due to randomness in the utility (recall that we use a neural network) and randomness in the samplers, the resulting values are likely never exactly the same. Another quality measure is that a good sampler would lead to very consistent values, a bad one to less consistent values. Let us first examine how similar the results are, then we'll look at consistency.
Similarity of the semivalues computed using different samplers ΒΆ
This plot shows that the samplers lead to quite different Banzhaf semivalues, however, all of them have some points in common. The MSR Sampler does not seem to be significantly worse than any others.
In an ideal setting without randomness, the overlap of points would be higher, however, the stochastic nature of the CNN model that we use together with the fact that we use only 200 data points for training, might overshadow these results. As a matter of fact we have the rather discouraging following result:
Consistency of the semivalues ΒΆ
Finally, we want to analyze how consistent the semivalues returned by the different samplers are. In order to do this, we compute semivalues multiple times and check how many of the data points in the top and lowest 20% of valuation of the data overlap.
Conclusion ΒΆ
MSR sampling updates the semivalue estimates for every index in the sample, much more frequently than any other sampler available, which leads to much faster convergence . Additionally, the sampler is more consistent with its value estimates than the other samplers, which might be caused by the higher number of value updates.
There is alas no general recommendation. It is best to try different samplers when computing semivalues and test which one is best suited for your use case. Nevertheless, the MSR sampler seems like a more efficient sampler which may bring fast results and is well-suited for stochastic models.