Util
TorchTensorContainerType = TypeVar('TorchTensorContainerType', torch.Tensor, Tuple[torch.Tensor, ...], Dict[str, torch.Tensor])
module-attribute
¶
Type variable for a PyTorch tensor or a container thereof.
to_model_device(x, model)
¶
Returns the tensor x
moved to the device of the model
, if device of model is set
PARAMETER | DESCRIPTION |
---|---|
x |
The tensor to be moved to the device of the model.
TYPE:
|
model |
The model whose device will be used to move the tensor.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Tensor
|
The tensor |
Source code in src/pydvl/influence/torch/util.py
flatten_tensors_to_vector(tensors)
¶
Flatten multiple tensors into a single 1D tensor (vector).
This function takes an iterable of tensors and reshapes each of them into a 1D tensor. These reshaped tensors are then concatenated together into a single 1D tensor in the order they were given.
PARAMETER | DESCRIPTION |
---|---|
tensors |
An iterable of tensors to be reshaped and concatenated. |
RETURNS | DESCRIPTION |
---|---|
Tensor
|
A 1D tensor that is the concatenation of all the reshaped input tensors. |
Source code in src/pydvl/influence/torch/util.py
reshape_vector_to_tensors(input_vector, target_shapes)
¶
Reshape a 1D tensor into multiple tensors with specified shapes.
This function takes a 1D tensor (input_vector) and reshapes it into a series of tensors with shapes given by 'target_shapes'. The reshaped tensors are returned as a tuple in the same order as their corresponding shapes.
Note: The total number of elements in 'input_vector' must be equal to the sum of the products of the shapes in 'target_shapes'.
PARAMETER | DESCRIPTION |
---|---|
input_vector |
The 1D tensor to be reshaped. Must be 1D.
TYPE:
|
target_shapes |
An iterable of tuples. Each tuple defines the shape of a tensor to be reshaped from the 'input_vector'. |
RETURNS | DESCRIPTION |
---|---|
Tuple[Tensor, ...]
|
A tuple of reshaped tensors. |
RAISES | DESCRIPTION |
---|---|
ValueError
|
If 'input_vector' is not a 1D tensor or if the total number of elements in 'input_vector' does not match the sum of the products of the shapes in 'target_shapes'. |
Source code in src/pydvl/influence/torch/util.py
align_structure(source, target)
¶
This function transforms target
to have the same structure as source
, i.e.,
it should be a dictionary with the same keys as source
and each corresponding
value in target
should have the same shape as the value in source
.
PARAMETER | DESCRIPTION |
---|---|
source |
The reference dictionary containing PyTorch tensors. |
target |
The input to be harmonized. It can be a dictionary, tuple, or tensor.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Dict[str, Tensor]
|
The harmonized version of |
RAISES | DESCRIPTION |
---|---|
ValueError
|
If |
Source code in src/pydvl/influence/torch/util.py
as_tensor(a, warn=True, **kwargs)
¶
Converts an array into a torch tensor.
PARAMETER | DESCRIPTION |
---|---|
a |
Array to convert to tensor.
TYPE:
|
warn |
If True, warns that
DEFAULT:
|
RETURNS | DESCRIPTION |
---|---|
Tensor
|
A torch tensor converted from the input array. |
Source code in src/pydvl/influence/torch/util.py
Created: 2023-09-02