Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/syna/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ class Function:
"""

def __call__(self, *input: Tensor | np.ndarray | int | float) -> Any:
inputs = [as_tensor(x) for x in input]
inputs = [as_tensor(x) if x is not None else Tensor(x) for x in input]
xs = [x.data for x in inputs]
ys = self.forward(*xs)
if not isinstance(ys, tuple):
Expand Down
1 change: 1 addition & 0 deletions src/syna/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from syna.layers.layer import Layer, Linear, Parameter
from syna.layers.normalization import LayerNorm
from syna.layers.rnn import LSTM, RNN
39 changes: 39 additions & 0 deletions src/syna/layers/normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Normalization layers"""

from __future__ import annotations

import numpy as np

import syna.functions as F
from syna.core import Parameter
from syna.layers.layer import Layer


class LayerNorm(Layer):
r"""Layer Normalization (Jimmy Lei Ba et al. 2016).

paper: https://arxiv.org/abs/1607.06450

.. math::
y = \gamma \odot \frac{x - \mathbb{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} + \beta

Args:
dim: size of the last (normalized) dimension
eps: small constant added to variance for numerical stability
"""

def __init__(self, dim: int, eps: float = 1e-5, dtype=np.float32) -> None:
super().__init__()
self.eps = eps
self.gamma = Parameter(np.ones(dim, dtype=dtype), name="gamma")
self.beta = Parameter(np.zeros(dim, dtype=dtype), name="beta")

def forward(self, x):
"""Apply layer normalization to input tensor."""
axis = len(x.shape) - 1
mean = F.mean(x, axis=axis, keepdims=True)
diff = x - mean
var = F.mean(diff * diff, axis=axis, keepdims=True)
std = F.sqrt(var + self.eps)
x_hat = diff / std
return x_hat * self.gamma + self.beta