Skip to content

pydvl.valuation.methods.banzhaf

This module implements the Banzhaf valuation method, as described in Wang and Jia, (2022)1.

Data Banzhaf was proposed as a means to counteract the inherent stochasticity of the utility function in machine learning problems. It chooses the coefficients \(w(k)\) of the semi-value valuation function to be constant \(2^{n-1}\) for all set sizes \(k,\) yielding:

\[ v_\text{bzf}(i) = \frac{1}{2^{n-1}} \sum_{S \sim P(D_{-i})} [u(S_{+i}) - u(S)], \]

Background on semi-values

The Banzhaf valuation is a special case of the semi-value valuation method. You can read a short introduction in the documentation.

The intuition for picking a constant weight is that for any choice of weight function \(w\), one can always construct a utility with higher variance where \(w\) is greater. Therefore, in a worst-case sense, the best one can do is to pick a constant weight.

Data Banzhaf proves to outperform many other valuation methods in downstream tasks like best point removal.

Maximum Sample Reuse Banzhaf

A special sampling scheme (MSR) that reuses each sample to update every index in the dataset is shown by Wang and Jia to be optimal for the Banzhaf valuation: not only does it drastically reduce the number of sets needed, but the sampling distribution also matches the Banzhaf indices, in the sense explained in Sampling strategies for semi-values.

In order to work with this sampler for Banzhaf values, you can use MSRBanzhafValuation. In principle, it is also possible to select the MSRSampler when instantiating BanzhafValuation, but this might introduce some numerical instability, as explained in the document linked above.

References


  1. Wang, Jiachen T., and Ruoxi Jia. Data Banzhaf: A Robust Data Valuation Framework for Machine Learning. In Proceedings of The 26th International Conference on Artificial Intelligence and Statistics, 6388–6421. PMLR, 2023. 

BanzhafValuation

BanzhafValuation(
    utility: UtilityBase,
    sampler: IndexSampler,
    is_done: StoppingCriterion,
    skip_converged: bool = False,
    show_warnings: bool = True,
    progress: dict[str, Any] | bool = False,
)

Bases: SemivalueValuation

Computes Banzhaf values.

Source code in src/pydvl/valuation/methods/semivalue.py
def __init__(
    self,
    utility: UtilityBase,
    sampler: IndexSampler,
    is_done: StoppingCriterion,
    skip_converged: bool = False,
    show_warnings: bool = True,
    progress: dict[str, Any] | bool = False,
):
    super().__init__()
    self.utility = utility
    self.sampler = sampler
    self.is_done = is_done
    self.skip_converged = skip_converged
    if skip_converged:  # test whether the sampler supports skipping indices:
        self.sampler.skip_indices = np.array([], dtype=np.int_)
    self.show_warnings = show_warnings
    self.tqdm_args: dict[str, Any] = {"desc": str(self)}
    # HACK: parse additional args for the progress bar if any (we probably want
    #  something better)
    if isinstance(progress, bool):
        self.tqdm_args.update({"disable": not progress})
    elif isinstance(progress, dict):
        self.tqdm_args.update(progress)
    else:
        raise TypeError(f"Invalid type for progress: {type(progress)}")

log_coefficient property

log_coefficient: SemivalueCoefficient | None

Returns the log-coefficient of the Banzhaf valuation.

values

values(sort: bool = False) -> ValuationResult

Returns a copy of the valuation result.

The valuation must have been run with fit() before calling this method.

PARAMETER DESCRIPTION
sort

Whether to sort the valuation result by value before returning it.

TYPE: bool DEFAULT: False

Returns: The result of the valuation.

Source code in src/pydvl/valuation/base.py
def values(self, sort: bool = False) -> ValuationResult:
    """Returns a copy of the valuation result.

    The valuation must have been run with `fit()` before calling this method.

    Args:
        sort: Whether to sort the valuation result by value before returning it.
    Returns:
        The result of the valuation.
    """
    if not self.is_fitted:
        raise NotFittedException(type(self))
    assert self.result is not None

    from copy import copy

    r = copy(self.result)
    if sort:
        r.sort()
    return r

MSRBanzhafValuation

MSRBanzhafValuation(
    utility: UtilityBase,
    is_done: StoppingCriterion,
    batch_size: int = 1,
    seed: Seed | None = None,
    skip_converged: bool = False,
    show_warnings: bool = True,
    progress: dict[str, Any] | bool = False,
)

Bases: SemivalueValuation

Computes Banzhaf values with Maximum Sample Reuse.

This can be seen as a convenience class that wraps the MSRSampler but in fact it also skips importance sampling altogether, since the MSR sampling scheme already provides the correct weights for the Monte Carlo approximation. This can avoid some numerical inaccuracies that can arise, when using an MSRSampler with BanzhafValuation, despite the fact that the respective coefficients cancel each other out analytically.

Source code in src/pydvl/valuation/methods/banzhaf.py
def __init__(
    self,
    utility: UtilityBase,
    is_done: StoppingCriterion,
    batch_size: int = 1,
    seed: Seed | None = None,
    skip_converged: bool = False,
    show_warnings: bool = True,
    progress: dict[str, Any] | bool = False,
):
    sampler = MSRSampler(batch_size=batch_size, seed=seed)
    super().__init__(
        utility, sampler, is_done, skip_converged, show_warnings, progress
    )

log_coefficient property

log_coefficient: SemivalueCoefficient | None

Disable importance sampling for this method since we have a fixed sampler that already provides the correct weights for the Monte Carlo approximation.

values

values(sort: bool = False) -> ValuationResult

Returns a copy of the valuation result.

The valuation must have been run with fit() before calling this method.

PARAMETER DESCRIPTION
sort

Whether to sort the valuation result by value before returning it.

TYPE: bool DEFAULT: False

Returns: The result of the valuation.

Source code in src/pydvl/valuation/base.py
def values(self, sort: bool = False) -> ValuationResult:
    """Returns a copy of the valuation result.

    The valuation must have been run with `fit()` before calling this method.

    Args:
        sort: Whether to sort the valuation result by value before returning it.
    Returns:
        The result of the valuation.
    """
    if not self.is_fitted:
        raise NotFittedException(type(self))
    assert self.result is not None

    from copy import copy

    r = copy(self.result)
    if sort:
        r.sort()
    return r