import torch
from torch import Tensor
from flambe.nn.distance import DistanceModule, MeanModule
[docs]class CosineDistance(DistanceModule):
"""Implement a CosineDistance object.
"""
def __init__(self, eps: float = 1e-8) -> None:
"""Initialize the CosineDistance module.
Parameters
----------
eps : float, optional
Used for numerical stability
"""
super().__init__()
self.eps = eps
[docs] def forward(self, mat_1: Tensor, mat_2: Tensor) -> Tensor:
"""Returns the cosine distance between each
element in mat_1 and each element in mat_2.
Parameters
----------
mat_1: torch.Tensor
matrix of shape (n_1, n_features)
mat_2: torch.Tensor
matrix of shape (n_2, n_features)
Returns
-------
dist: torch.Tensor
distance matrix of shape (n_1, n_2)
"""
w1 = mat_1.norm(p=2, dim=1, keepdim=True)
w2 = mat_2.norm(p=2, dim=1, keepdim=True)
return 1 - torch.mm(mat_1, mat_2.t()) / (w1 * w2.t()).clamp(min=self.eps)
[docs]class CosineMean(MeanModule):
"""Implement a CosineMean object.
"""
[docs] def forward(self, data: Tensor) -> Tensor:
"""Performs a forward pass through the network.
Parameters
----------
data : torch.Tensor
The input data, as a float tensor
Returns
-------
torch.Tensor
The encoded output, as a float tensor
"""
data = data / (data.norm(dim=1, keepdim=True))
return data.mean(0)