xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell

Contents

xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell#

Attributes#

Classes#

Functions#

sLSTMCellFuncGenerator(training, config)

Module Contents#

xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.LOGGER#
xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.DTYPE_DICT#
xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.DTYPES#
xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.curdir#
xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.rnn_function_registry#
xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell._python_dtype_to_cuda_dtype#
class xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.sLSTMCellConfig#
hidden_size: int = -1#
num_heads: int = 4#
num_states: int = 4#
backend: Literal['vanilla', 'cuda'] = 'cuda'#
function: str = 'slstm'#
bias_init: Literal['powerlaw_blockdependent', 'small_init', 'standard'] = 'powerlaw_blockdependent'#
recurrent_weight_init: Literal['zeros', 'standard'] = 'zeros'#
_block_idx: int = 0#
_num_blocks: int = 1#
num_gates: int = 4#
gradient_recurrent_cut: bool = False#
gradient_recurrent_clipval: float | None = None#
forward_clipval: float | None = None#
batch_size: int = 8#
input_shape: Literal['BSGNH', 'SBGNH'] = 'BSGNH'#
internal_input_shape: Literal['SBNGH', 'SBGNH', 'SBNHG'] = 'SBNGH'#
output_shape: Literal['BNSH', 'SBH', 'BSH', 'SBNH'] = 'BNSH'#
constants: dict#
dtype: DTYPES = 'bfloat16'#
dtype_b: DTYPES | None = 'float32'#
dtype_r: DTYPES | None = None#
dtype_w: DTYPES | None = None#
dtype_g: DTYPES | None = None#
dtype_s: DTYPES | None = None#
dtype_a: DTYPES | None = None#
enable_automatic_mixed_precision: bool = True#
initial_val: float | collections.abc.Sequence[float] = 0.0#
property head_dim#
property input_dim#
property torch_dtype: torch.dtype#
Return type:

torch.dtype

property torch_dtype_b: torch.dtype#
Return type:

torch.dtype

property torch_dtype_r: torch.dtype#
Return type:

torch.dtype

property torch_dtype_w: torch.dtype#
Return type:

torch.dtype

property torch_dtype_s: torch.dtype#
Return type:

torch.dtype

property defines#
class xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.sLSTMCellBase(config)#

Bases: torch.nn.Module

Parameters:

config (sLSTMCellConfig)

config_class#
config#
_recurrent_kernel_#
recurrent_kernel#
_bias_#
bias#
property _recurrent_kernel#
property _bias#
_recurrent_kernel_ext2int(recurrent_kernel_ext)#
Parameters:

recurrent_kernel_ext (torch.Tensor)

Return type:

torch.Tensor

_bias_ext2int(bias_ext)#
Parameters:

bias_ext (torch.Tensor)

Return type:

torch.Tensor

_recurrent_kernel_int2ext(recurrent_kernel_int)#
Parameters:

recurrent_kernel_int (torch.Tensor)

Return type:

torch.Tensor

_bias_int2ext(bias_int)#
Parameters:

bias_int (torch.Tensor)

Return type:

torch.Tensor

parameters_to_dtype()#
property head_dim#
_permute_input(x)#
Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

_permute_output(x)#
Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

reset_parameters()#

Resets this layer’s parameters to their initial values.

_check_input(input)#
Parameters:

input (torch.Tensor)

Return type:

None

_zero_state(input)#

Return a zero state matching dtype and batch size of input.

Parameters:

input (torch.Tensor) – Tensor, to specify the device and dtype of the returned tensors.

Returns:

a nested structure of zero Tensors.

Return type:

zero_state

_get_state(input, state=None)#
Parameters:
Return type:

torch.Tensor

static _get_final_state(all_states)#

All states have the structure [STATES, SEQUENCE, BATCH, HIDDEN]

Parameters:

all_states (torch.Tensor)

Return type:

torch.Tensor

_is_cuda()#
Return type:

bool

step(input, state)#
Parameters:
Return type:

tuple[torch.Tensor, torch.Tensor]

forward(input, state=None)#
class xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.sLSTMCellCUDA#
mod#
classmethod instance(config)#
Parameters:

config (sLSTMCellConfig)

xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.sLSTMCellFuncGenerator(training, config)#
Parameters:

config (sLSTMCellConfig)

class xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.sLSTMCell_vanilla(config)#

Bases: sLSTMCellBase

Parameters:

config (sLSTMCellConfig)

config_class#
pointwise#
_recurrent_kernel_ext2int(recurrent_kernel_ext)#
Parameters:

recurrent_kernel_ext (torch.Tensor)

Return type:

torch.Tensor

_recurrent_kernel_int2ext(recurrent_kernel_int)#
Parameters:

recurrent_kernel_int (torch.Tensor)

Return type:

torch.Tensor

_bias_ext2int(bias_ext)#
Parameters:

bias_ext (torch.Tensor)

Return type:

torch.Tensor

_bias_int2ext(bias_int)#
Parameters:

bias_int (torch.Tensor)

Return type:

torch.Tensor

_impl(input, state)#
Parameters:
Return type:

