Neural networks¤
feedbax.nn.orthogonal_gru_cell
(
input_size
: int
,
hidden_size
: int
,
use_bias
: bool = True
,
scale
: float = 1.0
,
*,
key
: PRNGKeyArray
)
¤
feedbax.nn.orthogonal_gru_cell
(
input_size
: int
,
hidden_size
: int
,
use_bias
: bool = True
,
scale
: float = 1.0
,
*,
key
: PRNGKeyArray
)
Returns an eqx.nn.GRUCell
with orthogonal weight matrix initialization.
feedbax.nn.NetworkState
(Module)
¤
Type of state PyTree operated on by SimpleStagedNetwork
instances.
Attributes:
Name | Type | Description |
---|---|---|
hidden |
PyTree[Float[Array, unit]]
|
The (output) activity of the hidden layer units. |
output |
Optional[PyTree[Array]]
|
The activity of the readout layer, if the network has one. |
encoding |
Optional[PyTree[Array]]
|
The activity of the encoding layer, if the network has one. |
feedbax.nn.SimpleStagedNetwork
(AbstractStagedModel[NetworkState])
¤
A single step of a neural network layer, with optional encoder and readout layers.
Attributes:
Name | Type | Description |
---|---|---|
hidden_size |
int
|
The number of units in the hidden layer. |
out_size |
int
|
The number of readout units, if the network has a readout layer. Otherwise
this is equal to |
encoding_size |
Optional[int]
|
The number of encoder units, if the network has an encoder layer. |
hidden |
Module
|
The module implementing the hidden layer. |
hidden_nonlinearity |
Callable[[Float], Float]
|
The nonlinearity applied to the hidden layer output. |
encoder |
Optional[Module]
|
The module implementing the encoder layer, if present. |
readout |
Optional[Module]
|
The module implementing the readout layer, if present. |
model_spec: OrderedDict[str, ModelStage[Self, NetworkState]]
property
¤
Specifies the network model stages: layers, nonlinearities, and noise.
Only includes stages for the encoding layer, readout layer, hidden noise, and hidden nonlinearity, if the user respectively requests them at the time of construction.
Note
Inspects the instantiated hidden layer to determine if it is a stateful
network (e.g. an RNN). If not (e.g. Linear), it wraps the layer so that
it plays well with the state-passing of AbstractStagedModel
. This assumes
that stateful layers will take 2 positional arguments, and stateless layers
only 1.
__call__
(
input
: ModelInput
,
state
: StateT
,
key
: PRNGKeyArray
)
->
StateT
¤
__call__
(
input
: ModelInput
,
state
: StateT
,
key
: PRNGKeyArray
)
->
StateT
Inherited from feedbax._staged.AbstractStagedModel
state_consistency_update
(
state
: StateT
)
->
StateT
¤
state_consistency_update
(
state
: StateT
)
->
StateT
Inherited from feedbax._staged.AbstractStagedModel
__init__
(
input_size
: int
,
hidden_size
: int
,
out_size
: Optional[int] = None
,
encoding_size
: Optional[int] = None
,
hidden_type
: Callable[..., Module] = eqx.nn.GRUCell
,
encoder_type
: Callable[..., Module] = eqx.nn.Linear
,
readout_type
: Callable[..., Module] = eqx.nn.Linear
,
use_bias
: bool = True
,
hidden_nonlinearity
: Callable[[Float], Float] = identity_func
,
out_nonlinearity
: Callable[[Float], Float] = identity_func
,
hidden_noise_std
: Optional[float] = None
,
intervenors
: Optional[ArgIntervenors] = None
,
*,
key
: PRNGKeyArray
)
¤
__init__
(
input_size
: int
,
hidden_size
: int
,
out_size
: Optional[int] = None
,
encoding_size
: Optional[int] = None
,
hidden_type
: Callable[..., Module] = eqx.nn.GRUCell
,
encoder_type
: Callable[..., Module] = eqx.nn.Linear
,
readout_type
: Callable[..., Module] = eqx.nn.Linear
,
use_bias
: bool = True
,
hidden_nonlinearity
: Callable[[Float], Float] = identity_func
,
out_nonlinearity
: Callable[[Float], Float] = identity_func
,
hidden_noise_std
: Optional[float] = None
,
intervenors
: Optional[ArgIntervenors] = None
,
*,
key
: PRNGKeyArray
)
Note
If an integer is passed for encoding_size
, input encoding is enabled.
Otherwise network inputs are passed directly to the hidden layer.
If an integer is passed for out_size
, readout is enabled. Otherwise
the network's outputs are the outputs of the hidden units.
In principle hidden_type
can be class defining a multi-layer network,
as long as it is instantiated as hidden_type(input_size, hidden_size, use_bias, *,
key)
.
Use partial
to set use_bias
for the encoder or readout types, before
passing them to this constructor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_size |
int
|
The number of input channels in the network.
If |
required |
hidden_size |
int
|
The number of units in the hidden layer. |
required |
out_size |
Optional[int]
|
The number of readout units. If |
None
|
encoding_size |
Optional[int]
|
The number of encoder units. If |
None
|
hidden_type |
Callable[..., Module]
|
The type of hidden layer to use. |
GRUCell
|
encoder_type |
Callable[..., Module]
|
The type of encoder layer to use. |
Linear
|
use_bias |
bool
|
Whether the hidden layer should have a bias term. |
True
|
hidden_nonlinearity |
Callable[[Float], Float]
|
A function to apply unitwise to the hidden layer output. This is
typically not used if |
identity_func
|
out_nonlinearity |
Callable[[Float], Float]
|
A function to apply unitwise to the readout layer output. |
identity_func
|
hidden_noise_std |
Optional[float]
|
Standard deviation of Gaussian noise to add to the hidden layer output. |
None
|
intervenors |
Optional[ArgIntervenors]
|
Intervenors to add to the model at construction time. |
None
|
key |
PRNGKeyArray
|
Random key for initialising the network. |
required |