Skip to content

pydvl.valuation.samplers.permutation

Permutation-based samplers.

TODO: explain the formulation and the different samplers.

References


  1. Mitchell, Rory, Joshua Cooper, Eibe Frank, and Geoffrey Holmes. Sampling Permutations for Shapley Value Estimation. Journal of Machine Learning Research 23, no. 43 (2022): 1–46. 

  2. Watson, Lauren, Zeno Kujawa, Rayna Andreeva, Hao-Tsung Yang, Tariq Elahi, and Rik Sarkar. Accelerated Shapley Value Approximation for Data Evaluation. arXiv, 9 November 2023. 

PermutationSamplerBase

PermutationSamplerBase(
    *args,
    truncation: TruncationPolicy | None = None,
    batch_size: int = 1,
    **kwargs,
)

Bases: IndexSampler

Base class for permutation samplers.

Source code in src/pydvl/valuation/samplers/permutation.py
def __init__(
    self,
    *args,
    truncation: TruncationPolicy | None = None,
    batch_size: int = 1,
    **kwargs,
):
    super().__init__(batch_size=batch_size)
    self.truncation = truncation or NoTruncation()

skip_indices property writable

skip_indices: IndexSetT

Indices being skipped in the sampler. The exact behaviour will be sampler-dependent, so that setting this property is disabled by default.

interrupt

interrupt()

Signals the sampler to stop generating samples after the current batch.

Source code in src/pydvl/valuation/samplers/base.py
def interrupt(self):
    """Signals the sampler to stop generating samples after the current batch."""
    self._interrupted = True

__len__

__len__() -> int

Returns the length of the current sample generation in generate_batches.

RAISES DESCRIPTION
`TypeError`

if the sampler is infinite or generate_batches has not been called yet.

Source code in src/pydvl/valuation/samplers/base.py
def __len__(self) -> int:
    """Returns the length of the current sample generation in generate_batches.

    Raises:
        `TypeError`: if the sampler is infinite or
            [generate_batches][pydvl.valuation.samplers.IndexSampler.generate_batches]
            has not been called yet.
    """
    if self._len is None:
        raise TypeError(f"This {self.__class__.__name__} has no length")
    return self._len

generate_batches

generate_batches(indices: IndexSetT) -> BatchGenerator

Batches the samples and yields them.

Source code in src/pydvl/valuation/samplers/base.py
def generate_batches(self, indices: IndexSetT) -> BatchGenerator:
    """Batches the samples and yields them."""
    self._len = self.sample_limit(indices)

    # Create an empty generator if the indices are empty: `return` acts like a
    # `break`, and produces an empty generator.
    if len(indices) == 0:
        return

    self._interrupted = False
    self._n_samples = 0
    for batch in chunked(self.generate(indices), self.batch_size):
        self._n_samples += len(batch)
        yield batch
        if self._interrupted:
            break

sample_limit abstractmethod

sample_limit(indices: IndexSetT) -> int | None

Number of samples that can be generated from the indices.

PARAMETER DESCRIPTION
indices

The indices used in the sampler.

TYPE: IndexSetT

RETURNS DESCRIPTION
int | None

The maximum number of samples that will be generated, or None if the number of samples is infinite. This will depend, among other things, on the type of IndexIteration.

Source code in src/pydvl/valuation/samplers/base.py
@abstractmethod
def sample_limit(self, indices: IndexSetT) -> int | None:
    """Number of samples that can be generated from the indices.

    Args:
        indices: The indices used in the sampler.

    Returns:
        The maximum number of samples that will be generated, or  `None` if the
            number of samples is infinite. This will depend, among other things,
            on the type of [IndexIteration][pydvl.valuation.samplers.IndexIteration].
    """
    ...

generate abstractmethod

generate(indices: IndexSetT) -> SampleGenerator

Generates single samples.

IndexSampler.generate_batches() will batch these samples according to the batch size set upon construction.

PARAMETER DESCRIPTION
indices

TYPE: IndexSetT

YIELDS DESCRIPTION
SampleGenerator

A tuple (idx, subset) for each sample.

Source code in src/pydvl/valuation/samplers/base.py
@abstractmethod
def generate(self, indices: IndexSetT) -> SampleGenerator:
    """Generates single samples.

    `IndexSampler.generate_batches()` will batch these samples according to the
    batch size set upon construction.

    Args:
        indices:

    Yields:
        A tuple (idx, subset) for each sample.
    """
    ...

result_updater

result_updater(result: ValuationResult) -> ResultUpdater[ValueUpdateT]

Returns a callable that updates a valuation result with a value update.

Because we use log-space computation for numerical stability, the default result updater keeps track of several quantities required to maintain accurate running 1st and 2nd moments.

PARAMETER DESCRIPTION
result

The result to update

TYPE: ValuationResult

Returns: A callable object that updates the result with a value update

Source code in src/pydvl/valuation/samplers/base.py
def result_updater(self, result: ValuationResult) -> ResultUpdater[ValueUpdateT]:
    """Returns a callable that updates a valuation result with a value update.

    Because we use log-space computation for numerical stability, the default result
    updater keeps track of several quantities required to maintain accurate running
    1st and 2nd moments.

    Args:
        result: The result to update
    Returns:
        A callable object that updates the result with a value update
    """
    return LogResultUpdater(result)

PermutationSampler

PermutationSampler(
    truncation: TruncationPolicy | None = None,
    seed: Seed | None = None,
    batch_size: int = 1,
)

Bases: StochasticSamplerMixin, PermutationSamplerBase

Samples permutations of indices.

Batching

Even though this sampler supports batching, it is not recommended to use it since the PermutationEvaluationStrategy processes whole permutations in one go, effectively batching the computation of up to n-1 marginal utilities in one process.

PARAMETER DESCRIPTION
truncation

A policy to stop the permutation early.

TYPE: TruncationPolicy | None DEFAULT: None

seed

Seed for the random number generator.

TYPE: Seed | None DEFAULT: None

Source code in src/pydvl/valuation/samplers/permutation.py
def __init__(
    self,
    truncation: TruncationPolicy | None = None,
    seed: Seed | None = None,
    batch_size: int = 1,
):
    super().__init__(seed=seed, truncation=truncation, batch_size=batch_size)

interrupt

interrupt()

Signals the sampler to stop generating samples after the current batch.

Source code in src/pydvl/valuation/samplers/base.py
def interrupt(self):
    """Signals the sampler to stop generating samples after the current batch."""
    self._interrupted = True

__len__

__len__() -> int

Returns the length of the current sample generation in generate_batches.

RAISES DESCRIPTION
`TypeError`

if the sampler is infinite or generate_batches has not been called yet.

Source code in src/pydvl/valuation/samplers/base.py
def __len__(self) -> int:
    """Returns the length of the current sample generation in generate_batches.

    Raises:
        `TypeError`: if the sampler is infinite or
            [generate_batches][pydvl.valuation.samplers.IndexSampler.generate_batches]
            has not been called yet.
    """
    if self._len is None:
        raise TypeError(f"This {self.__class__.__name__} has no length")
    return self._len

generate_batches

generate_batches(indices: IndexSetT) -> BatchGenerator

Batches the samples and yields them.

Source code in src/pydvl/valuation/samplers/base.py
def generate_batches(self, indices: IndexSetT) -> BatchGenerator:
    """Batches the samples and yields them."""
    self._len = self.sample_limit(indices)

    # Create an empty generator if the indices are empty: `return` acts like a
    # `break`, and produces an empty generator.
    if len(indices) == 0:
        return

    self._interrupted = False
    self._n_samples = 0
    for batch in chunked(self.generate(indices), self.batch_size):
        self._n_samples += len(batch)
        yield batch
        if self._interrupted:
            break

result_updater

result_updater(result: ValuationResult) -> ResultUpdater[ValueUpdateT]

Returns a callable that updates a valuation result with a value update.

Because we use log-space computation for numerical stability, the default result updater keeps track of several quantities required to maintain accurate running 1st and 2nd moments.

PARAMETER DESCRIPTION
result

The result to update

TYPE: ValuationResult

Returns: A callable object that updates the result with a value update

