pydvl.valuation.samplers.truncation
¶
Truncation policies for the interruption of batched computations.
When estimating marginal contribution-based values with permutation sampling, the computation can be interrupted early if the utility of the current batch of samples is close enough to the total utility of the dataset. This idea is implemented in [RelativeTruncation][pydvl.valuation.shapley.truncated.RelativeTruncation], and in [DeviationTruncation][pydvl.valuation.shapley.truncated.DeviationTruncation], but it can be generalized to other policies.
In particular, one can stop after a fixed number of updates, as in [FixedTruncation][pydvl.valuation.shapley.truncated.FixedTruncation], or after a flag has been set. The latter allows communication with parallel or remote workers to stop computation when the main process determines that values have converged.
References¶
-
Ghorbani, A., Zou, J., 2019. Data Shapley: Equitable Valuation of Data for Machine Learning. In: Proceedings of the 36th International Conference on Machine Learning, PMLR, pp. 2242–2251. ↩
TruncationPolicy
¶
Bases: ABC
A policy for deciding whether to stop computation of a batch of samples
Statistics are kept on the number of calls and truncations as n_calls
and
n_truncations
respectively.
ATTRIBUTE | DESCRIPTION |
---|---|
n_calls |
Number of calls to the policy.
TYPE:
|
n_truncations |
Number of truncations made by the policy.
TYPE:
|
Todo
Because the policy objects are copied to the workers, the statistics are not accessible from the coordinating process. We need to add methods for this.
Source code in src/pydvl/valuation/samplers/truncation.py
reset
abstractmethod
¶
reset(utility: UtilityBase)
__call__
¶
Check whether the computation should be interrupted.
PARAMETER | DESCRIPTION |
---|---|
idx |
Position in the batch currently being computed.
TYPE:
|
score |
Last utility computed.
TYPE:
|
batch_size |
Size of the batch being computed.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
bool
|
|
Source code in src/pydvl/valuation/samplers/truncation.py
NoTruncation
¶
Bases: TruncationPolicy
A policy which never interrupts the computation.
Source code in src/pydvl/valuation/samplers/truncation.py
__call__
¶
Check whether the computation should be interrupted.
PARAMETER | DESCRIPTION |
---|---|
idx |
Position in the batch currently being computed.
TYPE:
|
score |
Last utility computed.
TYPE:
|
batch_size |
Size of the batch being computed.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
bool
|
|
Source code in src/pydvl/valuation/samplers/truncation.py
FixedTruncation
¶
FixedTruncation(fraction: float)
Bases: TruncationPolicy
Break a computation after a fixed number of updates.
The experiments in Appendix B of (Ghorbani and Zou, 2019)1 show that when the training set size is large enough, one can simply truncate the iteration over permutations after a fixed number of steps. This happens because beyond a certain number of samples in a training set, the model becomes insensitive to new ones. Alas, this strongly depends on the data distribution and the model and there is no automatic way of estimating this number.
PARAMETER | DESCRIPTION |
---|---|
fraction |
Fraction of updates in a batch to compute before stopping (e.g. 0.5 to compute half of the marginals in a permutation).
TYPE:
|
Source code in src/pydvl/valuation/samplers/truncation.py
__call__
¶
Check whether the computation should be interrupted.
PARAMETER | DESCRIPTION |
---|---|
idx |
Position in the batch currently being computed.
TYPE:
|
score |
Last utility computed.
TYPE:
|
batch_size |
Size of the batch being computed.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
bool
|
|
Source code in src/pydvl/valuation/samplers/truncation.py
RelativeTruncation
¶
RelativeTruncation(rtol: float)
Bases: TruncationPolicy
Break a computation if the utility is close enough to the total utility.
This is called "performance tolerance" in (Ghorbani and Zou, 2019)1.
Warning
Initialization and reset()
of this policy imply the computation of the total
utility for the dataset, which can be expensive!
PARAMETER | DESCRIPTION |
---|---|
rtol |
Relative tolerance. The permutation is broken if the last computed utility is within this tolerance of the total utility
TYPE:
|
Source code in src/pydvl/valuation/samplers/truncation.py
DeviationTruncation
¶
Bases: TruncationPolicy
Break a computation if the last computed utility is close to the total utility.
This is essentially the same as [RelativeTruncation][pydvl.valuation.shapley.truncated.RelativeTruncation], but with the tolerance determined by a multiple of the standard deviation of the utilities.
Warning
Initialization and reset()
of this policy imply the computation of the total
utility for the dataset, which can be expensive!
PARAMETER | DESCRIPTION |
---|---|
burn_in_fraction |
Fraction of samples within a batch (e.g. permutation) to wait until actually checking.
TYPE:
|
sigmas |
Number of standard deviations to use as a threshold.
TYPE:
|