Skip to content

Banzhaf Semi-values for data valuation

This notebook showcases Data Banzhaf: A Robust Data Valuation Framework for Machine Learning by Wang, and Jia.

Computing Banzhaf semi-values using pyDVL follows basically the same procedure as all other semi-value-based methods like Shapley values. However, Data-Banzhaf tends to be more robust to stochasticity in the training process than other semi-values. A property that we study here.

Additionally, we compare two sampling techniques: the standard permutation-based Monte Carlo sampling, and the so-called MSR (Maximum Sample Reuse) principle.

In order to highlight the strengths of Data-Banzhaf, we require a stochastic model. For this reason, we use a CNN to classify handwritten digits from the scikit-learn toy datasets .

Setup

If you are reading this in the documentation, some boilerplate (including most plotting code) has been omitted for convenience.

The dataset

The data consists of ~1800 grayscale images of 8x8 pixels with 16 shades of gray. These images contain handwritten digits from 0 to 9. The helper function load_digits_dataset() downloads and prepares it for usage returning two Datasets .

train, test = load_digits_dataset(train_size=0.7, random_state=random_state)
No description has been provided for this image

Creating the utility and computing Banzhaf semi-values

Now we can calculate the contribution of each training sample to the model performance. We use a simple CNN written in torch, and wrapped into an object to convert numpy arrays into tensors (as of v0.9.2 valuation methods in pyDVL work only with numpy arrays). Note that any model that implements the protocol SupervisedModel , which is just the standard sklearn interface of fit() , predict() and score() can be used to construct the utility.

from support.banzhaf import TorchCNNModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TorchCNNModel(lr=0.001, epochs=n_epochs, batch_size=batch_size, device=device)
model.fit(*train.data())
Training accuracy: 0.698
Test accuracy: 0.711

As with all other model-based valuation methods, for Data Banzhaf we need a scoring function to measure performance of the model over the test set. We will use accuracy, but it can be anything, like e.g. \(R^2\) , using strings from the standard sklearn scoring methods , passed to SupervisedScorer .

We group our torch model and the scoring function into an instance of ModelUtility .

from pydvl.valuation.samplers import PermutationSampler, RelativeTruncation
from pydvl.valuation.scorers import SupervisedScorer
from pydvl.valuation.stopping import MinUpdates
from pydvl.valuation.utility import ModelUtility

accuracy_over_test_set = SupervisedScorer(
    "accuracy", test_data=test, default=0.0, range=(0, 1)
)

utility = ModelUtility(model=model, scorer=accuracy_over_test_set)

In order to compute the Banzhaf semi-values, we use DataBanzhafValuation , which also requires choosing a sampler and a stopping criterion.

We use the standard PermutationSampler , and choose to stop computation using the MinUpdates stopping criterion, which terminates after a fixed number of value updates. This is a simple stopping criterion, but it is not very efficient. We will later compare it to RankCorrelation , which terminates after the change in Spearman correlation between two successive iterations is below a certain threshold.

We also define a relative TruncationPolicy , which is a policy used to early stop computation of marginal values in permutations, once the utility is close to the total utility. This is a heuristic to speed up computation introduced in the Data-Shapley paper called Truncated Monte Carlo Shapley. Note how we tell it to wait until at least 50% of every permutation has been processed in order to start evaluation. This is to ensure that noise doesn't stop the computation too early.

truncation = RelativeTruncation(rtol=0.05, burn_in_fraction=0.5)
sampler = PermutationSampler(truncation=truncation)
stopping = MinUpdates(100)

We now instantiate and fit the valuation. Note how parallelization is just a matter of using joblib's context manager parallel_config in order to set the number of jobs.

from joblib import parallel_config

from pydvl.valuation.methods import DataBanzhafValuation

valuation = DataBanzhafValuation(
    utility, sampler=sampler, is_done=stopping, progress=True
)

with parallel_config(n_jobs=n_jobs):
    valuation.fit(train)

values = valuation.values()
values.sort(key="value")
df = values.to_dataframe(column="banzhaf_value")

For convenience, we have transformed the values into a dataframe. It includes columns with the mean, variance and number of updates of the Monte Carlo estimates for the values:

