Source code for quanda.utils.functions.similarities
"""Similarity functions."""
import torch
[docs]
def cosine_similarity(test, train, replace_nan=0) -> torch.Tensor:
"""Compute cosine similarity between test and train activations.
Parameters
----------
test : torch.Tensor
The test activations.
train : torch.Tensor
The train activations.
replace_nan : int, optional
The value to replace NaN values with. Default is 0.
Returns
-------
torch.Tensor
The cosine similarity between the test and train activations.
"""
# TODO: Captum returns test activations as a list
if isinstance(test, list):
test = torch.cat(test)
test = test.view(test.shape[0], -1)
train = train.view(train.shape[0], -1)
test_norm = torch.linalg.norm(test, ord=2, dim=1, keepdim=True)
train_norm = torch.linalg.norm(train, ord=2, dim=1, keepdim=True)
test = torch.where(
test_norm != 0.0, test / test_norm, torch.Tensor([replace_nan])
)
train = torch.where(
train_norm != 0.0, train / train_norm, torch.Tensor([replace_nan])
).T
similarity = torch.mm(test, train)
return similarity
[docs]
def dot_product_similarity(test, train, replace_nan=0) -> torch.Tensor:
"""Compute cosine similarity between test and train activations.
Parameters
----------
test : torch.Tensor
The test activations.
train : torch.Tensor
The train activations.
replace_nan : int, optional
The value to replace NaN values with. Default is 0.
Returns
-------
torch.Tensor
The dot product similarity between the test and train activations.
"""
# TODO: I don't know why Captum return test activations as a list
if isinstance(test, list):
test = torch.cat(test)
test = test.view(test.shape[0], -1)
train = train.view(train.shape[0], -1)
similarity = torch.mm(test, train.T)
return similarity