quanda.utils.functions.similarities module¶
Similarity functions.
- quanda.utils.functions.similarities.cosine_similarity(test, train, replace_nan=0) Tensor[source]¶
Compute cosine similarity between test and train activations.
- Parameters:
test (torch.Tensor) – The test activations.
train (torch.Tensor) – The train activations.
replace_nan (int, optional) – The value to replace NaN values with. Default is 0.
- Returns:
The cosine similarity between the test and train activations.
- Return type:
torch.Tensor
- quanda.utils.functions.similarities.dot_product_similarity(test, train, replace_nan=0) Tensor[source]¶
Compute cosine similarity between test and train activations.
- Parameters:
test (torch.Tensor) – The test activations.
train (torch.Tensor) – The train activations.
replace_nan (int, optional) – The value to replace NaN values with. Default is 0.
- Returns:
The dot product similarity between the test and train activations.
- Return type:
torch.Tensor