Source code for flambe.nn.distance.euclidean

# type: ignore[override]

import torch
from torch import Tensor
from flambe.nn.distance.distance import DistanceModule, MeanModule


[docs]class EuclideanDistance(DistanceModule): """Implement a EuclideanDistance object."""
[docs] def forward(self, mat_1: Tensor, mat_2: Tensor) -> Tensor: """Returns the squared euclidean distance between each element in mat_1 and each element in mat_2. Parameters ---------- mat_1: torch.Tensor matrix of shape (n_1, n_features) mat_2: torch.Tensor matrix of shape (n_2, n_features) Returns ------- dist: torch.Tensor distance matrix of shape (n_1, n_2) """ dist = [torch.sum((mat_1 - mat_2[i])**2, dim=1) for i in range(mat_2.size(0))] dist = torch.stack(dist, dim=1) return dist
[docs]class EuclideanMean(MeanModule): """Implement a EuclideanMean object."""
[docs] def forward(self, data: Tensor) -> Tensor: """Performs a forward pass through the network. Parameters ---------- data : torch.Tensor The input data, as a float tensor Returns ------- torch.Tensor The encoded output, as a float tensor """ return data.mean(0)