Skip to content

pydvl.valuation.samplers.permutation

Permutation-based samplers.

TODO: explain the formulation and the different samplers.

References


  1. Mitchell, Rory, Joshua Cooper, Eibe Frank, and Geoffrey Holmes. Sampling Permutations for Shapley Value Estimation. Journal of Machine Learning Research 23, no. 43 (2022): 1–46. 

  2. Watson, Lauren, Zeno Kujawa, Rayna Andreeva, Hao-Tsung Yang, Tariq Elahi, and Rik Sarkar. Accelerated Shapley Value Approximation for Data Evaluation. arXiv, 9 November 2023. 

PermutationSampler

PermutationSampler(
    truncation: TruncationPolicy | None = None, seed: Seed | None = None
)

Bases: StochasticSamplerMixin, IndexSampler

Sample permutations of indices and iterate through each returning increasing subsets, as required for the permutation definition of semi-values.

For a permutation (3,1,4,2), this sampler returns in sequence the following [Samples][pydvl.valuation.samplers.Sample] (tuples of index and subset):

(3, {3}), (1, {3,1}), (4, {3,1,4}) and (2, {3,1,4,2}).

Batching

PermutationSamplers always batch their outputs to include a whole permutation of the index set, i.e. the batch size is always the number of indices.

PARAMETER DESCRIPTION
truncation

A policy to stop the permutation early.

TYPE: TruncationPolicy | None DEFAULT: None

seed

Seed for the random number generator.

TYPE: Seed | None DEFAULT: None

Source code in src/pydvl/valuation/samplers/permutation.py
def __init__(
    self, truncation: TruncationPolicy | None = None, seed: Seed | None = None
):
    super().__init__(seed=seed)
    self.truncation = truncation or NoTruncation()

generate_batches

generate_batches(indices: IndexSetT) -> BatchGenerator

Batches the samples and yields them.

Source code in src/pydvl/valuation/samplers/base.py
def generate_batches(self, indices: IndexSetT) -> BatchGenerator:
    """Batches the samples and yields them."""

    # create an empty generator if the indices are empty. `generate_batches` is
    # a generator function because it has a yield statement later in its body.
    # Inside generator function, `return` acts like a `break`, which produces an
    # empty generator function. See: https://stackoverflow.com/a/13243870
    if len(indices) == 0:
        return

    self._interrupted = False
    self._n_samples = 0
    for batch in chunked(self._generate(indices), self.batch_size):
        yield batch
        self._n_samples += len(batch)
        if self._interrupted:
            break

sample_limit

sample_limit(indices: IndexSetT) -> int | None

Number of samples that can be generated from the indices.

Returns None if the number of samples is infinite, which is the case for most stochastic samplers.

Source code in src/pydvl/valuation/samplers/base.py
def sample_limit(self, indices: IndexSetT) -> int | None:
    """Number of samples that can be generated from the indices.

    Returns None if the number of samples is infinite, which is the case for most
    stochastic samplers.
    """
    if len(indices) == 0:
        out = 0
    else:
        out = None
    return out

AntitheticPermutationSampler

AntitheticPermutationSampler(
    truncation: TruncationPolicy | None = None, seed: Seed | None = None
)

Bases: PermutationSampler

Samples permutations like PermutationSampler, but after each permutation, it returns the same permutation in reverse order.

This sampler was suggested in (Mitchell et al. 2022)1

New in version 0.7.1

Source code in src/pydvl/valuation/samplers/permutation.py
def __init__(
    self, truncation: TruncationPolicy | None = None, seed: Seed | None = None
):
    super().__init__(seed=seed)
    self.truncation = truncation or NoTruncation()

generate_batches

generate_batches(indices: IndexSetT) -> BatchGenerator

Batches the samples and yields them.

Source code in src/pydvl/valuation/samplers/base.py
def generate_batches(self, indices: IndexSetT) -> BatchGenerator:
    """Batches the samples and yields them."""

    # create an empty generator if the indices are empty. `generate_batches` is
    # a generator function because it has a yield statement later in its body.
    # Inside generator function, `return` acts like a `break`, which produces an
    # empty generator function. See: https://stackoverflow.com/a/13243870
    if len(indices) == 0:
        return

    self._interrupted = False
    self._n_samples = 0
    for batch in chunked(self._generate(indices), self.batch_size):
        yield batch
        self._n_samples += len(batch)
        if self._interrupted:
            break

sample_limit

sample_limit(indices: IndexSetT) -> int | None

Number of samples that can be generated from the indices.

Returns None if the number of samples is infinite, which is the case for most stochastic samplers.

Source code in src/pydvl/valuation/samplers/base.py
def sample_limit(self, indices: IndexSetT) -> int | None:
    """Number of samples that can be generated from the indices.

    Returns None if the number of samples is infinite, which is the case for most
    stochastic samplers.
    """
    if len(indices) == 0:
        out = 0
    else:
        out = None
    return out

DeterministicPermutationSampler

DeterministicPermutationSampler(
    truncation: TruncationPolicy | None = None, seed: Seed | None = None
)

Bases: PermutationSampler

Samples all n! permutations of the indices deterministically, and iterates through them, returning sets as required for the permutation-based definition of semi-values.

Source code in src/pydvl/valuation/samplers/permutation.py
def __init__(
    self, truncation: TruncationPolicy | None = None, seed: Seed | None = None
):
    super().__init__(seed=seed)
    self.truncation = truncation or NoTruncation()

generate_batches

generate_batches(indices: IndexSetT) -> BatchGenerator

Batches the samples and yields them.

Source code in src/pydvl/valuation/samplers/base.py
def generate_batches(self, indices: IndexSetT) -> BatchGenerator:
    """Batches the samples and yields them."""

    # create an empty generator if the indices are empty. `generate_batches` is
    # a generator function because it has a yield statement later in its body.
    # Inside generator function, `return` acts like a `break`, which produces an
    # empty generator function. See: https://stackoverflow.com/a/13243870
    if len(indices) == 0:
        return

    self._interrupted = False
    self._n_samples = 0
    for batch in chunked(self._generate(indices), self.batch_size):
        yield batch
        self._n_samples += len(batch)
        if self._interrupted:
            break

PermutationEvaluationStrategy

PermutationEvaluationStrategy(
    sampler: PermutationSampler,
    utility: UtilityBase,
    coefficient: Callable[[int, int], float] | None = None,
)

Bases: EvaluationStrategy[PermutationSampler, ValueUpdate]

Computes marginal values for permutation sampling schemes.

This strategy iterates over permutations from left to right, computing the marginal utility wrt. the previous one at each step to save computation.

Source code in src/pydvl/valuation/samplers/permutation.py
def __init__(
    self,
    sampler: PermutationSampler,
    utility: UtilityBase,
    coefficient: Callable[[int, int], float] | None = None,
):
    super().__init__(sampler, utility, coefficient)
    self.truncation = copy(sampler.truncation)
    self.truncation.reset(utility)  # Perform initial setup (e.g. total_utility)