import torch
from torch import Tensor
from flambe.nn.distance.distance import DistanceModule, MeanModule
[docs]class EuclideanDistance(DistanceModule):
"""Implement a EuclideanDistance object."""
[docs] def forward(self, mat_1: Tensor, mat_2: Tensor) -> Tensor:
"""Returns the squared euclidean distance between each
element in mat_1 and each element in mat_2.
Parameters
----------
mat_1: torch.Tensor
matrix of shape (n_1, n_features)
mat_2: torch.Tensor
matrix of shape (n_2, n_features)
Returns
-------
dist: torch.Tensor
distance matrix of shape (n_1, n_2)
"""
dist = [torch.sum((mat_1 - mat_2[i])**2, dim=1) for i in range(mat_2.size(0))]
dist = torch.stack(dist, dim=1)
return dist
[docs]class EuclideanMean(MeanModule):
"""Implement a EuclideanMean object."""
[docs] def forward(self, data: Tensor) -> Tensor:
"""Performs a forward pass through the network.
Parameters
----------
data : torch.Tensor
The input data, as a float tensor
Returns
-------
torch.Tensor
The encoded output, as a float tensor
"""
return data.mean(0)