Skip to content

pydvl.valuation.methods.msr_banzhaf

This module implements the MSR-Banzhaf valuation method, as described in (Wang et. al.)1.

References


  1. Wang, J.T. and Jia, R., 2023. Data Banzhaf: A Robust Data Valuation Framework for Machine Learning. In: Proceedings of The 26th International Conference on Artificial Intelligence and Statistics, pp. 6388-6421. 

MSRBanzhafValuation

MSRBanzhafValuation(
    utility: UtilityBase,
    sampler: MSRSampler,
    is_done: StoppingCriterion,
    progress: bool = True,
)

Bases: SemivalueValuation

Class to compute Maximum Sample Re-use (MSR) Banzhaf values.

See Data Valuation for an overview.

The MSR Banzhaf valuation approximates the Banzhaf valuation and is much more efficient than traditional Montecarlo approaches.

PARAMETER DESCRIPTION
utility

Utility object with model, data and scoring function.

TYPE: UtilityBase

sampler

Sampling scheme to use. Currently, only one MSRSampler is implemented. In the future, weighted MSRSamplers will be supported.

TYPE: MSRSampler

is_done

Stopping criterion to use.

TYPE: StoppingCriterion

progress

Whether to show a progress bar.

TYPE: bool DEFAULT: True

Source code in src/pydvl/valuation/methods/msr_banzhaf.py
def __init__(
    self,
    utility: UtilityBase,
    sampler: MSRSampler,
    is_done: StoppingCriterion,
    progress: bool = True,
):
    super().__init__(
        utility=utility,
        sampler=sampler,
        is_done=is_done,
        progress=progress,
    )

values

values(sort: bool = False) -> ValuationResult

Returns a copy of the valuation result.

The valuation must have been run with fit() before calling this method.

PARAMETER DESCRIPTION
sort

Whether to sort the valuation result before returning it.

TYPE: bool DEFAULT: False

Returns: The result of the valuation.

Source code in src/pydvl/valuation/base.py
def values(self, sort: bool = False) -> ValuationResult:
    """Returns a copy of the valuation result.

    The valuation must have been run with `fit()` before calling this method.

    Args:
        sort: Whether to sort the valuation result before returning it.
    Returns:
        The result of the valuation.
    """
    if not self.is_fitted:
        raise NotFittedException(type(self))
    assert self.result is not None

    from copy import copy

    r = copy(self.result)
    if sort:
        r.sort()
    return r

fit

fit(data: Dataset) -> Self

Calculate the MSR Banzhaf valuation on a dataset.

This method has to be called before calling values().

Calculating the Banzhaf valuation is a computationally expensive task that can be parallelized. To do so, call the fit() method inside a joblib.parallel_config context manager as follows:

from joblib import parallel_config

with parallel_config(n_jobs=4):
    valuation.fit(data)
Source code in src/pydvl/valuation/methods/msr_banzhaf.py
def fit(self, data: Dataset) -> Self:
    """Calculate the MSR Banzhaf valuation on a dataset.

    This method has to be called before calling `values()`.

    Calculating the Banzhaf valuation is a computationally expensive task that
    can be parallelized. To do so, call the `fit()` method inside a
    `joblib.parallel_config` context manager as follows:

    ```python
    from joblib import parallel_config

    with parallel_config(n_jobs=4):
        valuation.fit(data)
    ```

    """
    pos_result = ValuationResult.zeros(
        indices=data.indices,
        data_names=data.data_names,
        algorithm=self.algorithm_name,
    )

    neg_result = ValuationResult.zeros(
        indices=data.indices,
        data_names=data.data_names,
        algorithm=self.algorithm_name,
    )

    self.result = ValuationResult.zeros(
        indices=data.indices,
        data_names=data.data_names,
        algorithm=self.algorithm_name,
    )

    ensure_backend_has_generator_return()

    self.utility.training_data = data

    strategy = self.sampler.make_strategy(self.utility, self.coefficient)
    processor = delayed(strategy.process)

    with Parallel(return_as="generator_unordered") as parallel:
        with make_parallel_flag() as flag:
            delayed_evals = parallel(
                processor(batch=list(batch), is_interrupted=flag)
                for batch in self.sampler.generate_batches(data.indices)
            )
            for batch in Progress(delayed_evals, self.is_done, **self.tqdm_args):
                for evaluation in batch:
                    if evaluation.is_positive:
                        pos_result.update(evaluation.idx, evaluation.update)
                    else:
                        neg_result.update(evaluation.idx, evaluation.update)

                    self.result = self._combine_results(
                        pos_result, neg_result, data=data
                    )

                    if self.is_done(self.result):
                        flag.set()
                        self.sampler.interrupt()
                        break

                if self.is_done(self.result):
                    break

    return self