Source code in src/pydvl/valuation/samplers/base.py
def result_updater(self, result: ValuationResult) -> ResultUpdater[ValueUpdateT]:
    """Returns a callable that updates a valuation result with a value update.

    Because we use log-space computation for numerical stability, the default result
    updater keeps track of several quantities required to maintain accurate running
    1st and 2nd moments.

    Args:
        result: The result to update
    Returns:
        A callable object that updates the result with a value update
    """
    return LogResultUpdater(result)

generate

generate(indices: IndexSetT) -> SampleGenerator

Generates the permutation samples.

PARAMETER DESCRIPTION
indices

The indices to sample from. If empty, no samples are generated. If skip_indices is set, these indices are removed from the set before generating the permutation.

TYPE: IndexSetT

Source code in src/pydvl/valuation/samplers/permutation.py
def generate(self, indices: IndexSetT) -> SampleGenerator:
    """Generates the permutation samples.

    Args:
        indices: The indices to sample from. If empty, no samples are generated. If
            [skip_indices][pydvl.valuation.samplers.base.IndexSampler.skip_indices]
            is set, these indices are removed from the set before generating the
            permutation.
    """
    if len(indices) == 0:
        return
    while True:
        _indices = np.setdiff1d(indices, self.skip_indices)
        yield Sample(None, self._rng.permutation(_indices))

AntitheticPermutationSampler

AntitheticPermutationSampler(
    truncation: TruncationPolicy | None = None,
    seed: Seed | None = None,
    batch_size: int = 1,
)

Bases: PermutationSampler

Samples permutations like PermutationSampler, but after each permutation, it returns the same permutation in reverse order.

This sampler was suggested in (Mitchell et al. 2022)1

New in version 0.7.1

Source code in src/pydvl/valuation/samplers/permutation.py
def __init__(
    self,
    truncation: TruncationPolicy | None = None,
    seed: Seed | None = None,
    batch_size: int = 1,
):
    super().__init__(seed=seed, truncation=truncation, batch_size=batch_size)

interrupt

interrupt()

Signals the sampler to stop generating samples after the current batch.

Source code in src/pydvl/valuation/samplers/base.py
def interrupt(self):
    """Signals the sampler to stop generating samples after the current batch."""
    self._interrupted = True

__len__

__len__() -> int

Returns the length of the current sample generation in generate_batches.

RAISES DESCRIPTION
`TypeError`

if the sampler is infinite or generate_batches has not been called yet.

Source code in src/pydvl/valuation/samplers/base.py
def __len__(self) -> int:
    """Returns the length of the current sample generation in generate_batches.

    Raises:
        `TypeError`: if the sampler is infinite or
            [generate_batches][pydvl.valuation.samplers.IndexSampler.generate_batches]
            has not been called yet.
    """
    if self._len is None:
        raise TypeError(f"This {self.__class__.__name__} has no length")
    return self._len

generate_batches

generate_batches(indices: IndexSetT) -> BatchGenerator

Batches the samples and yields them.

Source code in src/pydvl/valuation/samplers/base.py
def generate_batches(self, indices: IndexSetT) -> BatchGenerator:
    """Batches the samples and yields them."""
    self._len = self.sample_limit(indices)

    # Create an empty generator if the indices are empty: `return` acts like a
    # `break`, and produces an empty generator.
    if len(indices) == 0:
        return

    self._interrupted = False
    self._n_samples = 0
    for batch in chunked(self.generate(indices), self.batch_size):
        self._n_samples += len(batch)
        yield batch
        if self._interrupted:
            break

result_updater

result_updater(result: ValuationResult) -> ResultUpdater[ValueUpdateT]

Returns a callable that updates a valuation result with a value update.

Because we use log-space computation for numerical stability, the default result updater keeps track of several quantities required to maintain accurate running 1st and 2nd moments.

PARAMETER DESCRIPTION
result

The result to update

TYPE: ValuationResult

Returns: A callable object that updates the result with a value update

Source code in src/pydvl/valuation/samplers/base.py
def result_updater(self, result: ValuationResult) -> ResultUpdater[ValueUpdateT]:
    """Returns a callable that updates a valuation result with a value update.

    Because we use log-space computation for numerical stability, the default result
    updater keeps track of several quantities required to maintain accurate running
    1st and 2nd moments.

    Args:
        result: The result to update
    Returns:
        A callable object that updates the result with a value update
    """
    return LogResultUpdater(result)

