Skip to content

Class-wise Shapley

Class-wise Shapley (CWS) (Schoch et al., 2022)1 offers a Shapley framework tailored for classification problems. Given a sample \(x_i\) with label \(y_i \in \mathbb{N}\), let \(D_{y_i}\) be the subset of \(D\) with labels \(y_i\), and \(D_{-y_i}\) be the complement of \(D_{y_i}\) in \(D\). The key idea is that the sample \((x_i, y_i)\) might improve the overall model performance on \(D\), while being detrimental for the performance on \(D_{y_i},\) e.g. because of a wrong label. To address this issue, the authors introduced

\[ v_u(i) = \frac{1}{2^{|D_{-y_i}|}} \sum_{S_{-y_i}} \left [ \frac{1}{|D_{y_i}|}\sum_{S_{y_i}} \binom{|D_{y_i}|-1}{|S_{y_i}|}^{-1} \delta(S_{y_i} | S_{-y_i}) \right ], \]

where \(S_{y_i} \subseteq D_{y_i} \setminus \{i\}\) and \(S_{-y_i} \subseteq D_{-y_i}\) is arbitrary (in particular, not the complement of \(S_{y_i}\)). The function \(\delta\) is called set-conditional marginal Shapley value and is defined as

\[ \delta(S | C) = u( S_{+i} | C ) − u(S | C), \]

for any set \(S\) such that \(i \notin S, C\) and \(S \cap C = \emptyset\).

In practical applications, estimating this quantity is done both with Monte Carlo sampling of the powerset, and the set of index permutations (Castro et al., 2009)2. Typically, this requires fewer samples than the original Shapley value, although the actual speed-up depends on the model and the dataset.

Computing classwise Shapley values

Like all other game-theoretic valuation methods, CWS requires a Utility object constructed with model and dataset, with the peculiarity of requiring a specific ClasswiseScorer. The entry point is the function compute_classwise_shapley_values:

from pydvl.value import *