banzhaf_value banzhaf_value_variances banzhaf_value_counts
240 -7.488506e-42 1.109645e-94 100
991 -4.254834e-42 3.424364e-94 100
792 -2.627003e-42 1.318692e-100 100
265 -2.535799e-42 2.243686e-95 100
1020 -1.209766e-42 1.270162e-93 100
... ... ... ...
1021 3.573181e-42 1.264143e-93 100
312 6.377917e-42 8.087408e-97 100
1018 1.126661e-41 1.038240e-93 100
653 1.610007e-41 2.148715e-110 100
433 2.065544e-41 2.471744e-102 100

1257 rows × 3 columns

Let us plot the results. In the next cell we will take the 10 images with the lowest score and plot their values with 95% Normal confidence intervals. Keep in mind that Permutation Monte Carlo Banzhaf is typically very noisy, and it can take many steps to arrive at a clean estimate.

No description has been provided for this image

Evaluation on anomalous data

An interesting use-case for data valuation is finding anomalous data. Maybe some of the data is really noisy or has been mislabeled. To simulate this for our dataset, we will change some of the labels and add noise to some images. Intuitively, these anomalous data points should then have a lower value.

To evaluate this, let us first check the average value of the 10 data points with the highest value, as these will be the ones that we modify:

For the first 5 images, we will falsify their label, for images 6-10, we will add some noise.

x_train_anomalous = train.data().x.copy()
y_train_anomalous = train.data().y.copy()
anomalous_indices = high_values.index.map(int).values

# Change the label of the first 5 images
y_train_anomalous[anomalous_indices[:5]] = np.mod(
    y_train_anomalous[anomalous_indices[:5]] + 1, 10
)

# Add noise to images 6-10
current_images = x_train_anomalous[anomalous_indices[5:10]]
noisy_images = current_images + 0.5 * np.random.randn(*current_images.shape)
noisy_images[noisy_images < 0] = 0.0
noisy_images[noisy_images > 1] = 1.0
x_train_anomalous[anomalous_indices[5:10]] = noisy_images
from pydvl.valuation.dataset import Dataset

anomalous_dataset = Dataset(x=x_train_anomalous, y=y_train_anomalous)

# Note that we reuse the same stopping criterion. fit() resets it, but
# to be sure we can always call stopping.reset()
anomalous_valuation = DataBanzhafValuation(
    utility, sampler=sampler, is_done=stopping.reset(), progress=True
)

with parallel_config(n_jobs=n_jobs):
    anomalous_valuation.fit(anomalous_dataset)

anomalous_values = anomalous_valuation.values()
anomalous_values.sort(key="value")
anomalous_df = anomalous_values.to_dataframe(column="banzhaf_value")

Let us now look at how the value has changed for the images that we manipulated:

As can be seen in this figure, the valuation of the data points has decreased significantly by adding noise or falsifying their labels. This shows the potential of using Banzhaf values or other data valuation methods to detect mislabeled data points or noisy input data.

Maximum Sample Reuse Banzhaf

Despite the previous results already being useful, we had to retrain the model a number of times and yet the variance of the value estimates was high. This has consequences for the stability of the top-k ranking of points, which decreases the applicability of the method. We will now use a different sampling method called Maximum Sample Reuse ( MSR ) which reuses every sample for updating the Banzhaf values. The method was introduced by the authors of Data-Banzhaf and is much more sample-efficient, as we will show.

All that is required to compute the values with MSR is using MSRSampler as sampler.

Because values converge much faster, we can use a better stopping criterion. Instead of fixing the number of value updates with MinUpdates , we use RankCorrelation to stop when the change in Spearman correlation between the ranking of two successive iterations is below a threshold. Despite the much stricter stopping criterion, fitting the Banzhaf values with the MSR sampler is much faster.

from pydvl.valuation.samplers import MSRSampler
from pydvl.valuation.stopping import RankCorrelation

valuation = DataBanzhafValuation(
    utility,
    sampler=MSRSampler(batch_size=32, seed=random_state),
    is_done=RankCorrelation(rtol=1e-4, burn_in=64),
    progress=True,
)

with parallel_config(n_jobs=n_jobs):
    valuation.fit(train)

msr_values = valuation.values()
msr_values.sort(key="value")
msr_df = msr_values.to_dataframe(column="banzhaf_value")

Inspection of the values reveals (generally) much lower variances. Notice the number of updates to each value as well.

Compare convergence speed of Banzhaf and MSR Banzhaf Values

Conventional margin-based samplers produce require evaluating the utility twice to do one update of the value, and permutation samplers do instead \(n+1\) evaluations for \(n\) updates. Maximum Sample Reuse ( MSR ) updates instead all indices in every sample that the utility evaluates. We compare the convergence rates of these methods.

