Source code for quanda.utils.functions.correlations
"""Correlation functions."""
from typing import Literal
from torchmetrics.functional.regression import (
kendall_rank_corrcoef,
spearman_corrcoef,
)
[docs]
def kendall_rank_corr(tensor1, tensor2):
"""Calculate torchmetrics kendall_corrcoef function.
The difference is that the input tensors are transposed before passing to
the function.
Parameters
----------
tensor1, tensor2 : torch.Tensor
The input tensors to compute the correlation coefficient.
Returns
-------
torch.Tensor
The Kendall Rank correlation coefficient between the two tensors.
"""
return kendall_rank_corrcoef(tensor1.T, tensor2.T)
[docs]
def spearman_rank_corr(tensor1, tensor2):
"""Calculate torchmetrics spearman_corrcoef function.
The difference is that the input tensors are transposed before passing to
the function.
Parameters
----------
tensor1, tensor2 : torch.Tensor
The input tensors to compute the correlation coefficient.
Returns
-------
torch.Tensor
The Spearman correlation coefficient between the two tensors.
"""
return spearman_corrcoef(tensor1.T, tensor2.T)
CorrelationFnLiterals = Literal["kendall", "spearman"]
correlation_functions = {
"kendall": kendall_rank_corr,
"spearman": spearman_rank_corr,
}