Skip to content

pydvl.valuation.methods.beta_shapley

This module implements Beta-Shapley valuation as introduced in Kwon and Zou (2022)1.

Background on semi-values

Beta-Shapley is a special case of the semi-value valuation method. You can read a short introduction in the documentation.

Beta(\(\alpha\), \(\beta\))-Shapley is a semi-value whose coefficients are given by the Beta function. The coefficients are defined as:

\[ \begin{eqnarray*} w_{\alpha, \beta} (n, k) & := & \int_0^1 t^{k - 1} (1 - t)^{n - k} \frac{t^{\beta - 1} (1 - t)^{\alpha - 1}}{\text{Beta} (\alpha, \beta)} \mathrm{d} t\\ & = & \frac{\text{Beta} (k + \beta - 1, n - k + \alpha)}{\text{Beta} (\alpha, \beta)}. \end{eqnarray*} \]

Note that this deviates by a factor \(n\) from eq. (5) in Kwon and Zou (2022)1 because of how we define sampler weights, but the effective coefficient remains the same when using any PowersetSampler or PermutationSampler.

Connection to AME

Beta-Shapley can be seen as a special case of AME, introduced in Lin et al. (2022)2.

Todo

Explain sampler choices for AME and how to estimate Beta-Shapley with lasso.

References


  1. Kwon, Yongchan, and James Zou. Beta Shapley: A Unified and Noise-Reduced Data Valuation Framework for Machine Learning. In Proceedings of The 25th International Conference on Artificial Intelligence and Statistics, 8780–8802. PMLR, 2022. 

  2. Lin, Jinkun, Anqi Zhang, Mathias Lécuyer, Jinyang Li, Aurojit Panda, and Siddhartha Sen. Measuring the Effect of Training Data on Deep Learning Predictions via Randomized Experiments. In Proceedings of the 39th International Conference on Machine Learning, 13468–504. PMLR, 2022. 

BetaShapleyValuation

BetaShapleyValuation(
    utility: UtilityBase,
    sampler: IndexSampler,
    is_done: StoppingCriterion,
    alpha: float,
    beta: float,
    skip_converged: bool = False,
    show_warnings: bool = True,
    progress: bool = False,
)

Bases: SemivalueValuation

Computes Beta-Shapley values.

PARAMETER DESCRIPTION
utility

Object to compute utilities.

TYPE: UtilityBase

sampler

Sampling scheme to use.

TYPE: IndexSampler

is_done

Stopping criterion to use.

TYPE: StoppingCriterion

alpha

The alpha parameter of the Beta distribution.

TYPE: float

beta

The beta parameter of the Beta distribution.

TYPE: float

skip_converged

Whether to skip converged indices. Convergence is determined by the stopping criterion's converged array.

TYPE: bool DEFAULT: False

show_warnings

Whether to show any runtime warnings.

TYPE: bool DEFAULT: True

progress

Whether to show a progress bar. If a dictionary, it is passed to tqdm as keyword arguments, and the progress bar is displayed.

TYPE: bool DEFAULT: False

Source code in src/pydvl/valuation/methods/beta_shapley.py
def __init__(
    self,
    utility: UtilityBase,
    sampler: IndexSampler,
    is_done: StoppingCriterion,
    alpha: float,
    beta: float,
    skip_converged: bool = False,
    show_warnings: bool = True,
    progress: bool = False,
):
    super().__init__(
        utility,
        sampler,
        is_done,
        skip_converged=skip_converged,
        show_warnings=show_warnings,
        progress=progress,
    )

    self.alpha = alpha
    self.beta = beta

log_coefficient property

log_coefficient: SemivalueCoefficient | None

Beta-Shapley coefficient.

Defined (up to a constant n) as eq. (5) of Kwon and Zou (2023)1.

result property

The current valuation result (not a copy).

fit

fit(data: Dataset, continue_from: ValuationResult | None = None) -> Self

Fits the semi-value valuation to the data.

Access the results through the result property.

PARAMETER DESCRIPTION
data

Data for which to compute values

TYPE: Dataset

continue_from

A previously computed valuation result to continue from.

TYPE: ValuationResult | None DEFAULT: None

Source code in src/pydvl/valuation/methods/semivalue.py
@suppress_warnings(flag="show_warnings")
def fit(self, data: Dataset, continue_from: ValuationResult | None = None) -> Self:
    """Fits the semi-value valuation to the data.

    Access the results through the `result` property.

    Args:
        data: Data for which to compute values
        continue_from: A previously computed valuation result to continue from.

    """
    self._result = self._init_or_check_result(data, continue_from)
    ensure_backend_has_generator_return()

    self.is_done.reset()
    self.utility = self.utility.with_dataset(data)

    strategy = self.sampler.make_strategy(self.utility, self.log_coefficient)
    updater = self.sampler.result_updater(self._result)
    processor = delayed(strategy.process)

    with Parallel(return_as="generator_unordered") as parallel:
        with make_parallel_flag() as flag:
            delayed_evals = parallel(
                processor(batch=list(batch), is_interrupted=flag)
                for batch in self.sampler.generate_batches(data.indices)
            )
            for batch in Progress(delayed_evals, self.is_done, **self.tqdm_args):
                for update in batch:
                    self._result = updater.process(update)
                if self.is_done(self._result):
                    flag.set()
                    self.sampler.interrupt()
                    break
                if self.skip_converged:
                    self.sampler.skip_indices = data.indices[self.is_done.converged]
    logger.debug(f"Fitting done after {updater.n_updates} value updates.")
    return self

values

values(sort: bool = False) -> ValuationResult

Returns a copy of the valuation result.

The valuation must have been run with fit() before calling this method.

PARAMETER DESCRIPTION
sort

Whether to sort the valuation result by value before returning it.

TYPE: bool DEFAULT: False

Returns: The result of the valuation.

Source code in src/pydvl/valuation/base.py
@deprecated(
    target=None,
    deprecated_in="0.10.0",
    remove_in="0.11.0",
)
def values(self, sort: bool = False) -> ValuationResult:
    """Returns a copy of the valuation result.

    The valuation must have been run with `fit()` before calling this method.

    Args:
        sort: Whether to sort the valuation result by value before returning it.
    Returns:
        The result of the valuation.
    """
    if not self.is_fitted:
        raise NotFittedException(type(self))
    assert self._result is not None

    r = self._result.copy()
    if sort:
        r.sort(inplace=True)
    return r