Skip to content

Truncated

References


  1. Ghorbani, A., Zou, J., 2019. Data Shapley: Equitable Valuation of Data for Machine Learning. In: Proceedings of the 36th International Conference on Machine Learning, PMLR, pp. 2242–2251. 

TruncationPolicy()

Bases: ABC

A policy for deciding whether to stop computing marginals in a permutation.

Statistics are kept on the number of calls and truncations as n_calls and n_truncations respectively.

ATTRIBUTE DESCRIPTION
n_calls

Number of calls to the policy.

TYPE: int

n_truncations

Number of truncations made by the policy.

TYPE: int

Todo

Because the policy objects are copied to the workers, the statistics are not accessible from the coordinating process. We need to add methods for this.

Source code in src/pydvl/value/shapley/truncated.py
def __init__(self) -> None:
    self.n_calls: int = 0
    self.n_truncations: int = 0

reset(u=None) abstractmethod

Reset the policy to a state ready for a new permutation.

Source code in src/pydvl/value/shapley/truncated.py
@abc.abstractmethod
def reset(self, u: Optional[Utility] = None):
    """Reset the policy to a state ready for a new permutation."""
    ...

__call__(idx, score)

Check whether the computation should be interrupted.

PARAMETER DESCRIPTION
idx

Position in the permutation currently being computed.

TYPE: int

score

Last utility computed.

TYPE: float

RETURNS DESCRIPTION
bool

True if the computation should be interrupted.

Source code in src/pydvl/value/shapley/truncated.py
def __call__(self, idx: int, score: float) -> bool:
    """Check whether the computation should be interrupted.

    Args:
        idx: Position in the permutation currently being computed.
        score: Last utility computed.

    Returns:
        `True` if the computation should be interrupted.
    """
    ret = self._check(idx, score)
    self.n_calls += 1
    self.n_truncations += 1 if ret else 0
    return ret

NoTruncation()

Bases: TruncationPolicy

A policy which never interrupts the computation.

Source code in src/pydvl/value/shapley/truncated.py
def __init__(self) -> None:
    self.n_calls: int = 0
    self.n_truncations: int = 0

__call__(idx, score)

Check whether the computation should be interrupted.

PARAMETER DESCRIPTION
idx

Position in the permutation currently being computed.

TYPE: int

score

Last utility computed.

TYPE: float

RETURNS DESCRIPTION
bool

True if the computation should be interrupted.

Source code in src/pydvl/value/shapley/truncated.py
def __call__(self, idx: int, score: float) -> bool:
    """Check whether the computation should be interrupted.

    Args:
        idx: Position in the permutation currently being computed.
        score: Last utility computed.

    Returns:
        `True` if the computation should be interrupted.
    """
    ret = self._check(idx, score)
    self.n_calls += 1
    self.n_truncations += 1 if ret else 0
    return ret

FixedTruncation(u, fraction)

Bases: TruncationPolicy

Break a permutation after computing a fixed number of marginals.

The experiments in Appendix B of (Ghorbani and Zou, 2019)1 show that when the training set size is large enough, one can simply truncate the iteration over permutations after a fixed number of steps. This happens because beyond a certain number of samples in a training set, the model becomes insensitive to new ones. Alas, this strongly depends on the data distribution and the model and there is no automatic way of estimating this number.

PARAMETER DESCRIPTION
u

Utility object with model, data, and scoring function

TYPE: Utility

fraction

Fraction of marginals in a permutation to compute before stopping (e.g. 0.5 to compute half of the marginals).

TYPE: float

Source code in src/pydvl/value/shapley/truncated.py
def __init__(self, u: Utility, fraction: float):
    super().__init__()
    if fraction <= 0 or fraction > 1:
        raise ValueError("fraction must be in (0, 1]")
    self.max_marginals = len(u.data) * fraction
    self.count = 0

__call__(idx, score)

Check whether the computation should be interrupted.

PARAMETER DESCRIPTION
idx

Position in the permutation currently being computed.

TYPE: int

score

Last utility computed.

TYPE: float

RETURNS DESCRIPTION
bool

True if the computation should be interrupted.

Source code in src/pydvl/value/shapley/truncated.py
def __call__(self, idx: int, score: float) -> bool:
    """Check whether the computation should be interrupted.

    Args:
        idx: Position in the permutation currently being computed.
        score: Last utility computed.

    Returns:
        `True` if the computation should be interrupted.
    """
    ret = self._check(idx, score)
    self.n_calls += 1
    self.n_truncations += 1 if ret else 0
    return ret

RelativeTruncation(u, rtol)

Bases: TruncationPolicy

Break a permutation if the marginal utility is too low.

This is called "performance tolerance" in (Ghorbani and Zou, 2019)1.

PARAMETER DESCRIPTION
u

Utility object with model, data, and scoring function

TYPE: Utility

rtol

Relative tolerance. The permutation is broken if the last computed utility is less than total_utility * rtol.

TYPE: float

Source code in src/pydvl/value/shapley/truncated.py
def __init__(self, u: Utility, rtol: float):
    super().__init__()
    self.rtol = rtol
    logger.info("Computing total utility for permutation truncation.")
    self.total_utility = self.reset(u)
    self._u = u

__call__(idx, score)

Check whether the computation should be interrupted.

PARAMETER DESCRIPTION
idx

Position in the permutation currently being computed.

TYPE: int

score

Last utility computed.

TYPE: float

RETURNS DESCRIPTION
bool

True if the computation should be interrupted.

Source code in src/pydvl/value/shapley/truncated.py
def __call__(self, idx: int, score: float) -> bool:
    """Check whether the computation should be interrupted.

    Args:
        idx: Position in the permutation currently being computed.
        score: Last utility computed.

    Returns:
        `True` if the computation should be interrupted.
    """
    ret = self._check(idx, score)
    self.n_calls += 1
    self.n_truncations += 1 if ret else 0
    return ret

BootstrapTruncation(u, n_samples, sigmas=1)

Bases: TruncationPolicy

Break a permutation if the last computed utility is close to the total utility, measured as a multiple of the standard deviation of the utilities.

PARAMETER DESCRIPTION
u

Utility object with model, data, and scoring function

TYPE: Utility

n_samples

Number of bootstrap samples to use to compute the variance of the utilities.

TYPE: int

sigmas

Number of standard deviations to use as a threshold.

TYPE: float DEFAULT: 1

Source code in src/pydvl/value/shapley/truncated.py
def __init__(self, u: Utility, n_samples: int, sigmas: float = 1):
    super().__init__()
    self.n_samples = n_samples
    logger.info("Computing total utility for permutation truncation.")
    self.total_utility = u(u.data.indices)
    self.count: int = 0
    self.variance: float = 0
    self.mean: float = 0
    self.sigmas: float = sigmas

__call__(idx, score)

Check whether the computation should be interrupted.

PARAMETER DESCRIPTION
idx

Position in the permutation currently being computed.

TYPE: int

score

Last utility computed.

TYPE: float

RETURNS DESCRIPTION
bool

True if the computation should be interrupted.

Source code in src/pydvl/value/shapley/truncated.py
def __call__(self, idx: int, score: float) -> bool:
    """Check whether the computation should be interrupted.

    Args:
        idx: Position in the permutation currently being computed.
        score: Last utility computed.

    Returns:
        `True` if the computation should be interrupted.
    """
    ret = self._check(idx, score)
    self.n_calls += 1
    self.n_truncations += 1 if ret else 0
    return ret

truncated_montecarlo_shapley(u, *, done, truncation, config=ParallelConfig(), n_jobs=1, coordinator_update_period=10, worker_update_period=5)

Warning

This method is deprecated and only a wrapper for permutation_montecarlo_shapley.

Todo

Think of how to add Robin-Gelman or some other more principled stopping criterion.

PARAMETER DESCRIPTION
u

Utility object with model, data, and scoring function

TYPE: Utility

done

Check on the results which decides when to stop sampling permutations.

TYPE: StoppingCriterion

truncation

callable that decides whether to stop computing marginals for a given permutation.

TYPE: TruncationPolicy

config

Object configuring parallel computation, with cluster address, number of cpus, etc.

TYPE: ParallelConfig DEFAULT: ParallelConfig()

n_jobs

Number of permutation monte carlo jobs to run concurrently.

TYPE: int DEFAULT: 1

Returns: Object with the data values.

Source code in src/pydvl/value/shapley/truncated.py
@deprecated(
    target=True,
    deprecated_in="0.7.0",
    remove_in="0.8.0",
    args_mapping=dict(coordinator_update_period=None, worker_update_period=None),
)
def truncated_montecarlo_shapley(
    u: Utility,
    *,
    done: StoppingCriterion,
    truncation: TruncationPolicy,
    config: ParallelConfig = ParallelConfig(),
    n_jobs: int = 1,
    coordinator_update_period: int = 10,
    worker_update_period: int = 5,
) -> ValuationResult:
    """
    !!! Warning
        This method is deprecated and only a wrapper for
        [permutation_montecarlo_shapley][pydvl.value.shapley.montecarlo.permutation_montecarlo_shapley].

    !!! Todo
        Think of how to add Robin-Gelman or some other more principled stopping
        criterion.

    Args:
        u: Utility object with model, data, and scoring function
        done: Check on the results which decides when to stop sampling
            permutations.
        truncation: callable that decides whether to stop computing marginals
            for a given permutation.
        config: Object configuring parallel computation, with cluster address,
            number of cpus, etc.
        n_jobs: Number of permutation monte carlo jobs to run concurrently.
    Returns:
        Object with the data values.
    """
    from pydvl.value.shapley.montecarlo import permutation_montecarlo_shapley

    return cast(
        ValuationResult,
        permutation_montecarlo_shapley(
            u, done=done, truncation=truncation, config=config, n_jobs=n_jobs
        ),
    )

Last update: 2023-12-21
Created: 2023-12-21