Skip to content

pydvl.valuation.methods.shapley

This module implements the Shapley valuation method.

Info

See the main documentation for a description of the algorithm and its properties.

We provide two main ways of computing Shapley values:

  1. A general approach that allows for any sampling scheme, including deterministic, uniform, permutations, and so on. This is implemented in ShapleyValuation
  2. A default configuration for the Truncated Monte Carlo Shapley (TMCS) method, described in Ghorbani and Zou (2019)1. This is basically a wrapper to the more general class, but with a permutation sampler by default. Besides being convenient, it allows deactivating importance sampling internally (see Sampling strategies for semi-values).

Computing values in PyDVL typically follows the following pattern: construct a ModelUtility, a sampler, and a stopping criterion, then pass them to the valuation and fit it.

General usage pattern
from pydvl.valuation import (
    ShapleyValuation,
    ModelUtility,
    SupervisedScorer,
    PermutationSampler,
    MaxSamples
)

model = SomeSKLearnModel()
scorer = SupervisedScorer("accuracy", test_data, default=0)
utility = ModelUtility(model, scorer, ...)
sampler = UniformSampler(seed=42)
stopping = MaxSamples(5000)
valuation = ShapleyValuation(utility, sampler, is_done=stopping)
with parallel_config(n_jobs=16):
    valuation.fit(training_data)
result = valuation.result

Choosing samplers

Different choices of sampler yield different qualities of approximation, see Sampling strategies for semi-values for a discussion of the internals.

The most basic one is DeterministicUniformSampler, which iterates over all possible subsets of the training set. This is the most accurate, but also the most computationally expensive method (with complexity \(O(2^n)\)), so it is never used in practice.

However, the most common one is PermutationSampler, which samples random permutations of the training set. Despite the apparent greater complexity of \(O(n!)\), the method is much faster to converge in practice, especially when using truncation policies to early-stop the processing of each permutation. As mentioned above, the default configuration of TMCS is available via TMCShapleyValuation.

Manually instantiating TMCS

Alternatively to using TMCShapleyValuation, in order to compute Shapley values as described in Ghorbani and Zou (2019)1, use this configuration:

truncation = RelativeTruncation(rtol=0.05)
sampler = PermutationSampler(truncation=truncation, seed=seed)
stopping = HistoryDeviation(n_steps=100, rtol=0.05)
valuation = ShapleyValuation(utility, sampler, stopping, skip_converged, progress)

Other samplers introduce different importance sampling schemes for the computation of Shapley values, like the Owen samplers,2 or the Maximum-Sample-Reuse sampler,3 these can be both beneficial and detrimental, but the usage pattern remains the same.

Choosing stopping criteria

As mentioned, computing Shapley values can be computationally expensive, especially for large datasets. Some samplers yield better convergence, but not in all cases. Proper choice of a stopping criterion is crucial to obtain useful results, while avoiding unnecessary computation.

Bogus configurations

While it is possible to mix-and-match different components of the valuation method, it is not always advisable, and it can sometimes be incorrect. For example, using a deterministic sampler with a count-based stopping criterion is likely to yield poor results. More importantly, not all samplers, nor sampler configurations, are compatible with Shapley value computation. For instance using NoIndexIteration with a PowersetSampler will not work since the evaluation strategy expects samples consisting of an index and a subset of its complement in the whole index set.

References


  1. Ghorbani, A., & Zou, J. Y. (2019). Data Shapley: Equitable Valuation of Data for Machine Learning. In Proceedings of the 36th International Conference on Machine Learning, PMLR pp. 2242--2251. 

  2. Okhrati, Ramin, and Aldo Lipani. A Multilinear Sampling Algorithm to Estimate Shapley Values. In 2020 25th International Conference on Pattern Recognition (ICPR), 7992–99. IEEE, 2021. 

  3. Wang, Jiachen T., and Ruoxi Jia. Data Banzhaf: A Robust Data Valuation Framework for Machine Learning. In Proceedings of The 26th International Conference on Artificial Intelligence and Statistics, 6388–6421. PMLR, 2023. 

ShapleyValuation

ShapleyValuation(
    utility: UtilityBase,
    sampler: IndexSampler,
    is_done: StoppingCriterion,
    skip_converged: bool = False,
    show_warnings: bool = True,
    progress: dict[str, Any] | bool = False,
)

Bases: SemivalueValuation

Computes Shapley values with any sampler.

Use this class to test different sampling schemes. For a default configuration, use TMCShapleyValuation.

For an introduction to the algorithm, see the main documentation.

PARAMETER DESCRIPTION
utility

Object to compute utilities.

TYPE: UtilityBase

sampler

Sampling scheme to use.

TYPE: IndexSampler

is_done

Stopping criterion to use.

TYPE: StoppingCriterion

skip_converged

Whether to skip converged indices, as determined by the stopping criterion's converged array.

TYPE: bool DEFAULT: False

show_warnings

Whether to show warnings.

TYPE: bool DEFAULT: True

progress

Whether to show a progress bar. If a dictionary, it is passed to tqdm as keyword arguments, and the progress bar is displayed.

TYPE: dict[str, Any] | bool DEFAULT: False

Source code in src/pydvl/valuation/methods/semivalue.py
def __init__(
    self,
    utility: UtilityBase,
    sampler: IndexSampler,
    is_done: StoppingCriterion,
    skip_converged: bool = False,
    show_warnings: bool = True,
    progress: dict[str, Any] | bool = False,
):
    super().__init__()
    self.utility = utility
    self.sampler = sampler
    self.is_done = is_done
    self.skip_converged = skip_converged
    if skip_converged:  # test whether the sampler supports skipping indices:
        self.sampler.skip_indices = np.array([], dtype=np.int_)
    self.show_warnings = show_warnings
    self.tqdm_args: dict[str, Any] = {"desc": str(self)}
    # HACK: parse additional args for the progress bar if any (we probably want
    #  something better)
    if isinstance(progress, bool):
        self.tqdm_args.update({"disable": not progress})
    elif isinstance(progress, dict):
        self.tqdm_args.update(progress)
    else:
        raise TypeError(f"Invalid type for progress: {type(progress)}")

result property

The current valuation result (not a copy).

fit

fit(data: Dataset, continue_from: ValuationResult | None = None) -> Self

Fits the semi-value valuation to the data.

Access the results through the result property.

PARAMETER DESCRIPTION
data

Data for which to compute values

TYPE: Dataset

continue_from

A previously computed valuation result to continue from.

TYPE: ValuationResult | None DEFAULT: None

Source code in src/pydvl/valuation/methods/semivalue.py
@suppress_warnings(flag="show_warnings")
def fit(self, data: Dataset, continue_from: ValuationResult | None = None) -> Self:
    """Fits the semi-value valuation to the data.

    Access the results through the `result` property.

    Args:
        data: Data for which to compute values
        continue_from: A previously computed valuation result to continue from.

    """
    self._result = self._init_or_check_result(data, continue_from)
    ensure_backend_has_generator_return()

    self.is_done.reset()
    self.utility = self.utility.with_dataset(data)

    strategy = self.sampler.make_strategy(self.utility, self.log_coefficient)
    updater = self.sampler.result_updater(self._result)
    processor = delayed(strategy.process)

    with Parallel(return_as="generator_unordered") as parallel:
        with make_parallel_flag() as flag:
            delayed_evals = parallel(
                processor(batch=list(batch), is_interrupted=flag)
                for batch in self.sampler.generate_batches(data.indices)
            )
            for batch in Progress(delayed_evals, self.is_done, **self.tqdm_args):
                for update in batch:
                    self._result = updater.process(update)
                if self.is_done(self._result):
                    flag.set()
                    self.sampler.interrupt()
                    break
                if self.skip_converged:
                    self.sampler.skip_indices = data.indices[self.is_done.converged]
    logger.debug(f"Fitting done after {updater.n_updates} value updates.")
    return self

values

values(sort: bool = False) -> ValuationResult

Returns a copy of the valuation result.

The valuation must have been run with fit() before calling this method.

PARAMETER DESCRIPTION
sort

Whether to sort the valuation result by value before returning it.

TYPE: bool DEFAULT: False

Returns: The result of the valuation.

Source code in src/pydvl/valuation/base.py
@deprecated(
    target=None,
    deprecated_in="0.10.0",
    remove_in="0.11.0",
)
def values(self, sort: bool = False) -> ValuationResult:
    """Returns a copy of the valuation result.

    The valuation must have been run with `fit()` before calling this method.

    Args:
        sort: Whether to sort the valuation result by value before returning it.
    Returns:
        The result of the valuation.
    """
    if not self.is_fitted:
        raise NotFittedException(type(self))
    assert self._result is not None

    r = self._result.copy()
    if sort:
        r.sort(inplace=True)
    return r

StratifiedShapleyValuation

StratifiedShapleyValuation(
    utility: UtilityBase,
    is_done: StoppingCriterion,
    batch_size: int = 1,
    seed: Seed | None = None,
    skip_converged: bool = False,
    show_warnings: bool = True,
    progress: dict[str, Any] | bool = False,
)

Bases: ShapleyValuation

Computes Shapley values using a uniform stratified sampler.

Uses a StratifiedSampler with uniform probability for the sample sizes. Under this sampling scheme, the expected marginal utility coincides with the Shapley value. See the documentation for details.

When to use this class

This class is for illustrative purposes only. In general, permutation based sampling exhibits better convergence. See e.g. TMCShapleyValuation.

If you need different size strategies, or wish to clip subset sizes outside a given range, instantiate ShapleyValuation directly with a StratifiedSampler.

PARAMETER DESCRIPTION
utility

Object to compute utilities.

TYPE: UtilityBase

is_done

Stopping criterion to use.

TYPE: StoppingCriterion

batch_size

The number of samples to generate per batch. Batches are processed together by each subprocess when working in parallel.

TYPE: int DEFAULT: 1

seed

Random seed for the sampler.

TYPE: Seed | None DEFAULT: None

skip_converged

Whether to skip converged indices. Convergence is determined by the stopping criterion's converged array.

TYPE: bool DEFAULT: False

show_warnings

Whether to show warnings when the stopping criterion is not met.

TYPE: bool DEFAULT: True

progress

Whether to show a progress bar. If a dictionary, it is passed to tqdm as keyword arguments, and the progress bar is displayed.

TYPE: dict[str, Any] | bool DEFAULT: False

Source code in src/pydvl/valuation/methods/shapley.py
def __init__(
    self,
    utility: UtilityBase,
    is_done: StoppingCriterion,
    batch_size: int = 1,
    seed: Seed | None = None,
    skip_converged: bool = False,
    show_warnings: bool = True,
    progress: dict[str, Any] | bool = False,
):
    sampler = StratifiedSampler(
        sample_sizes=ConstantSampleSize(),
        sample_sizes_iteration=RandomSizeIteration,
        index_iteration=RandomIndexIteration,
        batch_size=batch_size,
        seed=seed,
    )
    super().__init__(
        utility, sampler, is_done, skip_converged, show_warnings, progress
    )

result property

The current valuation result (not a copy).

fit

fit(data: Dataset, continue_from: ValuationResult | None = None) -> Self

Fits the semi-value valuation to the data.

Access the results through the result property.

PARAMETER DESCRIPTION
data

Data for which to compute values

TYPE: Dataset

continue_from

A previously computed valuation result to continue from.

TYPE: ValuationResult | None DEFAULT: None

Source code in src/pydvl/valuation/methods/semivalue.py
@suppress_warnings(flag="show_warnings")
def fit(self, data: Dataset, continue_from: ValuationResult | None = None) -> Self:
    """Fits the semi-value valuation to the data.

    Access the results through the `result` property.

    Args:
        data: Data for which to compute values
        continue_from: A previously computed valuation result to continue from.

    """
    self._result = self._init_or_check_result(data, continue_from)
    ensure_backend_has_generator_return()

    self.is_done.reset()
    self.utility = self.utility.with_dataset(data)

    strategy = self.sampler.make_strategy(self.utility, self.log_coefficient)
    updater = self.sampler.result_updater(self._result)
    processor = delayed(strategy.process)

    with Parallel(return_as="generator_unordered") as parallel:
        with make_parallel_flag() as flag:
            delayed_evals = parallel(
                processor(batch=list(batch), is_interrupted=flag)
                for batch in self.sampler.generate_batches(data.indices)
            )
            for batch in Progress(delayed_evals, self.is_done, **self.tqdm_args):
                for update in batch:
                    self._result = updater.process(update)
                if self.is_done(self._result):
                    flag.set()
                    self.sampler.interrupt()
                    break
                if self.skip_converged:
                    self.sampler.skip_indices = data.indices[self.is_done.converged]
    logger.debug(f"Fitting done after {updater.n_updates} value updates.")
    return self

values

values(sort: bool = False) -> ValuationResult

Returns a copy of the valuation result.

The valuation must have been run with fit() before calling this method.

PARAMETER DESCRIPTION
sort

Whether to sort the valuation result by value before returning it.

TYPE: bool DEFAULT: False

Returns: The result of the valuation.

Source code in src/pydvl/valuation/base.py
@deprecated(
    target=None,
    deprecated_in="0.10.0",
    remove_in="0.11.0",
)
def values(self, sort: bool = False) -> ValuationResult:
    """Returns a copy of the valuation result.

    The valuation must have been run with `fit()` before calling this method.

    Args:
        sort: Whether to sort the valuation result by value before returning it.
    Returns:
        The result of the valuation.
    """
    if not self.is_fitted:
        raise NotFittedException(type(self))
    assert self._result is not None

    r = self._result.copy()
    if sort:
        r.sort(inplace=True)
    return r

TMCShapleyValuation

TMCShapleyValuation(
    utility: UtilityBase,
    truncation: TruncationPolicy | None = None,
    is_done: StoppingCriterion | None = None,
    seed: Seed | None = None,
    skip_converged: bool = False,
    show_warnings: bool = True,
    progress: dict[str, Any] | bool = False,
)

