xlstm_jax.models.xlstm_clean.components.ln

xlstm_jax.models.xlstm_clean.components.ln#

Functions#

LayerNorm([weight, bias, eps, dtype])

MultiHeadLayerNorm([weight, bias, eps, dtype, axis])

Module Contents#

xlstm_jax.models.xlstm_clean.components.ln.LayerNorm(weight=True, bias=False, eps=1e-05, dtype=jnp.float32, **kwargs)#
Parameters:
xlstm_jax.models.xlstm_clean.components.ln.MultiHeadLayerNorm(weight=True, bias=False, eps=1e-05, dtype=jnp.float32, axis=1, **kwargs)#
Parameters: