Skip to content

Inversion

Contains methods to invert the hessian vector product.

InversionMethod

Bases: str, Enum

Different inversion methods types.

InversionRegistry

A registry to hold inversion methods for different models.

register(model_type, inversion_method, overwrite=False) classmethod

Register a function for a specific model type and inversion method.

The function to be registered must conform to the following signature: (model: TwiceDifferentiable, training_data: DataLoaderType, b: TensorType, hessian_perturbation: float = 0.0, ...).

PARAMETER DESCRIPTION
model_type

The type of the model the function should be registered for.

TYPE: Type[TwiceDifferentiable]

inversion_method

The inversion method the function should be registered for.

TYPE: InversionMethod

overwrite

If True, allows overwriting of an existing registered function for the same model type and inversion method. If False, logs a warning when attempting to register a function for an already registered model type and inversion method.

TYPE: bool DEFAULT: False

RAISES DESCRIPTION
TypeError

If the provided model_type or inversion_method are of the wrong type.

ValueError

If the function to be registered does not match the required signature.

RETURNS DESCRIPTION

A decorator for registering a function.

Source code in src/pydvl/influence/inversion.py
@classmethod
def register(
    cls,
    model_type: Type[TwiceDifferentiable],
    inversion_method: InversionMethod,
    overwrite: bool = False,
):
    """
    Register a function for a specific model type and inversion method.

    The function to be registered must conform to the following signature:
    `(model: TwiceDifferentiable, training_data: DataLoaderType, b: TensorType,
    hessian_perturbation: float = 0.0, ...)`.

    Args:
        model_type: The type of the model the function should be registered for.
        inversion_method: The inversion method the function should be
            registered for.
        overwrite: If ``True``, allows overwriting of an existing registered
            function for the same model type and inversion method. If ``False``,
            logs a warning when attempting to register a function for an already
            registered model type and inversion method.

    Raises:
        TypeError: If the provided model_type or inversion_method are of the wrong type.
        ValueError: If the function to be registered does not match the required signature.

    Returns:
        A decorator for registering a function.
    """

    if not isinstance(model_type, type):
        raise TypeError(
            f"'model_type' is of type {type(model_type)} but should be a Type[TwiceDifferentiable]"
        )

    if not isinstance(inversion_method, InversionMethod):
        raise TypeError(
            f"'inversion_method' must be an 'InversionMethod' "
            f"but has type {type(inversion_method)} instead."
        )

    key = (model_type, inversion_method)

    def decorator(func):
        if not overwrite and key in cls.registry:
            warnings.warn(
                f"There is already a function registered for model type {model_type} "
                f"and inversion method {inversion_method}. "
                f"To overwrite the existing function {cls.registry.get(key)} with {func}, set overwrite to True."
            )
        sig = inspect.signature(func)
        params = list(sig.parameters.values())

        expected_args = [
            ("model", model_type),
            ("training_data", DataLoaderType.__bound__),
            ("b", model_type.tensor_type()),
            ("hessian_perturbation", float),
        ]

        for (name, typ), param in zip(expected_args, params):
            if not (
                isinstance(param.annotation, typ)
                or issubclass(param.annotation, typ)
            ):
                raise ValueError(
                    f'Parameter "{name}" must be of type "{typ.__name__}"'
                )

        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            return func(*args, **kwargs)

        cls.registry[key] = wrapper
        return wrapper

    return decorator

call(inversion_method, model, training_data, b, hessian_perturbation, **kwargs) classmethod

Call a registered function with the provided parameters.

PARAMETER DESCRIPTION
inversion_method

The inversion method to use.

TYPE: InversionMethod

model

A model wrapped in the TwiceDifferentiable interface.

TYPE: TwiceDifferentiable

training_data

The training data to use.

TYPE: DataLoaderType

b

Array as the right hand side of the equation \(Ax = b\).

TYPE: TensorType

hessian_perturbation

Regularization of the hessian.

kwargs

Additional keyword arguments to pass to the inversion method.

