quanda.utils.functions.similarities module

Similarity functions.

quanda.utils.functions.similarities.cosine_similarity(test, train, replace_nan=0) Tensor[source]

Compute cosine similarity between test and train activations.

Parameters:
  • test (torch.Tensor) – The test activations.

  • train (torch.Tensor) – The train activations.

  • replace_nan (int, optional) – The value to replace NaN values with. Default is 0.

Returns:

The cosine similarity between the test and train activations.

Return type:

torch.Tensor

quanda.utils.functions.similarities.dot_product_similarity(test, train, replace_nan=0) Tensor[source]

Compute cosine similarity between test and train activations.

Parameters:
  • test (torch.Tensor) – The test activations.

  • train (torch.Tensor) – The train activations.

  • replace_nan (int, optional) – The value to replace NaN values with. Default is 0.

Returns:

The dot product similarity between the test and train activations.

Return type:

torch.Tensor