Computing Shapley values for a torch model ΒΆ
This notebook illustrates how to wrap a
torch.nn.Module
model using
skorch
and use it with pyDVL to compute Shapley values. The model is a simple convolutional neural network (CNN) that classifies handwritten digits from the
MNIST
dataset.
The notebook follows almost verbatim the first few steps of the notebook on MSR Banzhaf.
Setup ΒΆ
Creating the utility using a skorch model ΒΆ
Now we can calculate the contribution of each training sample to the model performance. We use a simple CNN written in torch wrapped as a skorch model. Something to keep in mind is to pass
torch_load_kwargs={"weights_only": False}
to the
NeuralNetClassifier
constructor, otherwise the model will fail to pickle and parallelization won't work.
from skorch import NeuralNetClassifier
from support.banzhaf import SimpleCNN
model = NeuralNetClassifier(
SimpleCNN,
criterion=torch.nn.CrossEntropyLoss,
lr=0.01,
max_epochs=n_epochs,
batch_size=batch_size,
train_split=None,
optimizer=torch.optim.Adam,
device=device,
verbose=False,
torch_load_kwargs={"weights_only": False},
)
model.fit(*train.data());
As with every model-based valuation method, we need a scoring function to measure the performance of the model over the test set. We will use accuracy, but it can be anything, e.g. \(R^2\) , using strings from the standard sklearn scoring methods , passed to SkorchSupervisedScorer .
We group our skorch model and the scoring function into an instance of ModelUtility .
from pydvl.valuation.scorers import SkorchSupervisedScorer
from pydvl.valuation.stopping import MinUpdates
from pydvl.valuation.utility import ModelUtility
accuracy_over_test_set = SkorchSupervisedScorer(
model, test_data=test, default=0.0, range=(0, 1)
)
utility = ModelUtility(model=model, scorer=accuracy_over_test_set)
In order to compute Shapley values, we use TMCShapleyValuation .
We choose to stop computation using the MinUpdates stopping criterion, which terminates after a fixed number of value updates.
We also define a relative TruncationPolicy , which is a policy used to early-stop computation of marginal values in permutations, once the utility is close to the total utility. This is a heuristic to speed up computation introduced in the Data-Shapley paper called Truncated Monte Carlo Shapley. Note how we tell it to wait until at least 30% of every permutation has been processed in order to start evaluation. This is to ensure that noise doesn't stop the computation too early.
We now instantiate and fit the valuation. Note how parallelization is just a matter of using joblib's context manager
parallel_config
in order to set the number of jobs.
from joblib import parallel_config
from pydvl.valuation.methods import TMCShapleyValuation
valuation = TMCShapleyValuation(
utility, truncation=truncation, is_done=stopping, progress=True
)
# filecache is a very simple wrapper not intended for production code
cached_fit = filecache("shapley_skorch_result.pkl")(lambda d: valuation.fit(d).result)
with parallel_config(n_jobs=n_jobs):
result = cached_fit(train)
The results object is of type ValuationResult , and contains values, variances and number of updates of the Monte Carlo estimates. It can be indexed, sliced and copied in natural ways.
Let us plot the results. In the next cell we will take the 10 images with the lowest score and plot their values with 95% Normal confidence intervals.