# Source code for flambe.nn.distance.hyperbolic

# type: ignore[override]

import torch
from torch import Tensor
from flambe.nn.distance.distance import DistanceModule, MeanModule

[docs]EPSILON = 1e-5

[docs]def arccosh(x):
"""Compute the arcosh, numerically stable."""
x = torch.clamp(x, min=1 + EPSILON)
a = torch.log(x)
b = torch.log1p(torch.sqrt(x * x - 1) / x)
return a + b

[docs]def mdot(x, y):
"""Compute the inner product."""
m = x.new_ones(1, x.size(1))
m[0, 0] = -1

[docs]def dist(x, y):
"""Get the hyperbolic distance between x and y."""
return arccosh(-mdot(x, y))

[docs]def project(x):
"""Project onto the hyeprboloid embedded in in n+1 dimensions."""

[docs]def log_map(x, y):
"""Perform the log step."""
d = dist(x, y)
return (d / torch.sinh(d)) * (y - torch.cosh(d) * x)

[docs]def norm(x):
"""Compute the norm"""
n = torch.sqrt(torch.abs(mdot(x, x)))
return n

[docs]def exp_map(x, y):
"""Perform the exp step."""
n = torch.clamp(norm(y), min=EPSILON)

[docs]def loss(x, y):
"""Get the loss for the optimizer."""

[docs]class HyperbolicDistance(DistanceModule):
"""Implement a HyperbolicDistance object.

"""

[docs]    def forward(self, mat_1: Tensor, mat_2: Tensor) -> Tensor:
"""Returns the squared euclidean distance between each
element in mat_1 and each element in mat_2.

Parameters
----------
mat_1: torch.Tensor
matrix of shape (n_1, n_features)
mat_2: torch.Tensor
matrix of shape (n_2, n_features)

Returns
-------
dist: torch.Tensor
distance matrix of shape (n_1, n_2)

"""
# Get projected 1st dimension
mat_1_x_0 = torch.sqrt(1 + mat_1.pow(2).sum(dim=1, keepdim=True))
mat_2_x_0 = torch.sqrt(1 + mat_2.pow(2).sum(dim=1, keepdim=True))

# Compute bilinear form
left = mat_1_x_0.mm(mat_2_x_0.t())  # n_1 x n_2
right = mat_1[:, 1:].mm(mat_2[:, 1:].t())  # n_1 x n_2

# Arcosh
return arccosh(left - right).pow(2)

[docs]class HyperbolicMean(MeanModule):
"""Compute the mean point in the hyperboloid model."""

[docs]    def forward(self, data: Tensor) -> Tensor:
"""Performs a forward pass through the network.

Parameters
----------
data : torch.Tensor
The input data, as a float tensor

Returns
-------
torch.Tensor
The encoded output, as a float tensor

"""
n_iter = 5 if self.training else 100

# Project the input data to n+1 dimensions
projected = project(data)

mean = torch.mean(projected, 0, keepdim=True)
mean = mean / norm(mean)

r = 1e-2
for i in range(n_iter):
g = -2 * torch.mean(log_map(mean, projected), 0, keepdim=True)
mean = exp_map(mean, -r * g)
mean = mean / norm(mean)

# The first dimension, is recomputed in the distance module
return mean.squeeze()[1:]