pydvl.value.shapley.montecarlo
¶
Monte Carlo approximations to Shapley Data values.
Warning
You probably want to use the common interface provided by compute_shapley_values() instead of directly using the functions in this module.
Because exact computation of Shapley values requires \(\mathcal{O}(2^n)\) re-trainings of the model, several Monte Carlo approximations are available. The first two sample from the powerset of the training data directly: combinatorial_montecarlo_shapley() and owen_sampling_shapley(). The latter uses a reformulation in terms of a continuous extension of the utility.
Alternatively, employing another reformulation of the expression above as a sum over permutations, one has the implementation in permutation_montecarlo_shapley() with the option to pass an early stopping strategy to reduce computation as done in Truncated MonteCarlo Shapley (TMCS).
Also see
It is also possible to use group_testing_shapley() to reduce the number of evaluations of the utility. The method is however typically outperformed by others in this module.
Also see
Additionally, you can consider grouping your data points using GroupedDataset and computing the values of the groups instead. This is not to be confused with "group testing" as implemented in group_testing_shapley(): any of the algorithms mentioned above, including Group Testing, can work to valuate groups of samples as units.
References¶
-
Ghorbani, A., Zou, J., 2019. Data Shapley: Equitable Valuation of Data for Machine Learning. In: Proceedings of the 36th International Conference on Machine Learning, PMLR, pp. 2242–2251. ↩
permutation_montecarlo_shapley
¶
permutation_montecarlo_shapley(
u: Utility,
done: StoppingCriterion,
*,
truncation: TruncationPolicy = NoTruncation(),
n_jobs: int = 1,
parallel_backend: Optional[ParallelBackend] = None,
config: Optional[ParallelConfig] = None,
progress: bool = False,
seed: Optional[Seed] = None
) -> ValuationResult
Computes an approximate Shapley value by sampling independent permutations of the index set, approximating the sum:
where \(\sigma_{:i}\) denotes the set of indices in permutation sigma before the position where \(i\) appears (see [[data-valuation]] for details).
This implements the method described in (Ghorbani and Zou, 2019)1 with a double stopping criterion.
Todo
Think of how to add Robin-Gelman or some other more principled stopping criterion.
Instead of naively implementing the expectation, we sequentially add points to coalitions from a permutation and incrementally compute marginal utilities. We stop computing marginals for a given permutation based on a TruncationPolicy. (Ghorbani and Zou, 2019)1 mention two policies: one that stops after a certain fraction of marginals are computed, implemented in FixedTruncation, and one that stops if the last computed utility ("score") is close to the total utility using the standard deviation of the utility as a measure of proximity, implemented in BootstrapTruncation.
We keep sampling permutations and updating all shapley values
until the StoppingCriterion returns
True
.
PARAMETER | DESCRIPTION |
---|---|
u |
Utility object with model, data, and scoring function.
TYPE:
|
done |
function checking whether computation must stop.
TYPE:
|
truncation |
An optional callable which decides whether to interrupt processing a permutation and set all subsequent marginals to zero. Typically used to stop computation when the marginal is small.
TYPE:
|
n_jobs |
number of jobs across which to distribute the computation.
TYPE:
|
parallel_backend |
Parallel backend instance to use
for parallelizing computations. If
TYPE:
|
config |
(DEPRECATED) Object configuring parallel computation, with cluster address, number of cpus, etc.
TYPE:
|
progress |
Whether to display a progress bar.
TYPE:
|
seed |
Either an instance of a numpy random number generator or a seed for it.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
ValuationResult
|
Object with the data values. |
Changed in version 0.9.0
Deprecated config
argument and added a parallel_backend
argument to allow users to pass the Parallel Backend instance
directly.
Source code in src/pydvl/value/shapley/montecarlo.py
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 |
|
combinatorial_montecarlo_shapley
¶
combinatorial_montecarlo_shapley(
u: Utility,
done: StoppingCriterion,
*,
n_jobs: int = 1,
parallel_backend: Optional[ParallelBackend] = None,
config: Optional[ParallelConfig] = None,
progress: bool = False,
seed: Optional[Seed] = None
) -> ValuationResult
Computes an approximate Shapley value using the combinatorial definition:
This consists of randomly sampling subsets of the power set of the training indices in u.data, and computing their marginal utilities. See Data valuation for details.
Note that because sampling is done with replacement, the approximation is poor even for \(2^{m}\) subsets with \(m>n\), even though there are \(2^{n-1}\) subsets for each \(i\). Prefer permutation_montecarlo_shapley().
Parallelization is done by splitting the set of indices across processes and computing the sum over subsets \(S \subseteq N \setminus \{i\}\) separately.
PARAMETER | DESCRIPTION |
---|---|
u |
Utility object with model, data, and scoring function
TYPE:
|
done |
Stopping criterion for the computation.
TYPE:
|
n_jobs |
number of parallel jobs across which to distribute the computation. Each worker receives a chunk of indices
TYPE:
|
parallel_backend |
Parallel backend instance to use
for parallelizing computations. If
TYPE:
|
config |
(DEPRECATED) Object configuring parallel computation, with cluster address, number of cpus, etc.
TYPE:
|
progress |
Whether to display progress bars for each job.
TYPE:
|
seed |
Either an instance of a numpy random number generator or a seed for it.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
ValuationResult
|
Object with the data values. |
Changed in version 0.9.0
Deprecated config
argument and added a parallel_backend
argument to allow users to pass the Parallel Backend instance
directly.