Neural networks¤

feedbax.nn.orthogonal_gru_cell (input_size: int,hidden_size: int,use_bias: bool = True,scale: float = 1.0,*,key: PRNGKeyArray)

Returns an eqx.nn.GRUCell with orthogonal weight matrix initialization.

feedbax.nn.NetworkState (Module) ¤

Type of state PyTree operated on by SimpleStagedNetwork instances.


Name Type Description
hidden PyTree[Float[Array, unit]]

The (output) activity of the hidden layer units.

output Optional[PyTree[Array]]

The activity of the readout layer, if the network has one.

encoding Optional[PyTree[Array]]

The activity of the encoding layer, if the network has one.

feedbax.nn.SimpleStagedNetwork (AbstractStagedModel[NetworkState]) ¤

A single step of a neural network layer, with optional encoder and readout layers.


Name Type Description
hidden_size int

The number of units in the hidden layer.

out_size int

The number of readout units, if the network has a readout layer. Otherwise this is equal to hidden_size.

encoding_size Optional[int]

The number of encoder units, if the network has an encoder layer.

hidden Module

The module implementing the hidden layer.

hidden_nonlinearity Callable[[Float], Float]

The nonlinearity applied to the hidden layer output.

encoder Optional[Module]

The module implementing the encoder layer, if present.

readout Optional[Module]

The module implementing the readout layer, if present.

model_spec: OrderedDict[str, ModelStage[Self, NetworkState]] property ¤

Specifies the network model stages: layers, nonlinearities, and noise.

Only includes stages for the encoding layer, readout layer, hidden noise, and hidden nonlinearity, if the user respectively requests them at the time of construction.


Inspects the instantiated hidden layer to determine if it is a stateful network (e.g. an RNN). If not (e.g. Linear), it wraps the layer so that it plays well with the state-passing of AbstractStagedModel. This assumes that stateful layers will take 2 positional arguments, and stateless layers only 1.

__call__ (input: ModelInput,state: StateT,key: PRNGKeyArray) -> StateT
state_consistency_update (state: StateT) -> StateT
__init__ (input_size: int,hidden_size: int,out_size: Optional[int] = None,encoding_size: Optional[int] = None,hidden_type: Callable[..., Module] = eqx.nn.GRUCell,encoder_type: Callable[..., Module] = eqx.nn.Linear,readout_type: Callable[..., Module] = eqx.nn.Linear,use_bias: bool = True,hidden_nonlinearity: Callable[[Float], Float] = identity_func,out_nonlinearity: Callable[[Float], Float] = identity_func,hidden_noise_std: Optional[float] = None,intervenors: Optional[ArgIntervenors] = None,*,key: PRNGKeyArray)


If an integer is passed for encoding_size, input encoding is enabled. Otherwise network inputs are passed directly to the hidden layer.

If an integer is passed for out_size, readout is enabled. Otherwise the network's outputs are the outputs of the hidden units.

In principle hidden_type can be class defining a multi-layer network, as long as it is instantiated as hidden_type(input_size, hidden_size, use_bias, *, key).

Use partial to set use_bias for the encoder or readout types, before passing them to this constructor.


Name Type Description Default
input_size int

The number of input channels in the network. If encoder_type is not None, this is the number of inputs to the encoder layer—otherwise, the hidden layer.

hidden_size int

The number of units in the hidden layer.

out_size Optional[int]

The number of readout units. If None, do not add a readout layer.

encoding_size Optional[int]

The number of encoder units. If None, do not add an encoder layer.

hidden_type Callable[..., Module]

The type of hidden layer to use.

encoder_type Callable[..., Module]

The type of encoder layer to use.

use_bias bool

Whether the hidden layer should have a bias term.

hidden_nonlinearity Callable[[Float], Float]

A function to apply unitwise to the hidden layer output. This is typically not used if hidden_type is GRUCell or LSTMCell.

out_nonlinearity Callable[[Float], Float]

A function to apply unitwise to the readout layer output.

hidden_noise_std Optional[float]

Standard deviation of Gaussian noise to add to the hidden layer output.

intervenors Optional[ArgIntervenors]

Intervenors to add to the model at construction time.

key PRNGKeyArray

Random key for initialising the network.
