KNN Shapley ¶
This notebook shows how to calculate Shapley values for the K-Nearest Neighbours algorithm. By making use of the local structure of KNN, it is possible to compute an exact value in almost linear time, as opposed to exponential complexity of exact, model-agnostic Shapley.
The main idea is to exploit the fact that adding or removing points beyond the k-ball doesn't influence the score. Because the algorithm then essentially only needs to do a search it runs in \(\mathcal{O}(N \log N)\) time.
By further using approximate nearest neighbours, it is possible to achieve \((\epsilon,\delta)\) -approximations in sublinear time. However, this is not implemented in pyDVL yet.
We refer to the original paper that pyDVL implements for details: Jia, Ruoxi, David Dao, Boxin Wang, Frances Ann Hubis, Nezihe Merve Gurel, Bo Li, Ce Zhang, Costas Spanos, and Dawn Song. Efficient Task-Specific Data Valuation for Nearest Neighbor Algorithms . Proceedings of the VLDB Endowment 12, no. 11 (1 July 2019): 1610–23.
Setup ¶
We begin by importing the main libraries and setting some defaults.
The main interface is the class KNNShapleyValuation . In order to use it we need to construct two Datasets (one for training and one for evaluating), and a KNNClassifierUtility .
Building a Dataset and a Utility ¶
We use the sklearn iris dataset and wrap it into two Datasets calling the factory Dataset.from_sklearn() . This automatically creates a train / test split for us which will be used to score the model.
We then create a KNN model from scikit-learn and instantiate a
KNNShapleyValuation
object. This valuation departs from standard usage in that it does not use a
Utility
but instead takes the scikit-learn model and the test set directly as input. This is because KNN-Shapley uses a recursive formula to compute the values directly, without needing to sample subsets of the training data (nevertheless, we provide
KNNClassifierUtility
for purposes of testing and experimentation).
Inspecting the results ¶
Let us first look at the labels' distribution as a function of petal and sepal length:
If we now look at the distribution of Shapley values for each class, we see that each has samples with both high and low scores. This is expected, because an accurate model uses information of all classes.
Corrupting labels ¶
To test how informative values are, we can corrupt some training labels and see how their Shapley values change with respect to the non-corrupted points.
_x, _y = train.data()
n_corrupted = 10
_y[:n_corrupted] = (_y[:n_corrupted] + 1) % 3
corrupted_data = Dataset(
_x,
_y,
feature_names=train.feature_names,
target_names=train.target_names,
description="Corrupted iris dataset",
)
knn = sk.neighbors.KNeighborsClassifier(n_neighbors=5)
contaminated_valuation = KNNShapleyValuation(
model=knn, test_data=test, progress=True
).fit(corrupted_data)
contaminated_result = contaminated_valuation.values()
Taking the average corrupted value and comparing it to uncorrupted one, we notice that on average anomalous points have a much lower score, i.e. they tend to be much less valuable to the model.
To do this, first we make sure that we access the results by data index with a call to
ValuationResult.sort()
, then we split the values into two groups: corrupted and non-corrupted. Note how we access property
values
of the
ValuationResult
object. This is a numpy array of values, sorted however the object was sorted. Finally, we compute the quantiles of the two groups and compare them. We see that the corrupted mean is in the lowest percentile of the value distribution, while the correct mean is in the 70th percentile.
contaminated_result.sort(key="index") # actually redundant
corrupted_shapley_values = contaminated_result.values[:n_corrupted]
correct_shapley_values = contaminated_result.values[n_corrupted:]
mean_corrupted = np.mean(corrupted_shapley_values)
mean_correct = np.mean(correct_shapley_values)
percentile_corrupted = np.round(100 * np.mean(result.values < mean_corrupted), 0)
percentile_correct = np.round(100 * np.mean(result.values < mean_correct), 0)
This is confirmed if we plot the distribution of Shapley values and circle corrupt points in red. They all tend to have low Shapley scores, regardless of their position in space and assigned label: