Class-wise Shapley¶
Class-wise Shapley (CWS) (Schoch et al., 2022)1 offers a Shapley framework tailored for classification problems. Given a sample \(x_i\) with label \(y_i \in \mathbb{N}\), let \(D_{y_i}\) be the subset of \(D\) with labels \(y_i\), and \(D_{-y_i}\) be the complement of \(D_{y_i}\) in \(D\). The key idea is that the sample \((x_i, y_i)\) might improve the overall model performance on \(D\), while being detrimental for the performance on \(D_{y_i},\) e.g. because of a wrong label. To address this issue, the authors introduced
where \(S_{y_i} \subseteq D_{y_i} \setminus \{i\}\) and \(S_{-y_i} \subseteq D_{-y_i}\) is arbitrary (in particular, not the complement of \(S_{y_i}\)). The function \(\delta\) is called set-conditional marginal Shapley value and is defined as
for any set \(S\) such that \(i \notin S, C\) and \(S \cap C = \emptyset\).
In practical applications, estimating this quantity is done both with Monte Carlo sampling of the powerset, and the set of index permutations (Castro et al., 2009)2. Typically, this requires fewer samples than the original Shapley value, although the actual speed-up depends on the model and the dataset.
Computing classwise Shapley values
Like all other game-theoretic valuation methods, CWS requires a Utility object constructed with model and dataset, with the peculiarity of requiring a specific ClasswiseScorer. The entry point is the function compute_classwise_shapley_values:
from pydvl.value import *
model = ...
data = Dataset(...)
scorer = ClasswiseScorer(...)
utility = Utility(model, data, scorer)
values = compute_classwise_shapley_values(
utility,
done=HistoryDeviation(n_steps=500, rtol=5e-2) | MaxUpdates(5000),
truncation=RelativeTruncation(utility, rtol=0.01),
done_sample_complements=MaxChecks(1),
normalize_values=True
)
The class-wise scorer¶
In order to use the classwise Shapley value, one needs to define a ClasswiseScorer. This scorer is defined as
where \(f\) and \(g\) are monotonically increasing functions, \(a_S(D_{y_i})\) is the in-class accuracy, and \(a_S(D_{-y_i})\) is the out-of-class accuracy (the names originate from a choice by the authors to use accuracy, but in principle any other score, like \(F_1\) can be used).
The authors show that \(f(x)=x\) and \(g(x)=e^x\) have favorable properties and are therefore the defaults, but we leave the option to set different functions \(f\) and \(g\) for an exploration with different base scores.
The default class-wise scorer
Constructing the CWS scorer requires choosing a metric and the functions \(f\) and \(g\):
Surface of the discounted utility function
The level curves for \(f(x)=x\) and \(g(x)=e^x\) are depicted below. The lines illustrate the contour lines, annotated with their respective gradients.
Evaluation¶
We illustrate the method with two experiments: point removal and noise removal,
as well as an analysis of the distribution of the values. For this we employ the
nine datasets used in (Schoch et al., 2022)1, using the same pre-processing.
For images, PCA is used to reduce down to 32 the features found by a pre-trained
Resnet18
model. Standard loc-scale normalization is performed for all models
except gradient boosting, since the latter is not sensitive to the scale of the
features.
Datasets used for evaluation
Dataset | Data Type | Classes | Input Dims | OpenML ID |
---|---|---|---|---|
Diabetes | Tabular | 2 | 8 | 37 |
Click | Tabular | 2 | 11 | 1216 |
CPU | Tabular | 2 | 21 | 197 |
Covertype | Tabular | 7 | 54 | 1596 |
Phoneme | Tabular | 2 | 5 | 1489 |
FMNIST | Image | 2 | 32 | 40996 |
CIFAR10 | Image | 2 | 32 | 40927 |
MNIST (binary) | Image | 2 | 32 | 554 |
MNIST (multi) | Image | 10 | 32 | 554 |
We show mean and coefficient of variation (CV) \(\frac{\sigma}{\mu}\) of an "inner metric". The former shows the performance of the method, whereas the latter displays its stability: we normalize by the mean to see the relative effect of the standard deviation. Ideally the mean value is maximal and CV minimal.
Finally, we note that for all sampling-based valuation methods the same number of evaluations of the marginal utility was used. This is important to make the algorithms comparable, but in practice one should consider using a more sophisticated stopping criterion.
Dataset pruning for logistic regression (point removal)¶
In (best-)point removal, one first computes values for the training set and then removes in sequence the points with the highest values. After each removal, the remaining points are used to train the model from scratch and performance is measured on a test set. This produces a curve of performance vs. number of points removed which we show below.
As a scalar summary of this curve, (Schoch et al., 2022)1 define Weighted Accuracy Drop (WAD) as:
where \(a_T(D)\) is the accuracy of the model (trained on \(T\)) evaluated on \(D\) and \(T_{-\{1 \colon j \}}\) is the set \(T\) without elements from \(\{1, \dots , j \}\).
We run the point removal experiment for a logistic regression model five times and compute WAD for each run, then report the mean \(\mu_\text{WAD}\) and standard deviation \(\sigma_\text{WAD}\).
We see that CWS is competitive with all three other methods. In all problems
except MNIST (multi)
it outperforms TMCS, while in that case TMCS has a slight
advantage.
In order to understand the variability of WAD we look at its coefficient of variation (lower is better):
CWS is not the best method in terms of CV. For CIFAR10
, Click
, CPU
and
MNIST (binary)
Beta Shapley has the lowest CV. For Diabetes
, MNIST (multi)
and Phoneme
CWS is the winner and for FMNIST
and Covertype
TMCS takes the
lead. Besides LOO, TMCS has the highest relative standard deviation.
The following plot shows accuracy vs number of samples removed. Random values serve as a baseline. The shaded area represents the 95% bootstrap confidence interval of the mean across 5 runs.
Because samples are removed from high to low valuation order, we expect a steep decrease in the curve.
Overall we conclude that in terms of mean WAD, CWS and TMCS perform best, with CWS's CV on par with Beta Shapley's, making CWS a competitive method.
Dataset pruning for a neural network by value transfer¶
Transfer of values from one model to another is probably of greater practical relevance: values are computed using a cheap model and used to prune the dataset before training a more expensive one.
The following plot shows accuracy vs number of samples removed for transfer from logistic regression to a neural network. The shaded area represents the 95% bootstrap confidence interval of the mean across 5 runs.
As in the previous experiment samples are removed from high to low valuation
order and hence we expect a steep decrease in the curve. CWS is competitive with
the other methods, especially in very unbalanced datasets like Click
. In other
datasets, like Covertype
, Diabetes
and MNIST (multi)
the performance is on
par with TMCS.
Detection of mis-labeled data points¶
The next experiment tries to detect mis-labeled data points in binary classification tasks. 20% of the indices is flipped at random (we don't consider multi-class datasets because there isn't a unique flipping strategy). The following table shows the mean of the area under the curve (AUC) for five runs.
In the majority of cases TMCS has a slight advantage over CWS, except for
Click
, where CWS has a slight edge, most probably due to the unbalanced nature
of the dataset. The following plot shows the CV for the AUC of the five runs.
In terms of CV, CWS has a clear edge over TMCS and Beta Shapley.
Finally, we look at the ROC curves training the classifier on the \(n\) first samples in increasing order of valuation (i.e. starting with the worst):
Although at first sight TMCS seems to be the winner, CWS stays competitive after factoring in running time. For a perfectly balanced dataset, CWS needs on average fewer samples than TCMS.
Value distribution¶
For illustration, we compare the distribution of values computed by TMCS and CWS.
For Click
TMCS has a multi-modal distribution of values. We hypothesize that
this is due to the highly unbalanced nature of the dataset, and notice that CWS
has a single mode, leading to its greater performance on this dataset.
Conclusion¶
CWS is an effective way to handle classification problems, in particular for unbalanced datasets. It reduces the computing requirements by considering in-class and out-of-class points separately.
-
Schoch, S., Xu, H., Ji, Y., 2022. CS-Shapley: Class-wise Shapley Values for Data Valuation in Classification, in: Proc. Of the Thirty-Sixth Conference on Neural Information Processing Systems (NeurIPS). Presented at the Advances in Neural Information Processing Systems (NeurIPS 2022). ↩↩↩
-
Castro, J., Gómez, D., Tejada, J., 2009. Polynomial calculation of the Shapley value based on sampling. Computers & Operations Research, Selected papers presented at the Tenth International Symposium on Locational Decisions (ISOLDE X) 36, 1726–1730. https://doi.org/10.1016/j.cor.2008.04.004 ↩