Skip to content

Computing Shapley values for a torch model ΒΆ

This notebook illustrates how to wrap a torch.nn.Module model using skorch and use it with pyDVL to compute Shapley values. The model is a simple convolutional neural network (CNN) that classifies handwritten digits from the MNIST dataset.

The notebook follows almost verbatim the first few steps of the notebook on MSR Banzhaf.

Setup ΒΆ

If you are reading this in the documentation, some boilerplate (including most plotting code) has been omitted for convenience.

The dataset ΒΆ

The data consists of ~1800 grayscale images of 8x8 pixels with 16 shades of gray. These images contain handwritten digits from 0 to 9. The helper function load_digits_dataset() downloads and prepares it for usage returning two Datasets .

train, test = load_digits_dataset(
    train_size=0.7, random_state=random_state, device=device
)
No description has been provided for this image

Creating the utility using a skorch model ΒΆ

Now we can calculate the contribution of each training sample to the model performance. We use a simple CNN written in torch wrapped as a skorch model. Something to keep in mind is to pass torch_load_kwargs={"weights_only": False} to the NeuralNetClassifier constructor, otherwise the model will fail to pickle and parallelization won't work.

from skorch import NeuralNetClassifier
from support.banzhaf import SimpleCNN

model = NeuralNetClassifier(
    SimpleCNN,
    criterion=torch.nn.CrossEntropyLoss,
    lr=0.01,
    max_epochs=n_epochs,
    batch_size=batch_size,
    train_split=None,
    optimizer=torch.optim.Adam,
    device=device,
    verbose=False,
    torch_load_kwargs={"weights_only": False},
)
model.fit(*train.data());
Training accuracy: 0.984
Test accuracy: 0.971

As with every model-based valuation method, we need a scoring function to measure the performance of the model over the test set. We will use accuracy, but it can be anything, e.g. \(R^2\) , using strings from the standard sklearn scoring methods , passed to SkorchSupervisedScorer .

We group our skorch model and the scoring function into an instance of ModelUtility .

from pydvl.valuation.scorers import SkorchSupervisedScorer
from pydvl.valuation.stopping import MinUpdates
from pydvl.valuation.utility import ModelUtility

accuracy_over_test_set = SkorchSupervisedScorer(
    model, test_data=test, default=0.0, range=(0, 1)
)

utility = ModelUtility(model=model, scorer=accuracy_over_test_set)

In order to compute Shapley values, we use TMCShapleyValuation .

We choose to stop computation using the MinUpdates stopping criterion, which terminates after a fixed number of value updates.

We also define a relative TruncationPolicy , which is a policy used to early-stop computation of marginal values in permutations, once the utility is close to the total utility. This is a heuristic to speed up computation introduced in the Data-Shapley paper called Truncated Monte Carlo Shapley. Note how we tell it to wait until at least 30% of every permutation has been processed in order to start evaluation. This is to ensure that noise doesn't stop the computation too early.

from pydvl.valuation.samplers import RelativeTruncation

truncation = RelativeTruncation(rtol=0.05, burn_in_fraction=0.3)
stopping = MinUpdates(100)

We now instantiate and fit the valuation. Note how parallelization is just a matter of using joblib's context manager parallel_config in order to set the number of jobs.

from joblib import parallel_config

from pydvl.valuation.methods import TMCShapleyValuation

valuation = TMCShapleyValuation(
    utility, truncation=truncation, is_done=stopping, progress=True
)

# filecache is a very simple wrapper not intended for production code
cached_fit = filecache("shapley_skorch_result.pkl")(lambda d: valuation.fit(d).result)
with parallel_config(n_jobs=n_jobs):
    result = cached_fit(train)

The results object is of type ValuationResult , and contains values, variances and number of updates of the Monte Carlo estimates. It can be indexed, sliced and copied in natural ways.

Let us plot the results. In the next cell we will take the 10 images with the lowest score and plot their values with 95% Normal confidence intervals.

No description has been provided for this image