In order to do so, we will compute the semi-values using different samplers and use a high number of iterations to make sure that the values have converged.

max_checks = 1000
moving_avg = 200
from pydvl.valuation import IndexSampler, SemivalueValuation, ValuationResult
from pydvl.valuation.stopping import HistoryDeviation, MaxChecks


def compute_semivalues_and_history(
    method_t: Type[SemivalueValuation],
    sampler_t: Type[IndexSampler],
    sampler_args: dict,
    max_checks: int,
    progress: bool = True,
):
    history = HistoryDeviation(n_steps=max_checks, rtol=1e-6)
    valuation = method_t(
        utility,
        sampler=sampler_t(**sampler_args, seed=random_state),
        is_done=MaxChecks(max_checks + 2) | history,
        progress=progress,
    )
    with parallel_config(n_jobs=n_jobs):
        valuation.fit(train)

    return history, valuation.values()
from pydvl.valuation.samplers import (
    AntitheticSampler,
    HarmonicSampleSize,
    RandomIndexIteration,
    RandomSizeIteration,
    StratifiedSampler,
    UniformSampler,
)

experiments = OrderedDict(
    [
        (
            PermutationSampler,
            {
                "name": "Permutation",
                "truncation": RelativeTruncation(rtol=0.05, burn_in_fraction=0.5),
            },
        ),
        (MSRSampler, {"name": "MSR", "kwargs": {"batch_size": 16}}),
        (UniformSampler, {"name": "Uniform", "kwargs": {}}),
        (AntitheticSampler, {"name": "Antithetic", "kwargs": {}}),
        (
            StratifiedSampler,
            {
                "name": "Stratified",
                "kwargs": {
                    "sample_sizes": HarmonicSampleSize(1),
                    "sample_sizes_iteration": RandomSizeIteration,
                    "index_iteration": RandomIndexIteration,
                },
            },
        ),
    ]
)

results = {}
history = {}

for sampler_t, params in experiments.items():
    history[sampler_t], results[sampler_t] = compute_semivalues_and_history(
        DataBanzhafValuation, sampler_t, params.get("kwargs", {}), max_checks
    )

The plot above visualizes the convergence speed of different samplers used for Banzhaf semi-value calculation. It shows the average magnitude of how much the semi-values are updated in every step of the algorithm.

As you can see, MSR Banzhaf stabilizes much faster. After 1000 iterations (subsets sampled and evaluated with the utility), Permutation Monte Carlo Banzhaf has evaluated the marginal function about 5 times per data point (we are using 200 data points). For MSR , the semi-value of each data point was updated 1000 times. Due to this, the values converge much faster wrt. the number of utility evaluations, which is the key advantage of MSR sampling.

MSR sampling does come at a cost, however, which is that the updates to the semi-values are more noisy than in other methods. We will analyze the impact of this tradeoff in the next sections. First, let us look at how similar all the computed semi-values are. They are all Banzhaf values, so in a perfect world, all samplers should result in the exact same semi-values. However, due to randomness in the utility (recall that we use a neural network) and randomness in the samplers, the resulting values are likely never exactly the same. Another quality measure is that a good sampler would lead to very consistent values, a bad one to less consistent values. Let us first examine how similar the results are, then we'll look at consistency.

Similarity of the semi-values computed using different samplers

This plot shows that the samplers lead to quite different Banzhaf semi-values, however, all of them have some points in common. The MSR Sampler does not seem to be significantly worse than any others.

In an ideal setting without randomness, the overlap of points would be higher, however, the stochastic nature of the CNN model that we use together with the fact that we use only 200 data points for training, might overshadow these results. As a matter of fact we have the rather discouraging following result:

Consistency of the semi-values

Finally, we want to analyze how consistent the semi-values are when computed using the different samplers. In order to do this, we calculate them multiple times and check how many of the data points in the top and lowest 20% of the valuation overlap.

Conclusion

MSR sampling updates the semi-value estimates for every index in the sample, much more frequently than any other sampler available, which leads to much faster convergence . Additionally, the sampler is more consistent with its value estimates than the other samplers, which might be caused by the higher number of value updates.

There is alas no general recommendation. It is best to try different samplers when computing semi-values and test which one is best suited for your use case. Nevertheless, the MSR sampler seems like a more efficient sampler which may bring fast results and is well-suited for stochastic models.