pydvl.influence.base_influence_function_model
¶
InfluenceMode
¶
Enum representation for the types of influence.
ATTRIBUTE | DESCRIPTION |
---|---|
Up |
|
Perturbation |
|
InfluenceFunctionModel
¶
Bases: Generic[TensorType, DataLoaderType]
, ABC
Generic abstract base class for computing influence related quantities. For a specific influence algorithm and tensor framework, inherit from this base class
n_parameters
abstractmethod
property
¶
Number of trainable parameters of the underlying model
is_thread_safe
abstractmethod
property
¶
is_thread_safe: bool
Whether the influence computation is thread safe
is_fitted
abstractmethod
property
¶
Override this, to expose the fitting status of the instance.
fit
abstractmethod
¶
fit(data: DataLoaderType) -> InfluenceFunctionModel
Override this method to fit the influence function model to training data, e.g. pre-compute hessian matrix or matrix decompositions
PARAMETER | DESCRIPTION |
---|---|
data |
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
InfluenceFunctionModel
|
The fitted instance |
Source code in src/pydvl/influence/base_influence_function_model.py
influences_from_factors
abstractmethod
¶
influences_from_factors(
z_test_factors: TensorType,
x: TensorType,
y: TensorType,
mode: InfluenceMode = InfluenceMode.Up,
) -> TensorType
Override this method to implement the computation of
for the case of up-weighting influence, resp.
for the perturbation type influence case. The gradient is meant to be per sample of the batch \((x, y)\).
PARAMETER | DESCRIPTION |
---|---|
z_test_factors |
pre-computed array, approximating \(H^{-1}\nabla_{\theta} \ell(y_{\text{test}}, f_{\theta}(x_{\text{test}}))\)
TYPE:
|
x |
model input to use in the gradient computations \(\nabla_{\theta}\ell(y, f_{\theta}(x))\), resp. \(\nabla_{x}\nabla_{\theta}\ell(y, f_{\theta}(x))\), if None, use \(x=x_{\text{test}}\)
TYPE:
|
y |
label tensor to compute gradients
TYPE:
|
mode |
enum value of InfluenceMode
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
TensorType
|
Tensor representing the element-wise scalar products for the provided batch |