Skip to content

pydvl.valuation.methods.banzhaf

This module implements the Banzhaf valuation method, as described in Wang and Jia, (2022)1.

Data Banzhaf was proposed as a means to counteract the inherent stochasticity of the utility function in machine learning problems. It chooses the coefficients \(w(k)\) of the semi-value valuation function to be constant \(2^{n-1}\) for all set sizes \(k,\) yielding:

\[ v_\text{bzf}(i) = \frac{1}{2^{n-1}} \sum_{S \sim P(D_{-i})} [u(S_{+i}) - u(S)], \]

Background on semi-values

The Banzhaf valuation is a special case of the semi-value valuation method. You can read a short introduction in the documentation.

The intuition for picking a constant weight is that for any choice of weight function \(w\), one can always construct a utility with higher variance where \(w\) is greater. Therefore, in a worst-case sense, the best one can do is to pick a constant weight.

Data Banzhaf proves to outperform many other valuation methods in downstream tasks like best point removal.

Maximum Sample Reuse Banzhaf

A special sampling scheme (MSR) that reuses each sample to update every index in the dataset is shown by Wang and Jia to be optimal for the Banzhaf valuation: not only does it drastically reduce the number of sets needed, but the sampling distribution also matches the Banzhaf indices, in the sense explained in Sampling strategies for semi-values.

In order to work with this sampler for Banzhaf values, you can use MSRBanzhafValuation. In principle, it is also possible to select the MSRSampler when instantiating BanzhafValuation, but this might introduce some numerical instability, as explained in the document linked above.

References


  1. Wang, Jiachen T., and Ruoxi Jia. Data Banzhaf: A Robust Data Valuation Framework for Machine Learning. In Proceedings of The 26th International Conference on Artificial Intelligence and Statistics, 6388–6421. PMLR, 2023. 

BanzhafValuation

BanzhafValuation(
    utility: UtilityBase,
    sampler: IndexSampler,
    is_done: StoppingCriterion,
    skip_converged: bool = False,
    show_warnings: bool = True,
    progress: dict[str, Any] | bool = False,
)

Bases: SemivalueValuation

Computes Banzhaf values.

Source code in src/pydvl/valuation/methods/semivalue.py
def __init__(
    self,
    utility: UtilityBase,
    sampler: IndexSampler,
    is_done: StoppingCriterion,
    skip_converged: bool = False,
    show_warnings: bool = True,
    progress: dict[str, Any] | bool = False,
):
    super().__init__()
    self.utility = utility
    self.sampler = sampler
    self.is_done = is_done
    self.skip_converged = skip_converged
    if skip_converged:  # test whether the sampler supports skipping indices:
        self.sampler.skip_indices = np.array([], dtype=np.int_)
    self.show_warnings = show_warnings
    self.tqdm_args: dict[str, Any] = {"desc": str(self)}
    # HACK: parse additional args for the progress bar if any (we probably want
    #  something better)
    if isinstance(progress, bool):
        self.tqdm_args.update({"disable": not progress})
    elif isinstance(progress, dict):
        self.tqdm_args.update(progress)
    else:
        raise TypeError(f"Invalid type for progress: {type(progress)}")

log_coefficient property

log_coefficient: SemivalueCoefficient | None

Returns the log-coefficient of the Banzhaf valuation.

result property

The current valuation result (not a copy).

fit

fit(data: Dataset, continue_from: ValuationResult | None = None) -> Self

Fits the semi-value valuation to the data.

Access the results through the result property.

PARAMETER DESCRIPTION
data

Data for which to compute values

TYPE: Dataset

continue_from

A previously computed valuation result to continue from.

TYPE: ValuationResult | None DEFAULT: None

Source code in src/pydvl/valuation/methods/semivalue.py
@suppress_warnings(flag="show_warnings")
def fit(self, data: Dataset, continue_from: ValuationResult | None = None) -> Self:
    """Fits the semi-value valuation to the data.

    Access the results through the `result` property.

    Args:
        data: Data for which to compute values
        continue_from: A previously computed valuation result to continue from.

    """
    self._result = self._init_or_check_result(data, continue_from)
    ensure_backend_has_generator_return()

    self.is_done.reset()
    self.utility = self.utility.with_dataset(data)

    strategy = self.sampler.make_strategy(self.utility, self.log_coefficient)
    updater = self.sampler.result_updater(self._result)
    processor = delayed(strategy.process)

    with Parallel(return_as="generator_unordered") as parallel:
        with make_parallel_flag() as flag:
            delayed_evals = parallel(
                processor(batch=list(batch), is_interrupted=flag)
                for batch in self.sampler.generate_batches(data.indices)
            )
            for batch in Progress(delayed_evals, self.is_done, **self.tqdm_args):
                for update in batch:
                    self._result = updater.process(update)
                if self.is_done(self._result):
                    flag.set()
                    self.sampler.interrupt()
                    break
                if self.skip_converged:
                    self.sampler.skip_indices = data.indices[self.is_done.converged]
    logger.debug(f"Fitting done after {updater.n_updates} value updates.")
    return self

