-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathlorentz_feedforward.py
More file actions
50 lines (42 loc) · 1.59 KB
/
lorentz_feedforward.py
File metadata and controls
50 lines (42 loc) · 1.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from .. import nn as hnn
import math
import torch
from torch import nn
import torch.nn.functional as F
from ..manifolds import Lorentz
class LorentzFeedForward(nn.Module):
"""
Lorentz Multi-Layer Perceptron (MLP) used as a feed-forward layer.
Attributes:
w1: Linear layer for input-to-hidden transformation.
w2: Linear layer for hidden-to-output transformation.
w3: Additional linear layer for feature transformation.
"""
def __init__(self, manifold: Lorentz, dim: int, inter_dim: int):
"""
Initializes the MLP layer.
Args:
manifold: Input manifold
dim (int): Input and output dimensionality.
inter_dim (int): Hidden layer dimensionality.
"""
super().__init__()
self.manifold = manifold
self.c = manifold.c
self.w1 = hnn.LorentzLinear(self.manifold, dim, inter_dim - 1)
self.w2 = hnn.LorentzLinear(self.manifold, inter_dim, dim - 1)
self.w3 = hnn.LorentzLinear(self.manifold, dim, inter_dim - 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the MLP layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after MLP computation.
"""
x1_time = F.silu(self.w1(x, return_space=True))
x3_time = self.w3(x, return_space=True)
x_space = x1_time * x3_time
x_time = ((x_space**2).sum(dim=-1, keepdims=True) + self.c).clamp_min(1e-6).sqrt()
x = torch.cat([x_time, x_space], dim=-1)
return self.w2(x)