xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell#
Attributes#
Classes#
Functions#
|
Module Contents#
- xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.LOGGER#
- xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.DTYPE_DICT#
- xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.DTYPES#
- xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.curdir#
- xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.rnn_function_registry#
- xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell._python_dtype_to_cuda_dtype#
- class xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.sLSTMCellConfig#
- backend: Literal['vanilla', 'cuda'] = 'cuda'#
- bias_init: Literal['powerlaw_blockdependent', 'small_init', 'standard'] = 'powerlaw_blockdependent'#
- recurrent_weight_init: Literal['zeros', 'standard'] = 'zeros'#
- input_shape: Literal['BSGNH', 'SBGNH'] = 'BSGNH'#
- internal_input_shape: Literal['SBNGH', 'SBGNH', 'SBNHG'] = 'SBNGH'#
- output_shape: Literal['BNSH', 'SBH', 'BSH', 'SBNH'] = 'BNSH'#
- dtype: DTYPES = 'bfloat16'#
- initial_val: float | collections.abc.Sequence[float] = 0.0#
- property head_dim#
- property input_dim#
- property torch_dtype: torch.dtype#
- Return type:
- property torch_dtype_b: torch.dtype#
- Return type:
- property torch_dtype_r: torch.dtype#
- Return type:
- property torch_dtype_w: torch.dtype#
- Return type:
- property torch_dtype_s: torch.dtype#
- Return type:
- property defines#
- class xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.sLSTMCellBase(config)#
Bases:
torch.nn.Module- Parameters:
config (sLSTMCellConfig)
- config_class#
- config#
- _recurrent_kernel_#
- recurrent_kernel#
- _bias_#
- bias#
- property _recurrent_kernel#
- property _bias#
- _recurrent_kernel_ext2int(recurrent_kernel_ext)#
- Parameters:
recurrent_kernel_ext (torch.Tensor)
- Return type:
- _bias_ext2int(bias_ext)#
- Parameters:
bias_ext (torch.Tensor)
- Return type:
- _recurrent_kernel_int2ext(recurrent_kernel_int)#
- Parameters:
recurrent_kernel_int (torch.Tensor)
- Return type:
- _bias_int2ext(bias_int)#
- Parameters:
bias_int (torch.Tensor)
- Return type:
- parameters_to_dtype()#
- property head_dim#
- _permute_input(x)#
- Parameters:
x (torch.Tensor)
- Return type:
- _permute_output(x)#
- Parameters:
x (torch.Tensor)
- Return type:
- reset_parameters()#
Resets this layer’s parameters to their initial values.
- _check_input(input)#
- Parameters:
input (torch.Tensor)
- Return type:
None
- _zero_state(input)#
Return a zero state matching dtype and batch size of input.
- Parameters:
input (torch.Tensor) – Tensor, to specify the device and dtype of the returned tensors.
- Returns:
a nested structure of zero Tensors.
- Return type:
zero_state
- _get_state(input, state=None)#
- Parameters:
input (torch.Tensor)
state (torch.Tensor | None)
- Return type:
- static _get_final_state(all_states)#
All states have the structure [STATES, SEQUENCE, BATCH, HIDDEN]
- Parameters:
all_states (torch.Tensor)
- Return type:
- step(input, state)#
- Parameters:
input (torch.Tensor)
state (torch.Tensor)
- Return type:
- forward(input, state=None)#
- class xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.sLSTMCellCUDA#
- mod#
- classmethod instance(config)#
- Parameters:
config (sLSTMCellConfig)
- xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.sLSTMCellFuncGenerator(training, config)#
- Parameters:
config (sLSTMCellConfig)
- class xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.sLSTMCell_vanilla(config)#
Bases:
sLSTMCellBase- Parameters:
config (sLSTMCellConfig)
- config_class#
- pointwise#
- _recurrent_kernel_ext2int(recurrent_kernel_ext)#
- Parameters:
recurrent_kernel_ext (torch.Tensor)
- Return type:
- _recurrent_kernel_int2ext(recurrent_kernel_int)#
- Parameters:
recurrent_kernel_int (torch.Tensor)
- Return type:
- _bias_ext2int(bias_ext)#
- Parameters:
bias_ext (torch.Tensor)
- Return type:
- _bias_int2ext(bias_int)#
- Parameters:
bias_int (torch.Tensor)
- Return type:
- _impl(input, state)#
- Parameters:
input (torch.Tensor)
state (torch.Tensor)
- Return type:
- _impl_step(input, state)#
- Parameters:
input (torch.Tensor)
state (torch.Tensor)
- Return type:
- config#
- _recurrent_kernel_#
- recurrent_kernel#
- _bias_#
- bias#
- property _recurrent_kernel#
- property _bias#
- parameters_to_dtype()#
- property head_dim#
- _permute_input(x)#
- Parameters:
x (torch.Tensor)
- Return type:
- _permute_output(x)#
- Parameters:
x (torch.Tensor)
- Return type:
- reset_parameters()#
Resets this layer’s parameters to their initial values.
- _check_input(input)#
- Parameters:
input (torch.Tensor)
- Return type:
None
- _zero_state(input)#
Return a zero state matching dtype and batch size of input.
- Parameters:
input (torch.Tensor) – Tensor, to specify the device and dtype of the returned tensors.
- Returns:
a nested structure of zero Tensors.
- Return type:
zero_state
- _get_state(input, state=None)#
- Parameters:
input (torch.Tensor)
state (torch.Tensor | None)
- Return type:
- static _get_final_state(all_states)#
All states have the structure [STATES, SEQUENCE, BATCH, HIDDEN]
- Parameters:
all_states (torch.Tensor)
- Return type:
- step(input, state)#
- Parameters:
input (torch.Tensor)
state (torch.Tensor)
- Return type:
- forward(input, state=None)#
- class xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.sLSTMCell_cuda(config, skip_backend_init=False)#
Bases:
sLSTMCellBase- Parameters:
config (sLSTMCellConfig)
skip_backend_init (bool)
- config_class#
- internal_input_shape = 'SBNGH'#
- _recurrent_kernel_ext2int(recurrent_kernel_ext)#
- Parameters:
recurrent_kernel_ext (torch.Tensor)
- Return type:
- _recurrent_kernel_int2ext(recurrent_kernel_int)#
- Parameters:
recurrent_kernel_int (torch.tensor)
- Return type:
- _bias_ext2int(bias_ext)#
- Parameters:
bias_ext (torch.Tensor)
- Return type:
- _bias_int2ext(bias_int)#
- Parameters:
bias_int (torch.Tensor)
- Return type:
- _impl_step(training, input, state)#
- Parameters:
training (bool)
input (torch.Tensor)
state (torch.Tensor)
- Return type:
- _impl(training, input, state)#
- Parameters:
training (bool)
input (torch.Tensor)
state (torch.Tensor)
- Return type:
- config#
- _recurrent_kernel_#
- recurrent_kernel#
- _bias_#
- bias#
- property _recurrent_kernel#
- property _bias#
- parameters_to_dtype()#
- property head_dim#
- _permute_input(x)#
- Parameters:
x (torch.Tensor)
- Return type:
- _permute_output(x)#
- Parameters:
x (torch.Tensor)
- Return type:
- reset_parameters()#
Resets this layer’s parameters to their initial values.
- _check_input(input)#
- Parameters:
input (torch.Tensor)
- Return type:
None
- _zero_state(input)#
Return a zero state matching dtype and batch size of input.
- Parameters:
input (torch.Tensor) – Tensor, to specify the device and dtype of the returned tensors.
- Returns:
a nested structure of zero Tensors.
- Return type:
zero_state
- _get_state(input, state=None)#
- Parameters:
input (torch.Tensor)
state (torch.Tensor | None)
- Return type:
- static _get_final_state(all_states)#
All states have the structure [STATES, SEQUENCE, BATCH, HIDDEN]
- Parameters:
all_states (torch.Tensor)
- Return type:
- step(input, state)#
- Parameters:
input (torch.Tensor)
state (torch.Tensor)
- Return type:
- forward(input, state=None)#
- class xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.sLSTMCell(config)#
Bases:
sLSTMCellBase- Parameters:
config (sLSTMCellConfig)
- config_class#
- config#
- _recurrent_kernel_#
- recurrent_kernel#
- _bias_#
- bias#
- property _recurrent_kernel#
- property _bias#
- _recurrent_kernel_ext2int(recurrent_kernel_ext)#
- Parameters:
recurrent_kernel_ext (torch.Tensor)
- Return type:
- _bias_ext2int(bias_ext)#
- Parameters:
bias_ext (torch.Tensor)
- Return type:
- _recurrent_kernel_int2ext(recurrent_kernel_int)#
- Parameters:
recurrent_kernel_int (torch.Tensor)
- Return type:
- _bias_int2ext(bias_int)#
- Parameters:
bias_int (torch.Tensor)
- Return type:
- parameters_to_dtype()#
- property head_dim#
- _permute_input(x)#
- Parameters:
x (torch.Tensor)
- Return type:
- _permute_output(x)#
- Parameters:
x (torch.Tensor)
- Return type:
- reset_parameters()#
Resets this layer’s parameters to their initial values.
- _check_input(input)#
- Parameters:
input (torch.Tensor)
- Return type:
None
- _zero_state(input)#
Return a zero state matching dtype and batch size of input.
- Parameters:
input (torch.Tensor) – Tensor, to specify the device and dtype of the returned tensors.
- Returns:
a nested structure of zero Tensors.
- Return type:
zero_state
- _get_state(input, state=None)#
- Parameters:
input (torch.Tensor)
state (torch.Tensor | None)
- Return type:
- static _get_final_state(all_states)#
All states have the structure [STATES, SEQUENCE, BATCH, HIDDEN]
- Parameters:
all_states (torch.Tensor)
- Return type:
- step(input, state)#
- Parameters:
input (torch.Tensor)
state (torch.Tensor)
- Return type:
- forward(input, state=None)#