xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.fwbw#

Classes#

Functions#

rev_cumsum_off(x)

rev_cumsum(x)

causal_forget_matrix(forget_gates)

mLSTMTorchFunction(config)

Returns an autograd function that computes the gradient itself.

Module Contents#

xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.fwbw.rev_cumsum_off(x)#
xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.fwbw.rev_cumsum(x)#
xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.fwbw.causal_forget_matrix(forget_gates)#
class xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.fwbw.mLSTMfwbwConfig#
chunk_size: int | None = 512#
return_state: bool = False#
use_initial_state: bool = False#
keep_G: bool = False#
keep_gates: bool = True#
keep_M: bool = False#
keep_c: bool = False#
stabilize_correctly: bool = False#
scale = None#
device_type: str = 'cuda'#
assign_model_config_params(model_config)#
xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.fwbw.mLSTMTorchFunction(config)#

Returns an autograd function that computes the gradient itself.

config: mLSTMfwbwConfig Configuration for mLSTMTorchFunc

Parameters:

config (mLSTMfwbwConfig)

class xlstm_jax.models.xlstm_pytorch.blocks.mlstm.backend.fwbw.mLSTMfwbw(config)#

Bases: torch.nn.Module

Parameters:

config (mLSTMfwbwConfig)

config_class#
config#
func#
forward(*args)#