Skip to content

pydvl.valuation.samplers.truncation

Truncation policies for the interruption of batched computations.

When estimating marginal contribution-based values with permutation sampling, the computation can be interrupted early. A naive approach is to stop after a fixed number of updates, using FixedTruncation.

However, a more successful one is to stop if the utility of the current batch of samples is close enough to the total utility of the dataset. This idea is implemented in RelativeTruncation, and was introduced as Truncated Montecarlo Shapley (TMCS) in Ghobani and Zou (2019)1. A slight variation is to use the standard deviation of the utilities to set the tolerance, which can be done with DeviationTruncation.

Stopping too early

Truncation policies can lead to underestimation of values if the utility function has high variance. To avoid this, one can set a burn-in period before checking the utility, or use a policy that is less sensitive to variance.

References


  1. Ghorbani, A., Zou, J., 2019. Data Shapley: Equitable Valuation of Data for Machine Learning. In: Proceedings of the 36th International Conference on Machine Learning, PMLR, pp. 2242–2251. 

TruncationPolicy

TruncationPolicy()

Bases: ABC

A policy for deciding whether to stop computation of a batch of samples

Statistics are kept on the total number of calls and truncations as n_calls and n_truncations respectively.

ATTRIBUTE DESCRIPTION
n_calls

Number of calls to the policy.

TYPE: int

n_truncations

Number of truncations made by the policy.

TYPE: int

Todo

Because the policy objects are copied to the workers, the statistics are not accessible from the coordinating process. We need to add methods for this.

Source code in src/pydvl/valuation/samplers/truncation.py
def __init__(self) -> None:
    self.n_calls: int = 0
    self.n_truncations: int = 0

reset abstractmethod

reset(utility: UtilityBase)

(Re)set the policy to a state ready for a new permutation.

Source code in src/pydvl/valuation/samplers/truncation.py
@abstractmethod
def reset(self, utility: UtilityBase):
    """(Re)set the policy to a state ready for a new permutation."""
    ...

__call__

__call__(idx: IndexT, score: float, batch_size: int) -> bool

Check whether the computation should be interrupted.

PARAMETER DESCRIPTION
idx

Position in the batch currently being computed.

TYPE: IndexT

score

Last utility computed.

TYPE: float

batch_size

Size of the batch being computed.

TYPE: int

RETURNS DESCRIPTION
bool

True if the computation should be interrupted.

Source code in src/pydvl/valuation/samplers/truncation.py
def __call__(self, idx: IndexT, score: float, batch_size: int) -> bool:
    """Check whether the computation should be interrupted.

    Args:
        idx: Position in the batch currently being computed.
        score: Last utility computed.
        batch_size: Size of the batch being computed.

    Returns:
        `True` if the computation should be interrupted.
    """

    ret = self._check(idx, score, batch_size)
    self.n_calls += 1
    self.n_truncations += 1 if ret else 0
    return ret

NoTruncation

NoTruncation()

Bases: TruncationPolicy

A policy which never interrupts the computation.

Source code in src/pydvl/valuation/samplers/truncation.py
def __init__(self) -> None:
    self.n_calls: int = 0
    self.n_truncations: int = 0

__call__

__call__(idx: IndexT, score: float, batch_size: int) -> bool

Check whether the computation should be interrupted.

PARAMETER DESCRIPTION
idx

Position in the batch currently being computed.

TYPE: IndexT

score

Last utility computed.

TYPE: float

batch_size

Size of the batch being computed.

TYPE: int

RETURNS DESCRIPTION
bool

True if the computation should be interrupted.

Source code in src/pydvl/valuation/samplers/truncation.py
def __call__(self, idx: IndexT, score: float, batch_size: int) -> bool:
    """Check whether the computation should be interrupted.

    Args:
        idx: Position in the batch currently being computed.
        score: Last utility computed.
        batch_size: Size of the batch being computed.

    Returns:
        `True` if the computation should be interrupted.
    """

    ret = self._check(idx, score, batch_size)
    self.n_calls += 1
    self.n_truncations += 1 if ret else 0
    return ret