values

values(sort: bool = False) -> ValuationResult

Returns a copy of the valuation result.

The valuation must have been run with fit() before calling this method.

PARAMETER DESCRIPTION
sort

Whether to sort the valuation result by value before returning it.

TYPE: bool DEFAULT: False

Returns: The result of the valuation.

Source code in src/pydvl/valuation/base.py
@deprecated(
    target=None,
    deprecated_in="0.10.0",
    remove_in="0.11.0",
)
def values(self, sort: bool = False) -> ValuationResult:
    """Returns a copy of the valuation result.

    The valuation must have been run with `fit()` before calling this method.

    Args:
        sort: Whether to sort the valuation result by value before returning it.
    Returns:
        The result of the valuation.
    """
    if not self.is_fitted:
        raise NotFittedException(type(self))
    assert self._result is not None

    r = self._result.copy()
    if sort:
        r.sort(inplace=True)
    return r

MSRBanzhafValuation

MSRBanzhafValuation(
    utility: UtilityBase,
    is_done: StoppingCriterion,
    batch_size: int = 1,
    seed: Seed | None = None,
    skip_converged: bool = False,
    show_warnings: bool = True,
    progress: dict[str, Any] | bool = False,
)

Bases: SemivalueValuation

Computes Banzhaf values with Maximum Sample Reuse.

This can be seen as a convenience class that wraps the MSRSampler but in fact it also skips importance sampling altogether, since the MSR sampling scheme already provides the correct weights for the Monte Carlo approximation. This can avoid some numerical inaccuracies that can arise, when using an MSRSampler with BanzhafValuation, despite the fact that the respective coefficients cancel each other out analytically.

Source code in src/pydvl/valuation/methods/banzhaf.py
def __init__(
    self,
    utility: UtilityBase,
    is_done: StoppingCriterion,
    batch_size: int = 1,
    seed: Seed | None = None,
    skip_converged: bool = False,
    show_warnings: bool = True,
    progress: dict[str, Any] | bool = False,
):
    sampler = MSRSampler(batch_size=batch_size, seed=seed)
    super().__init__(
        utility, sampler, is_done, skip_converged, show_warnings, progress
    )

log_coefficient property

log_coefficient: SemivalueCoefficient | None

Disable importance sampling for this method since we have a fixed sampler that already provides the correct weights for the Monte Carlo approximation.

result property

The current valuation result (not a copy).

fit

fit(data: Dataset, continue_from: ValuationResult | None = None) -> Self

Fits the semi-value valuation to the data.

Access the results through the result property.

PARAMETER DESCRIPTION
data

Data for which to compute values

TYPE: Dataset

continue_from

A previously computed valuation result to continue from.

TYPE: ValuationResult | None DEFAULT: None

Source code in src/pydvl/valuation/methods/semivalue.py
@suppress_warnings(flag="show_warnings")
def fit(self, data: Dataset, continue_from: ValuationResult | None = None) -> Self:
    """Fits the semi-value valuation to the data.

    Access the results through the `result` property.

    Args:
        data: Data for which to compute values
        continue_from: A previously computed valuation result to continue from.

    """
    self._result = self._init_or_check_result(data, continue_from)
    ensure_backend_has_generator_return()

    self.is_done.reset()
    self.utility = self.utility.with_dataset(data)

    strategy = self.sampler.make_strategy(self.utility, self.log_coefficient)
    updater = self.sampler.result_updater(self._result)
    processor = delayed(strategy.process)

    with Parallel(return_as="generator_unordered") as parallel:
        with make_parallel_flag() as flag:
            delayed_evals = parallel(
                processor(batch=list(batch), is_interrupted=flag)
                for batch in self.sampler.generate_batches(data.indices)
            )
            for batch in Progress(delayed_evals, self.is_done, **self.tqdm_args):
                for update in batch:
                    self._result = updater.process(update)
                if self.is_done(self._result):
                    flag.set()
                    self.sampler.interrupt()
                    break
                if self.skip_converged:
                    self.sampler.skip_indices = data.indices[self.is_done.converged]
    logger.debug(f"Fitting done after {updater.n_updates} value updates.")
    return self

values

values(sort: bool = False) -> ValuationResult

Returns a copy of the valuation result.

The valuation must have been run with fit() before calling this method.

PARAMETER DESCRIPTION
sort

Whether to sort the valuation result by value before returning it.

TYPE: bool DEFAULT: False

Returns: The result of the valuation.

Source code in src/pydvl/valuation/base.py
@deprecated(
    target=None,
    deprecated_in="0.10.0",
    remove_in="0.11.0",
)
def values(self, sort: bool = False) -> ValuationResult:
    """Returns a copy of the valuation result.

    The valuation must have been run with `fit()` before calling this method.

    Args:
        sort: Whether to sort the valuation result by value before returning it.
    Returns:
        The result of the valuation.
    """
    if not self.is_fitted:
        raise NotFittedException(type(self))
    assert self._result is not None

    r = self._result.copy()
    if sort:
        r.sort(inplace=True)
    return r