torch.Tensor

_impl_step(input, state)#
Parameters:
Return type:

torch.Tensor

config#
_recurrent_kernel_#
recurrent_kernel#
_bias_#
bias#
property _recurrent_kernel#
property _bias#
parameters_to_dtype()#
property head_dim#
_permute_input(x)#
Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

_permute_output(x)#
Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

reset_parameters()#

Resets this layer’s parameters to their initial values.

_check_input(input)#
Parameters:

input (torch.Tensor)

Return type:

None

_zero_state(input)#

Return a zero state matching dtype and batch size of input.

Parameters:

input (torch.Tensor) – Tensor, to specify the device and dtype of the returned tensors.

Returns:

a nested structure of zero Tensors.

Return type:

zero_state

_get_state(input, state=None)#
Parameters:
Return type:

torch.Tensor

static _get_final_state(all_states)#

All states have the structure [STATES, SEQUENCE, BATCH, HIDDEN]

Parameters:

all_states (torch.Tensor)

Return type:

torch.Tensor

_is_cuda()#
Return type:

bool

step(input, state)#
Parameters:
Return type:

tuple[torch.Tensor, torch.Tensor]

forward(input, state=None)#
class xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.sLSTMCell_cuda(config, skip_backend_init=False)#

Bases: sLSTMCellBase

Parameters:
config_class#
internal_input_shape = 'SBNGH'#
_recurrent_kernel_ext2int(recurrent_kernel_ext)#
Parameters:

recurrent_kernel_ext (torch.Tensor)

Return type:

torch.Tensor

_recurrent_kernel_int2ext(recurrent_kernel_int)#
Parameters:

recurrent_kernel_int (torch.tensor)

Return type:

torch.Tensor

_bias_ext2int(bias_ext)#
Parameters:

bias_ext (torch.Tensor)

Return type:

torch.Tensor

_bias_int2ext(bias_int)#
Parameters:

bias_int (torch.Tensor)

Return type:

torch.Tensor

_impl_step(training, input, state)#
Parameters:
Return type:

torch.Tensor

_impl(training, input, state)#
Parameters:
Return type:

torch.Tensor

config#
_recurrent_kernel_#
recurrent_kernel#
_bias_#
bias#
property _recurrent_kernel#
property _bias#
parameters_to_dtype()#
property head_dim#
_permute_input(x)#
Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

_permute_output(x)#
Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

reset_parameters()#

Resets this layer’s parameters to their initial values.

_check_input(input)#
Parameters:

input (torch.Tensor)

Return type:

None

_zero_state(input)#

Return a zero state matching dtype and batch size of input.

Parameters:

input (torch.Tensor) – Tensor, to specify the device and dtype of the returned tensors.

Returns:

a nested structure of zero Tensors.

Return type:

zero_state

_get_state(input, state=None)#
Parameters:
Return type:

torch.Tensor

static _get_final_state(all_states)#

All states have the structure [STATES, SEQUENCE, BATCH, HIDDEN]

Parameters:

all_states (torch.Tensor)

Return type:

torch.Tensor

_is_cuda()#
Return type:

bool

step(input, state)#
Parameters:
Return type:

tuple[torch.Tensor, torch.Tensor]

forward(input, state=None)#
class xlstm_jax.models.xlstm_pytorch.blocks.slstm.cell.sLSTMCell(config)#

Bases: sLSTMCellBase

Parameters:

config (sLSTMCellConfig)

config_class#
config#
_recurrent_kernel_#
recurrent_kernel#
_bias_#
bias#
property _recurrent_kernel#
property _bias#
_recurrent_kernel_ext2int(recurrent_kernel_ext)#
Parameters:

recurrent_kernel_ext (torch.Tensor)

Return type:

torch.Tensor

_bias_ext2int(bias_ext)#
Parameters:

bias_ext (torch.Tensor)

Return type:

torch.Tensor

_recurrent_kernel_int2ext(recurrent_kernel_int)#
Parameters:

recurrent_kernel_int (torch.Tensor)

Return type:

torch.Tensor

_bias_int2ext(bias_int)#
Parameters:

bias_int (torch.Tensor)

Return type:

torch.Tensor

parameters_to_dtype()#
property head_dim#
_permute_input(x)#
Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

_permute_output(x)#
Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

reset_parameters()#

Resets this layer’s parameters to their initial values.

_check_input(input)#
Parameters:

input (torch.Tensor)

Return type:

None

_zero_state(input)#

Return a zero state matching dtype and batch size of input.

Parameters:

input (torch.Tensor) – Tensor, to specify the device and dtype of the returned tensors.

Returns:

a nested structure of zero Tensors.

Return type:

zero_state

_get_state(input, state=None)#
Parameters:
Return type:

torch.Tensor

static _get_final_state(all_states)#

All states have the structure [STATES, SEQUENCE, BATCH, HIDDEN]

Parameters:

all_states (torch.Tensor)

Return type:

torch.Tensor

_is_cuda()#
Return type:

bool

step(input, state)#
Parameters:
Return type:

tuple[torch.Tensor, torch.Tensor]

forward(input, state=None)#