FixedTruncation

FixedTruncation(fraction: float)

Bases: TruncationPolicy

Break a computation after a fixed number of updates.

The experiments in Appendix B of (Ghorbani and Zou, 2019)1 show that when the training set size is large enough, one can simply truncate the iteration over permutations after a fixed number of steps. This happens because beyond a certain number of samples in a training set, the model becomes insensitive to new ones. Alas, this strongly depends on the data distribution and the model and there is no automatic way of estimating this number.

PARAMETER DESCRIPTION
fraction

Fraction of updates in a batch to compute before stopping (e.g. 0.5 to compute half of the marginals in a permutation).

TYPE: float

Source code in src/pydvl/valuation/samplers/truncation.py
def __init__(self, fraction: float):
    super().__init__()
    if fraction <= 0 or fraction > 1:
        raise ValueError("fraction must be in (0, 1]")
    self.fraction = fraction
    self.count = 0  # within-permutation count

__call__

__call__(idx: IndexT, score: float, batch_size: int) -> bool

Check whether the computation should be interrupted.

PARAMETER DESCRIPTION
idx

Position in the batch currently being computed.

TYPE: IndexT

score

Last utility computed.

TYPE: float

batch_size

Size of the batch being computed.

TYPE: int

RETURNS DESCRIPTION
bool

True if the computation should be interrupted.

Source code in src/pydvl/valuation/samplers/truncation.py
def __call__(self, idx: IndexT, score: float, batch_size: int) -> bool:
    """Check whether the computation should be interrupted.

    Args:
        idx: Position in the batch currently being computed.
        score: Last utility computed.
        batch_size: Size of the batch being computed.

    Returns:
        `True` if the computation should be interrupted.
    """

    ret = self._check(idx, score, batch_size)
    self.n_calls += 1
    self.n_truncations += 1 if ret else 0
    return ret

RelativeTruncation

RelativeTruncation(rtol: float, burn_in_fraction: float = 0.0)

Bases: TruncationPolicy

Break a computation if the utility is close enough to the total utility.

This is called "performance tolerance" in (Ghorbani and Zou, 2019)1.

Warning

Initialization and reset() of this policy imply the computation of the total utility for the dataset, which can be expensive!

PARAMETER DESCRIPTION
rtol

Relative tolerance. The permutation is broken if the last computed utility is within this tolerance of the total utility.

TYPE: float

burn_in_fraction

Fraction of samples within a permutation to wait until actually checking.

TYPE: float DEFAULT: 0.0

Source code in src/pydvl/valuation/samplers/truncation.py
def __init__(self, rtol: float, burn_in_fraction: float = 0.0):
    super().__init__()
    assert 0 <= burn_in_fraction <= 1
    self.burn_in_fraction = burn_in_fraction
    self.rtol = rtol
    self.total_utility = 0.0
    self.count = 0  # within-permutation count
    self._is_setup = False

DeviationTruncation

DeviationTruncation(sigmas: float, burn_in_fraction: float = 0.0)

Bases: TruncationPolicy

Break a computation if the last computed utility is close to the total utility.

This is essentially the same as RelativeTruncation, but with the tolerance determined by a multiple of the standard deviation of the utilities.

Danger

This policy can break early if the utility function has high variance. This can lead to gross underestimation of values. Use with caution.

Warning

Initialization and reset() of this policy imply the computation of the total utility for the dataset, which can be expensive!

PARAMETER DESCRIPTION
burn_in_fraction

Fraction of samples within a permutation to wait until actually checking.

TYPE: float DEFAULT: 0.0

sigmas

Number of standard deviations to use as a threshold.

TYPE: float

Source code in src/pydvl/valuation/samplers/truncation.py
def __init__(self, sigmas: float, burn_in_fraction: float = 0.0):
    super().__init__()
    assert 0 <= burn_in_fraction <= 1

    self.burn_in_fraction = burn_in_fraction
    self.total_utility = 0.0
    self.count = 0  # within-permutation count
    self.variance = 0.0
    self.mean = 0.0
    self.sigmas = sigmas
    self._is_setup = False