*Serket is the goddess of magic in Egyptian mythology
Installation |Description |Documentation |Quick Example
Install development version
pip install git+https://github.com/ASEM000/serketserketaims to be the most intuitive and easy-to-use machine learning library injax.serketis fully transparent tojaxtransformation (e.g.vmap,grad,jit,...).
- Full documentation
- Train MNIST, UNet, ConvLSTM, PINN
- Model surgery, Parallelism, Mixed precision
- Optimizers, Augmentation composition
- Interoperability with keras, tensorflow
import jax, jax.numpy as jnp
import serket as sk
x_train, y_train = ..., ...
k1, k2 = jax.random.split(jax.random.key(0))
net = sk.tree_mask(sk.Sequential(
jnp.ravel,
sk.nn.Linear(28 * 28, 64, key=k1),
jax.nn.relu,
sk.nn.Linear(64, 10, key=k2),
))
@ft.partial(jax.grad, has_aux=True)
def loss_func(net, x, y):
logits = jax.vmap(sk.tree_unmask(net))(x)
onehot = jax.nn.one_hot(y, 10)
loss = jnp.mean(softmax_cross_entropy(logits, onehot))
return loss, (loss, logits)
@jax.jit
def train_step(net, x, y):
grads, (loss, logits) = loss_func(net, x, y)
net = jax.tree_map(lambda p, g: p - g * 1e-3, net, grads)
return net, (loss, logits)
for j, (xb, yb) in enumerate(zip(x_train, y_train)):
net, (loss, logits) = train_step(net, xb, yb)
accuracy = accuracy_func(logits, y_train)
net = sk.tree_unmask(net)📚 Layers catalog
| Group | Layers |
|---|---|
| Containers | - Sequential, Random{Choice} |
| Group | Layers |
|---|---|
| Attention | - MultiHeadAttention |
| Convolution | - {FFT,_}Conv{1D,2D,3D} - {FFT,_}Conv{1D,2D,3D}Transpose - Depthwise{FFT,_}Conv{1D,2D,3D} - Separable{FFT,_}Conv{1D,2D,3D} - Conv{1D,2D,3D}Local - SpectralConv{1D,2D,3D} |
| Dropout | - Dropout- Dropout{1D,2D,3D} - RandomCutout{1D,2D,3D} |
| Linear | - Linear, MLP, Identity |
| Normalization | - {Layer,Instance,Group,Batch}Norm |
| Pooling | - {Avg,Max,LP}Pool{1D,2D,3D} - Global{Avg,Max}Pool{1D,2D,3D} - Adaptive{Avg,Max}Pool{1D,2D,3D} |
| Reshaping | - Upsample{1D,2D,3D} - {Random,Center}Crop{1D,2D,3D} ` |
| Recurrent cells | - {SimpleRNN,LSTM,GRU,Dense}Cell - {Conv,FFTConv}{LSTM,GRU}{1D,2D,3D}Cell |
| Activations | - Adaptive{LeakyReLU,ReLU,Sigmoid,Tanh},- CeLU,ELU,GELU,GLU- Hard{SILU,Shrink,Sigmoid,Swish,Tanh}, - Soft{Plus,Sign,Shrink} - LeakyReLU,LogSigmoid,LogSoftmax,Mish,PReLU,- ReLU,ReLU6,SeLU,Sigmoid - Swish,Tanh,TanhShrink, ThresholdedReLU, Snake |
| Group | Layers |
|---|---|
| Filter | - {FFT,_}{Avg,Box,Gaussian,Motion}Blur2D - {JointBilateral,Bilateral,Median}Blur2D - {FFT,_}{UnsharpMask}2D - {FFT,_}{Sobel,Laplacian}2D - {FFT,_}BlurPool2D |
| Augment | - Adjust{Sigmoid,Log}2D - {Adjust,Random}{Brightness,Contrast,Hue,Saturation}2D, - RandomJigSaw2D,PixelShuffle2D, - Pixelate2D,Posterize2D,Solarize2D - FourierDomainAdapt2D |
| Geometric | - {Random,_}{Horizontal,Vertical}{Translate,Flip,Shear}2D - {Random,_}{Rotate}2D - RandomPerspective2D - {FFT,_}ElasticTransform2D |
| Color | - RGBToGrayscale2D , GrayscaleToRGB2D - RGBToHSV2D, HSVToRGB2D |