Skip to content

pydvl.valuation.methods.data_banzhaf

This module implements the Banzhaf valuation method, as described in (Wang and Jia, 2022)1.

Data Banzhaf was proposed as a means to counteract the inherent stochasticity of the utility function in machine learning problems. It chooses the coefficients \(w(k)\) of the semi-value valuation function to be constant:

\[w(k) := 2^{n-1},\]

for all set sizes \(k\). The intuition for picking a constant weight is that for any choice of weight function \(w\), one can always construct a utility with higher variance where \(w\) is greater. Therefore, in a worst-case sense, the best one can do is to pick a constant weight.

Data Banzhaf proves to outperform many other valuation methods in downstream tasks like best point removal, but can show some

References


  1. Wang, Jiachen T., and Ruoxi Jia. Data Banzhaf: A Robust Data Valuation Framework for Machine Learning. In Proceedings of The 26th International Conference on Artificial Intelligence and Statistics, 6388–6421. PMLR, 2023. 

DataBanzhafValuation

DataBanzhafValuation(
    utility: UtilityBase,
    sampler: IndexSampler,
    is_done: StoppingCriterion,
    progress: dict[str, Any] | bool = False,
)

Bases: SemivalueValuation

Computes Banzhaf values.

Source code in src/pydvl/valuation/methods/semivalue.py
def __init__(
    self,
    utility: UtilityBase,
    sampler: IndexSampler,
    is_done: StoppingCriterion,
    progress: dict[str, Any] | bool = False,
):
    super().__init__()
    self.utility = utility
    self.sampler = sampler
    self.is_done = is_done
    self.tqdm_args: dict[str, Any] = {
        "desc": f"{self.__class__.__name__}: {str(is_done)}"
    }
    # HACK: parse additional args for the progress bar if any (we probably want
    #  something better)
    if isinstance(progress, bool):
        self.tqdm_args.update({"disable": not progress})
    else:
        self.tqdm_args.update(progress if isinstance(progress, dict) else {})

values

values(sort: bool = False) -> ValuationResult

Returns a copy of the valuation result.

The valuation must have been run with fit() before calling this method.

PARAMETER DESCRIPTION
sort

Whether to sort the valuation result before returning it.

TYPE: bool DEFAULT: False

Returns: The result of the valuation.

Source code in src/pydvl/valuation/base.py
def values(self, sort: bool = False) -> ValuationResult:
    """Returns a copy of the valuation result.

    The valuation must have been run with `fit()` before calling this method.

    Args:
        sort: Whether to sort the valuation result before returning it.
    Returns:
        The result of the valuation.
    """
    if not self.is_fitted:
        raise NotFittedException(type(self))
    assert self.result is not None

    from copy import copy

    r = copy(self.result)
    if sort:
        r.sort()
    return r