Bases: ShapleyValuation

Computes Shapley values using the Truncated Monte Carlo method.

This class provides defaults similar to those in the experiments by Ghorbani and Zou (2019)1.

PARAMETER DESCRIPTION
utility

Object to compute utilities.

TYPE: UtilityBase

truncation

Truncation policy to use. Defaults to RelativeTruncation with a relative tolerance of 0.01 and a burn-in fraction of 0.4.

TYPE: TruncationPolicy | None DEFAULT: None

is_done

Stopping criterion to use. Defaults to HistoryDeviation with a relative tolerance of 0.05 and a window of 100 samples.

TYPE: StoppingCriterion | None DEFAULT: None

seed

Random seed for the sampler.

TYPE: Seed | None DEFAULT: None

skip_converged

Whether to skip converged indices. Convergence is determined by the stopping criterion's converged array.

TYPE: bool DEFAULT: False

show_warnings

Whether to show warnings when the stopping criterion is not met.

TYPE: bool DEFAULT: True

progress

Whether to show a progress bar. If a dictionary, it is passed to tqdm as keyword arguments, and the progress bar is displayed.

TYPE: dict[str, Any] | bool DEFAULT: False

Source code in src/pydvl/valuation/methods/shapley.py
def __init__(
    self,
    utility: UtilityBase,
    truncation: TruncationPolicy | None = None,
    is_done: StoppingCriterion | None = None,
    seed: Seed | None = None,
    skip_converged: bool = False,
    show_warnings: bool = True,
    progress: dict[str, Any] | bool = False,
):
    if truncation is None:
        truncation = RelativeTruncation(rtol=0.01, burn_in_fraction=0.4)
    if is_done is None:
        is_done = HistoryDeviation(n_steps=100, rtol=0.05)
    sampler = PermutationSampler(truncation=truncation, seed=seed)
    super().__init__(
        utility, sampler, is_done, skip_converged, show_warnings, progress
    )

result property

The current valuation result (not a copy).

fit

fit(data: Dataset, continue_from: ValuationResult | None = None) -> Self

Fits the semi-value valuation to the data.

Access the results through the result property.

PARAMETER DESCRIPTION
data

Data for which to compute values

TYPE: Dataset

continue_from

A previously computed valuation result to continue from.

TYPE: ValuationResult | None DEFAULT: None

Source code in src/pydvl/valuation/methods/semivalue.py
@suppress_warnings(flag="show_warnings")
def fit(self, data: Dataset, continue_from: ValuationResult | None = None) -> Self:
    """Fits the semi-value valuation to the data.

    Access the results through the `result` property.

    Args:
        data: Data for which to compute values
        continue_from: A previously computed valuation result to continue from.

    """
    self._result = self._init_or_check_result(data, continue_from)
    ensure_backend_has_generator_return()

    self.is_done.reset()
    self.utility = self.utility.with_dataset(data)

    strategy = self.sampler.make_strategy(self.utility, self.log_coefficient)
    updater = self.sampler.result_updater(self._result)
    processor = delayed(strategy.process)

    with Parallel(return_as="generator_unordered") as parallel:
        with make_parallel_flag() as flag:
            delayed_evals = parallel(
                processor(batch=list(batch), is_interrupted=flag)
                for batch in self.sampler.generate_batches(data.indices)
            )
            for batch in Progress(delayed_evals, self.is_done, **self.tqdm_args):
                for update in batch:
                    self._result = updater.process(update)
                if self.is_done(self._result):
                    flag.set()
                    self.sampler.interrupt()
                    break
                if self.skip_converged:
                    self.sampler.skip_indices = data.indices[self.is_done.converged]
    logger.debug(f"Fitting done after {updater.n_updates} value updates.")
    return self

values

values(sort: bool = False) -> ValuationResult

Returns a copy of the valuation result.

The valuation must have been run with fit() before calling this method.

PARAMETER DESCRIPTION
sort

Whether to sort the valuation result by value before returning it.

TYPE: bool DEFAULT: False

Returns: The result of the valuation.

Source code in src/pydvl/valuation/base.py
@deprecated(
    target=None,
    deprecated_in="0.10.0",
    remove_in="0.11.0",
)
def values(self, sort: bool = False) -> ValuationResult:
    """Returns a copy of the valuation result.

    The valuation must have been run with `fit()` before calling this method.

    Args:
        sort: Whether to sort the valuation result by value before returning it.
    Returns:
        The result of the valuation.
    """
    if not self.is_fitted:
        raise NotFittedException(type(self))
    assert self._result is not None

    r = self._result.copy()
    if sort:
        r.sort(inplace=True)
    return r