DeterministicPermutationSampler

DeterministicPermutationSampler(
    *args,
    truncation: TruncationPolicy | None = None,
    batch_size: int = 1,
    **kwargs,
)

Bases: PermutationSamplerBase

Samples all n! permutations of the indices deterministically, and iterates through them, returning sets as required for the permutation-based definition of semi-values.

Source code in src/pydvl/valuation/samplers/permutation.py
def __init__(
    self,
    *args,
    truncation: TruncationPolicy | None = None,
    batch_size: int = 1,
    **kwargs,
):
    super().__init__(batch_size=batch_size)
    self.truncation = truncation or NoTruncation()

skip_indices property writable

skip_indices: IndexSetT

Indices being skipped in the sampler. The exact behaviour will be sampler-dependent, so that setting this property is disabled by default.

interrupt

interrupt()

Signals the sampler to stop generating samples after the current batch.

Source code in src/pydvl/valuation/samplers/base.py
def interrupt(self):
    """Signals the sampler to stop generating samples after the current batch."""
    self._interrupted = True

__len__

__len__() -> int

Returns the length of the current sample generation in generate_batches.

RAISES DESCRIPTION
`TypeError`

if the sampler is infinite or generate_batches has not been called yet.

Source code in src/pydvl/valuation/samplers/base.py
def __len__(self) -> int:
    """Returns the length of the current sample generation in generate_batches.

    Raises:
        `TypeError`: if the sampler is infinite or
            [generate_batches][pydvl.valuation.samplers.IndexSampler.generate_batches]
            has not been called yet.
    """
    if self._len is None:
        raise TypeError(f"This {self.__class__.__name__} has no length")
    return self._len

generate_batches

generate_batches(indices: IndexSetT) -> BatchGenerator

Batches the samples and yields them.

Source code in src/pydvl/valuation/samplers/base.py
def generate_batches(self, indices: IndexSetT) -> BatchGenerator:
    """Batches the samples and yields them."""
    self._len = self.sample_limit(indices)

    # Create an empty generator if the indices are empty: `return` acts like a
    # `break`, and produces an empty generator.
    if len(indices) == 0:
        return

    self._interrupted = False
    self._n_samples = 0
    for batch in chunked(self.generate(indices), self.batch_size):
        self._n_samples += len(batch)
        yield batch
        if self._interrupted:
            break

result_updater

result_updater(result: ValuationResult) -> ResultUpdater[ValueUpdateT]

Returns a callable that updates a valuation result with a value update.

Because we use log-space computation for numerical stability, the default result updater keeps track of several quantities required to maintain accurate running 1st and 2nd moments.

PARAMETER DESCRIPTION
result

The result to update

TYPE: ValuationResult

Returns: A callable object that updates the result with a value update

Source code in src/pydvl/valuation/samplers/base.py
def result_updater(self, result: ValuationResult) -> ResultUpdater[ValueUpdateT]:
    """Returns a callable that updates a valuation result with a value update.

    Because we use log-space computation for numerical stability, the default result
    updater keeps track of several quantities required to maintain accurate running
    1st and 2nd moments.

    Args:
        result: The result to update
    Returns:
        A callable object that updates the result with a value update
    """
    return LogResultUpdater(result)

PermutationEvaluationStrategy

PermutationEvaluationStrategy(
    sampler: PermutationSamplerBase,
    utility: UtilityBase,
    coefficient: Callable[[int, int], float] | None = None,
)

Bases: EvaluationStrategy[PermutationSamplerBase, ValueUpdate]

Computes marginal values for permutation sampling schemes in log-space.

This strategy iterates over permutations from left to right, computing the marginal utility wrt. the previous one at each step to save computation.

Source code in src/pydvl/valuation/samplers/permutation.py
def __init__(
    self,
    sampler: PermutationSamplerBase,
    utility: UtilityBase,
    coefficient: Callable[[int, int], float] | None = None,
):
    super().__init__(sampler, utility, coefficient)
    self.truncation = copy(sampler.truncation)
    self.truncation.reset(utility)  # Perform initial setup (e.g. total_utility)