DEFAULT: {}

RETURNS DESCRIPTION
InverseHvpResult

An instance of InverseHvpResult, that contains an array, which solves the inverse problem, i.e. it returns \(x\) such that \(Ax = b\), and a dictionary containing information about the inversion process.

Source code in src/pydvl/influence/inversion.py
@classmethod
def call(
    cls,
    inversion_method: InversionMethod,
    model: TwiceDifferentiable,
    training_data: DataLoaderType,
    b: TensorType,
    hessian_perturbation,
    **kwargs,
) -> InverseHvpResult:
    r"""
    Call a registered function with the provided parameters.

    Args:
        inversion_method: The inversion method to use.
        model: A model wrapped in the TwiceDifferentiable interface.
        training_data: The training data to use.
        b: Array as the right hand side of the equation \(Ax = b\).
        hessian_perturbation: Regularization of the hessian.
        kwargs: Additional keyword arguments to pass to the inversion method.

    Returns:
        An instance of [InverseHvpResult][pydvl.influence.twice_differentiable.InverseHvpResult],
            that contains an array, which solves the inverse problem,
            i.e. it returns \(x\) such that \(Ax = b\), and a dictionary containing information
            about the inversion process.
    """

    return cls.get(type(model), inversion_method)(
        model, training_data, b, hessian_perturbation, **kwargs
    )

solve_hvp(inversion_method, model, training_data, b, *, hessian_perturbation=0.0, **kwargs)

Finds \( x \) such that \( Ax = b \), where \( A \) is the hessian of the model, and \( b \) a vector. Depending on the inversion method, the hessian is either calculated directly and then inverted, or implicitly and then inverted through matrix vector product. The method also allows to add a small regularization term (hessian_perturbation) to facilitate inversion of non fully trained models.

PARAMETER DESCRIPTION
inversion_method

TYPE: InversionMethod

model

A model wrapped in the TwiceDifferentiable interface.

TYPE: TwiceDifferentiable

training_data

TYPE: DataLoaderType

b

Array as the right hand side of the equation \( Ax = b \)

TYPE: TensorType

hessian_perturbation

regularization of the hessian.

TYPE: float DEFAULT: 0.0

kwargs

kwargs to pass to the inversion method.

TYPE: Any DEFAULT: {}

RETURNS DESCRIPTION
InverseHvpResult

Instance of InverseHvpResult, with an array that solves the inverse problem, i.e., it returns \( x \) such that \( Ax = b \) and a dictionary containing information about the inversion process.

Source code in src/pydvl/influence/inversion.py
def solve_hvp(
    inversion_method: InversionMethod,
    model: TwiceDifferentiable,
    training_data: DataLoaderType,
    b: TensorType,
    *,
    hessian_perturbation: float = 0.0,
    **kwargs: Any,
) -> InverseHvpResult:
    r"""
    Finds \( x \) such that \( Ax = b \), where \( A \) is the hessian of the model,
    and \( b \) a vector. Depending on the inversion method, the hessian is either
    calculated directly and then inverted, or implicitly and then inverted through
    matrix vector product. The method also allows to add a small regularization term
    (hessian_perturbation) to facilitate inversion of non fully trained models.

    Args:
        inversion_method:
        model: A model wrapped in the TwiceDifferentiable interface.
        training_data:
        b: Array as the right hand side of the equation \( Ax = b \)
        hessian_perturbation: regularization of the hessian.
        kwargs: kwargs to pass to the inversion method.

    Returns:
        Instance of [InverseHvpResult][pydvl.influence.twice_differentiable.InverseHvpResult], with
            an array that solves the inverse problem, i.e., it returns \( x \) such that \( Ax = b \)
            and a dictionary containing information about the inversion process.
    """

    return InversionRegistry.call(
        inversion_method,
        model,
        training_data,
        b,
        hessian_perturbation=hessian_perturbation,
        **kwargs,
    )

Last update: 2023-10-14
Created: 2023-10-14