model = ...
data = Dataset(...)
scorer = ClasswiseScorer(...)
utility = Utility(model, data, scorer)
values = compute_classwise_shapley_values(
    done=HistoryDeviation(n_steps=500, rtol=5e-2) | MaxUpdates(5000),
    truncation=RelativeTruncation(utility, rtol=0.01),

The class-wise scorer

In order to use the classwise Shapley value, one needs to define a ClasswiseScorer. This scorer is defined as

\[ u(S) = f(a_S(D_{y_i})) g(a_S(D_{-y_i})), \]

where \(f\) and \(g\) are monotonically increasing functions, \(a_S(D_{y_i})\) is the in-class accuracy, and \(a_S(D_{-y_i})\) is the out-of-class accuracy (the names originate from a choice by the authors to use accuracy, but in principle any other score, like \(F_1\) can be used).

The authors show that \(f(x)=x\) and \(g(x)=e^x\) have favorable properties and are therefore the defaults, but we leave the option to set different functions \(f\) and \(g\) for an exploration with different base scores.

The default class-wise scorer

Constructing the CWS scorer requires choosing a metric and the functions \(f\) and \(g\):

import numpy as np
from pydvl.value.shapley.classwise import ClasswiseScorer

# These are the defaults
identity = lambda x: x
scorer = ClasswiseScorer(
Surface of the discounted utility function

The level curves for \(f(x)=x\) and \(g(x)=e^x\) are depicted below. The lines illustrate the contour lines, annotated with their respective gradients.

Level curves of the class-wise utility


We illustrate the method with two experiments: point removal and noise removal, as well as an analysis of the distribution of the values. For this we employ the nine datasets used in (Schoch et al., 2022)1, using the same pre-processing. For images, PCA is used to reduce down to 32 the features found by a pre-trained Resnet18 model. Standard loc-scale normalization is performed for all models except gradient boosting, since the latter is not sensitive to the scale of the features.

Datasets used for evaluation
Dataset Data Type Classes Input Dims OpenML ID
Diabetes Tabular 2 8 37
Click Tabular 2 11 1216
CPU Tabular 2 21 197
Covertype Tabular 7 54 1596
Phoneme Tabular 2 5 1489
FMNIST Image 2 32 40996
CIFAR10 Image 2 32 40927
MNIST (binary) Image 2 32 554
MNIST (multi) Image 10 32 554

We show mean and coefficient of variation (CV) \(\frac{\sigma}{\mu}\) of an "inner metric". The former shows the performance of the method, whereas the latter displays its stability: we normalize by the mean to see the relative effect of the standard deviation. Ideally the mean value is maximal and CV minimal.

Finally, we note that for all sampling-based valuation methods the same number of evaluations of the marginal utility was used. This is important to make the algorithms comparable, but in practice one should consider using a more sophisticated stopping criterion.

Dataset pruning for logistic regression (point removal)

In (best-)point removal, one first computes values for the training set and then removes in sequence the points with the highest values. After each removal, the remaining points are used to train the model from scratch and performance is measured on a test set. This produces a curve of performance vs. number of points removed which we show below.

As a scalar summary of this curve, (Schoch et al., 2022)1 define Weighted Accuracy Drop (WAD) as:

\[ \text{WAD} = \sum_{j=1}^{n} \left ( \frac{1}{j} \sum_{i=1}^{j} a_{T_{-\{1 \colon i-1 \}}}(D) - a_{T_{-\{1 \colon i \}}}(D) \right) = a_T(D) - \sum_{j=1}^{n} \frac{a_{T_{-\{1 \colon j \}}}(D)}{j} , \]

where \(a_T(D)\) is the accuracy of the model (trained on \(T\)) evaluated on \(D\) and \(T_{-\{1 \colon j \}}\) is the set \(T\) without elements from \(\{1, \dots , j \}\).

We run the point removal experiment for a logistic regression model five times and compute WAD for each run, then report the mean \(\mu_\text{WAD}\) and standard deviation \(\sigma_\text{WAD}\).

Mean WAD for best-point removal on logistic regression. Values computed using LOO, CWS, Beta Shapley, and TMCS

We see that CWS is competitive with all three other methods. In all problems except MNIST (multi) it outperforms TMCS, while in that case TMCS has a slight advantage.

In order to understand the variability of WAD we look at its coefficient of variation (lower is better):

Coefficient of Variation of WAD for best-point removal on logistic regression. Values computed using LOO, CWS, Beta Shapley, and TMCS

CWS is not the best method in terms of CV. For CIFAR10, Click, CPU and MNIST (binary) Beta Shapley has the lowest CV. For Diabetes, MNIST (multi) and Phoneme CWS is the winner and for FMNIST and Covertype TMCS takes the lead. Besides LOO, TMCS has the highest relative standard deviation.

The following plot shows accuracy vs number of samples removed. Random values serve as a baseline. The shaded area represents the 95% bootstrap confidence interval of the mean across 5 runs.

Accuracy after best-sample removal using values from logistic regression

Because samples are removed from high to low valuation order, we expect a steep decrease in the curve.

Overall we conclude that in terms of mean WAD, CWS and TMCS perform best, with CWS's CV on par with Beta Shapley's, making CWS a competitive method.

Dataset pruning for a neural network by value transfer

Transfer of values from one model to another is probably of greater practical relevance: values are computed using a cheap model and used to prune the dataset before training a more expensive one.

The following plot shows accuracy vs number of samples removed for transfer from logistic regression to a neural network. The shaded area represents the 95% bootstrap confidence interval of the mean across 5 runs.

Accuracy after sample removal using values transferred from logistic regression to an MLP

As in the previous experiment samples are removed from high to low valuation order and hence we expect a steep decrease in the curve. CWS is competitive with the other methods, especially in very unbalanced datasets like Click. In other datasets, like Covertype, Diabetes and MNIST (multi) the performance is on par with TMCS.

Detection of mis-labeled data points

The next experiment tries to detect mis-labeled data points in binary classification tasks. 20% of the indices is flipped at random (we don't consider multi-class datasets because there isn't a unique flipping strategy). The following table shows the mean of the area under the curve (AUC) for five runs.

Mean AUC for mis-labeled data point detection. Values computed using LOO, CWS, Beta Shapley, and TMCS

In the majority of cases TMCS has a slight advantage over CWS, except for Click, where CWS has a slight edge, most probably due to the unbalanced nature of the dataset. The following plot shows the CV for the AUC of the five runs.

Coefficient of variation of AUC for mis-labeled data point detection. Values computed using LOO, CWS, Beta Shapley, and TMCS

In terms of CV, CWS has a clear edge over TMCS and Beta Shapley.

Finally, we look at the ROC curves training the classifier on the \(n\) first samples in increasing order of valuation (i.e. starting with the worst):

Mean ROC across 5 runs with 95% bootstrap CI

Although at first sight TMCS seems to be the winner, CWS stays competitive after factoring in running time. For a perfectly balanced dataset, CWS needs on average fewer samples than TCMS.

Value distribution

For illustration, we compare the distribution of values computed by TMCS and CWS.

Histogram and estimated density of the values computed by TMCS and CWS on all nine datasets

For Click TMCS has a multi-modal distribution of values. We hypothesize that this is due to the highly unbalanced nature of the dataset, and notice that CWS has a single mode, leading to its greater performance on this dataset.


CWS is an effective way to handle classification problems, in particular for unbalanced datasets. It reduces the computing requirements by considering in-class and out-of-class points separately.

  1. Schoch, S., Xu, H., Ji, Y., 2022. CS-Shapley: Class-wise Shapley Values for Data Valuation in Classification, in: Proc. Of the Thirty-Sixth Conference on Neural Information Processing Systems (NeurIPS). Presented at the Advances in Neural Information Processing Systems (NeurIPS 2022). 

  2. Castro, J., Gómez, D., Tejada, J., 2009. Polynomial calculation of the Shapley value based on sampling. Computers & Operations Research, Selected papers presented at the Tenth International Symposium on Locational Decisions (ISOLDE X) 36, 1726–1730.